In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler

import snntorch as snn
from snntorch import spikegen, utils, surrogate
from snntorch.functional import quant

import brevitas.nn as qnn
# from brevitas.nn import QuantConv1d, QuantIdentity, QuantLinear, BatchNorm1dToQuantScaleBias
from brevitas.quant import Int8WeightPerTensorFloat, Int8ActPerTensorFloat, Int32Bias
from brevitas.export import export_qonnx
from brevitas.core.scaling      import ScalingImplType
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.quant        import QuantType


from qonnx.util.cleanup import cleanup as qonnx_cleanup
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.core.datatype import DataType
#from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN

import optuna

import torchvision
from torchvision import transforms

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import itertools
import csv, copy, glob
import imblearn, imblearn.over_sampling
from collections import Counter, OrderedDict
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, precision_recall_curve, auc

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import label_binarize, normalize
from scipy.signal import butter, lfilter, freqz

import os, sys, time, datetime, argparse, json


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
# random.seed(seed)
torch.cuda.manual_seed_all(seed) # for multi-GPU setups

In [3]:
# def normalize_row_to_range(row, low=-4.0, high=4.0):
#     rmin, rmax = row.min(), row.max()
#     denom = (rmax - rmin) if rmax > rmin else 1e-8
#     return low + (row - rmin) * (high - low) / denom

In [5]:
def load_train_test_data(
    mitbih_train_path='/...../data/mitbih_processed_intra_patient_4class_180_center90_filtered/train',
    folders=None, random_state=42,
    max_files_per_folder=None
):
    if folders is None:
        folders = ['normal', 'sveb', 'veb', 'f']
    label_mapping = {'normal': 0, 'sveb': 1, 'veb': 2, 'f': 3}
    
    def scan_folder(base_path):
        X_parts, y_parts = [], []
        expected_cols = None
    
        for folder in folders:
            folder_path = os.path.join(base_path, folder)
            csvs = sorted(
                glob.glob(os.path.join(folder_path, '*.csv')) +
                glob.glob(os.path.join(folder_path, '*.CSV'))
            )
            if not csvs:
                print(f'Warning: no CSVs found in: {folder_path}')
                continue
    
            for fpath in csvs:
                df = pd.read_csv(
                    fpath, dtype=np.float32, engine='c',
                    usecols=lambda c: c != 'Unnamed: 0'
                )
                if df.shape[0] == 0:
                    df = pd.read_csv(fpath, header=None, dtype=np.float32, engine='c')
    
                df = df.dropna(axis=1, how='all')
    
                arr = df.to_numpy(copy=False)
                if expected_cols is None:
                    expected_cols = arr.shape[1]
                elif arr.shape[1] != expected_cols:
                    raise ValueError(
                        f'Inconsistent column count: {fpath} has {arr.shape[1]}, expected {expected_cols}'
                    )
    
                X_parts.append(arr)
                y_parts.extend([label_mapping[folder]] * arr.shape[0])
    
        if not X_parts:
            raise FileNotFoundError(f'No usable CSV rows under {base_path} (folders={folders})')
    
        X = np.vstack(X_parts).astype(np.float32, copy=False)
        y = np.asarray(y_parts, dtype=np.int64)
        return X, y
    
    # Load and shuffle
    X_mit, y_mit = scan_folder(mitbih_train_path)
    print(f'Loaded: MIT-BIH {X_mit.shape}')
    
    rng = np.random.RandomState(random_state)
    perm = rng.permutation(len(y_mit))
    X = X_mit[perm]
    y = y_mit[perm]
    
    # OPTION 3: NO SCALING - data already z-score normalized from preprocessing
    # Data will have mean≈0, std≈1, range≈[-3, 3]
    print(f"Data range after z-score normalization: [{X.min():.2f}, {X.max():.2f}]")
    print(f"Data mean: {X.mean():.4f}, std: {X.std():.4f}")
    
    # Handle any NaNs just in case
    X = np.nan_to_num(X, nan=0.0).astype(np.float32, copy=False)
    
    # Convert to torch tensors
    X_train_tensor = torch.from_numpy(X).unsqueeze(1)   # (N, 1, L)
    y_train = torch.from_numpy(y)                       # (N,)
    
    print("Class distribution:", Counter(y_train.tolist()))
    print("X train shape:", tuple(X_train_tensor.shape))
    print("y_train shape:", tuple(y_train.shape))
    print("Data tensor loaded successfully")
    
    return X_train_tensor, y_train

In [6]:
def create_csnn_datasets(X_train, y_train, X_test=None, y_test=None):
    """
    Create TensorDataset objects for SNN training.

    Args:
        X_train_spikes (torch.Tensor): Spike-encoded training data
        y_train (torch.Tensor): Training labels
        X_test_spikes (torch.Tensor): Spike-encoded testing data
        y_test (torch.Tensor): Testing labels

    Returns:
        tuple: (train_dataset, test_dataset)
    """
    train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
    if X_test is None:
        print("dataset created successfully")
        return train_dataset
    else:
        test_dataset = torch.utils.data.TensorDataset(X_test, y_test)
    return train_dataset, test_dataset

In [7]:
# X_train_tensor, y_train = load_train_test_data()

In [9]:
def create_qcsnn_model_4class(num_bits=8, input_size=180, stride4=1, kernel_size=3, 
                               dropout4=0.35, beta4=0.5, slope4=25, 
                               threshold4=0.5, learn_beta4=True, learn_threshold4=True):
    """Factory function for 4-class QCSNN with 3 conv blocks + 2 linear blocks"""
    
    spike_grad4 = snn.surrogate.fast_sigmoid(slope=slope4)
    
    # Calculate output sizes after 3 conv blocks
    output_size1 = (input_size - kernel_size) // stride4 + 1
    output_size1 = output_size1 // 2  # MaxPool
    
    output_size2 = (output_size1 - kernel_size) // stride4 + 1
    output_size2 = output_size2 // 2  # MaxPool
    
    output_size3 = (output_size2 - kernel_size) // stride4 + 1
    output_size3 = output_size3 // 2  # MaxPool
    
    flattened_size = output_size3 * 24  # 24 channels from 3rd conv block
    
    print(f"Output sizes: Block1={output_size1}, Block2={output_size2}, Block3={output_size3}")
    print(f"Flattened size: {flattened_size}")
    
    model = torch.nn.Sequential(OrderedDict([
        # ── Input quantiser ──────────────────────────────
        ("qcsnet4_cblk1_input",
         qnn.QuantIdentity(bit_width=num_bits, return_quant_tensor=True)),
    
        # ── First Conv block ─────────────────────────────
        ("qcsnet4_cblk1_qconv1d",
         qnn.QuantConv1d(
             1, 16, 3, stride=stride4, bias=False,
             weight_bit_width=num_bits,
             weight_quant=Int8WeightPerTensorFloat,
             output_quant=Int8ActPerTensorFloat,
             return_quant_tensor=True)),
    
        ("qcsnet4_cblk1_batch_norm",
         qnn.BatchNorm1dToQuantScaleBias(
             16,
             weight_bit_width=num_bits,
             weight_quant=Int8WeightPerTensorFloat,
             output_quant=Int8ActPerTensorFloat,
             bias_quant=Int32Bias,
             return_quant_tensor=True)),
        
        ("qcsnet4_cblk1_leaky",
         snn.Leaky(beta=beta4, learn_beta=learn_beta4,
                   spike_grad=spike_grad4,
                   threshold=threshold4, learn_threshold=learn_threshold4,
                   init_hidden=True)),
        
        ("qcsnet4_cblk1_max_pool", torch.nn.MaxPool1d(2, 2)),
    
        # ── Second Conv block ────────────────────────────
        ("qcsnet4_cblk2_input",
         qnn.QuantIdentity(bit_width=num_bits, return_quant_tensor=True)),
        
        ("qcsnet4_cblk2_qconv1d",
         qnn.QuantConv1d(
             16, 16, 3, stride=stride4, bias=False,  # 16→16 channels
             weight_bit_width=num_bits,
             weight_quant=Int8WeightPerTensorFloat,
             output_quant=Int8ActPerTensorFloat,
             return_quant_tensor=True)),
    
        ("qcsnet4_cblk2_batch_norm",
         qnn.BatchNorm1dToQuantScaleBias(
             16,  # 16 channels
             weight_bit_width=num_bits,
             weight_quant=Int8WeightPerTensorFloat,
             output_quant=Int8ActPerTensorFloat,
             bias_quant=Int32Bias,
             return_quant_tensor=True)),
    
        ("qcsnet4_cblk2_leaky",
         snn.Leaky(beta=beta4, learn_beta=learn_beta4,
                   spike_grad=spike_grad4,
                   threshold=threshold4, learn_threshold=learn_threshold4,
                   init_hidden=True)),
        
        ("qcsnet4_cblk2_max_pool", torch.nn.MaxPool1d(2, 2)),
    
        # ── Third Conv block ────────────────────────────
        ("qcsnet4_cblk3_input",
         qnn.QuantIdentity(bit_width=num_bits, return_quant_tensor=True)),
        
        ("qcsnet4_cblk3_qconv1d",
         qnn.QuantConv1d(
             16, 24, 3, stride=stride4, bias=False,  # 16→24 channels
             weight_bit_width=num_bits,
             weight_quant=Int8WeightPerTensorFloat,
             output_quant=Int8ActPerTensorFloat,
             return_quant_tensor=True)),
    
        ("qcsnet4_cblk3_batch_norm",
         qnn.BatchNorm1dToQuantScaleBias(
             24,  # 24 channels
             weight_bit_width=num_bits,
             weight_quant=Int8WeightPerTensorFloat,
             output_quant=Int8ActPerTensorFloat,
             bias_quant=Int32Bias,
             return_quant_tensor=True)),
        
        ("qcsnet4_cblk3_leaky",
         snn.Leaky(beta=beta4, learn_beta=learn_beta4,
                   spike_grad=spike_grad4,
                   threshold=threshold4, learn_threshold=learn_threshold4,
                   init_hidden=True)),
        
        ("qcsnet4_cblk3_max_pool", torch.nn.MaxPool1d(2, 2)),
    
        # ── Dense head (2 linear blocks) ──────────────────────────────
        ("qcsnet4_flatten", torch.nn.Flatten()),
        
        # Linear block 1
        ("qcsnet4_lblk1_input",
         qnn.QuantIdentity(bit_width=num_bits, return_quant_tensor=True)),
        
        ("qcsnet4_lblk1_qlinear",
         qnn.QuantLinear(
             flattened_size, 128, bias=False,  # Hidden layer with 128 units
             weight_bit_width=num_bits,
             weight_quant=Int8WeightPerTensorFloat,
             output_quant=Int8ActPerTensorFloat,
             return_quant_tensor=True)),
        
        ("qcsnet4_lblk1_leaky",
         snn.Leaky(beta=beta4, learn_beta=learn_beta4,
                   spike_grad=spike_grad4,
                   threshold=threshold4, learn_threshold=learn_threshold4,
                   init_hidden=True)),
        
        # Linear block 2
        ("qcsnet4_lblk2_input",
         qnn.QuantIdentity(bit_width=num_bits, return_quant_tensor=True)),
        
        ("qcsnet4_lblk2_qlinear",
         qnn.QuantLinear(
             128, 4, bias=False,  # Output layer: 4 classes
             weight_bit_width=num_bits,
             weight_quant=Int8WeightPerTensorFloat,
             output_quant=Int8ActPerTensorFloat,
             return_quant_tensor=True)),
        
        ("qcsnet4_lblk2_leaky",
         snn.Leaky(beta=beta4, learn_beta=learn_beta4,
                   spike_grad=spike_grad4,
                   threshold=threshold4, learn_threshold=learn_threshold4,
                   init_hidden=True, output=True)),
    ]))
    
    # CRITICAL FIX: Manually override runtime_shape for all 3 BatchNorm layers
    model.qcsnet4_cblk1_batch_norm.runtime_shape = (1, -1, 1)
    model.qcsnet4_cblk2_batch_norm.runtime_shape = (1, -1, 1)
    model.qcsnet4_cblk3_batch_norm.runtime_shape = (1, -1, 1)
    
    return model


def forward_pass(model, num_steps, data):
    """Run SNN for num_steps and collect spike/membrane recordings"""
    mem_rec = []
    spk_rec = []
    utils.reset(model)  # resets hidden states for all LIF neurons in net
    
    for step in range(num_steps):
        spk_out, mem_out = model(data)
        spk_rec.append(spk_out)
        mem_rec.append(mem_out)
    
    return torch.stack(spk_rec), torch.stack(mem_rec)


In [12]:
def train_epoch(model4, dataloader, loss_func, optimizer4, device, num_steps):
    train4_loss, train4_correct = 0.0, 0
    model4.train()
    for inputs, targets in dataloader:
        optimizer4.zero_grad()
        inputs = inputs.to(device, non_blocking=True)
        targets4 = targets.to(device, non_blocking=True).long()
        
        # Forward pass - returns [num_steps, batch, 4]
        output4, _ = forward_pass(model4, num_steps, inputs) 
        
        # TET: Compute loss at EACH timestep, then average
        loss4 = torch.stack([loss_func(output4[t], targets4) for t in range(num_steps)]).mean()
        
        loss4.backward()
        optimizer4.step()
        
        # For accuracy, use mean across timesteps (same as before)
        output4_mean = output4.mean(0)
        bs = targets4.size(0)
        train4_loss += loss4.item() * bs
        train4_correct += (output4_mean.argmax(1) == targets4).sum().item()
    
    return train4_loss, train4_correct

In [13]:
@torch.no_grad()
def validation_epoch(model4, dataloader, loss_func, device, num_steps):
    """
    Validate a 4-class model on the dataloader.
    Returns dictionary with per-class metrics + macro averages.
    """
    model4.eval()
    
    # ------------------------------------------------------------------ #
    # Metric containers
    # ------------------------------------------------------------------ #
    metrics4 = {
        'loss': 0., 'acc': 0.,
        'precision': [0., 0., 0., 0.],         # per-class lists
        'recall':    [0., 0., 0., 0.],
        'specificity': [0., 0., 0., 0.],
        'f1-score': [0., 0., 0., 0.],
    }
   
    valid4_loss = 0.0
    valid4_correct = 0
 
    # confusion-matrix elements - one array per metric and per class
    tp4 = [0, 0, 0, 0]
    fp4 = [0, 0, 0, 0]
    tn4 = [0, 0, 0, 0]
    fn4 = [0, 0, 0, 0]
    n_seen = 0
    
    # ------------------------------------------------------------------ #
    # Validation loop
    # ------------------------------------------------------------------ #
    for x, y_multi in dataloader:
        x = x.to(device, non_blocking=True) 
        y_multi = y_multi.to(device, non_blocking=True).long()
      
        # Forward pass - returns [num_steps, batch, 4]
        out4, _ = forward_pass(model4, num_steps, x)
        
        # TET: Compute loss at each timestep, then average
        loss4 = torch.stack([loss_func(out4[t], y_multi) for t in range(num_steps)]).mean()
        
        # For predictions, use average across timesteps
        out4_mean = out4.mean(0)
        pred4 = out4_mean.argmax(1)
        
        bs = y_multi.size(0)
        valid4_loss += loss4.item() * bs
        valid4_correct += (pred4 == y_multi).sum().item()
        n_seen += bs
        
        # update multi-class confusion matrix
        for i in (0, 1, 2, 3):
            tp4[i] += ((pred4 == i) & (y_multi == i)).sum().item()
            fp4[i] += ((pred4 == i) & (y_multi != i)).sum().item()
            tn4[i] += ((pred4 != i) & (y_multi != i)).sum().item()
            fn4[i] += ((pred4 != i) & (y_multi == i)).sum().item()
    
    # ------------------------------------------------------------------ #
    # Aggregate losses & accuracies (per-sample averages)
    # ------------------------------------------------------------------ #
    metrics4['loss'] = valid4_loss / max(n_seen, 1)
    metrics4['acc']  = 100.0 * valid4_correct / max(n_seen, 1)
    
    # ------------------------------------------------------------------ #
    # Per-class metrics + macro averages
    # ------------------------------------------------------------------ #
    def per_class_metrics(tp, fp, tn, fn):
        EPS = 1e-8          # numerical safety
        prec = [t / (t + f + EPS) for t, f in zip(tp, fp)]
        rec  = [t / (t + f + EPS) for t, f in zip(tp, fn)]
        spec = [t / (t + f + EPS) for t, f in zip(tn, fp)]
        f1   = [2*p*r / (p + r + EPS) for p, r in zip(prec, rec)]
        macro = {
            'precision_macro': sum(prec)/len(prec),
            'recall_macro':    sum(rec)/len(rec),
            'specificity_macro': sum(spec)/len(spec),
            'f1_macro':        sum(f1)/len(f1),
        }
        return prec, rec, spec, f1, macro
    
    # Calculate metrics
    prec4, rec4, spec4, f14, macro4 = per_class_metrics(tp4, fp4, tn4, fn4)
    metrics4['precision']   = prec4
    metrics4['recall']      = rec4
    metrics4['specificity'] = spec4
    metrics4['f1-score']    = f14
    metrics4.update(macro4)
    
    return metrics4

In [14]:
import gc
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset, TensorDataset
from sklearn.model_selection import StratifiedKFold


def _clear_cuda_cache():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


# NOTE: keep your existing Brevitas warmup if you already have one.
# This stub assumes forward_pass(model, num_steps, x) exists.
@torch.no_grad()
def _brevitas_warmup(model, dataset, device, num_steps=10, bs=8, n_steps=10):
    loader = DataLoader(dataset, batch_size=bs, shuffle=True, num_workers=0,
                        pin_memory=(device.type == "cuda"))
    it = iter(loader)
    for _ in range(n_steps):
        try:
            x, _ = next(it)
        except StopIteration:
            it = iter(loader)
            x, _ = next(it)
        x = x.to(device, non_blocking=True)
        out, _ = forward_pass(model, num_steps, x)  # (T,B,2) expected
        _ = out.mean(0)  # just force execution


In [15]:
def train_model_cv_4class(model_factory, epochs, dataset, device,
                          loss_func, optimizer_class, optimizer_kwargs,
                          num_steps=10, k_folds=6, batch_size=128,
                          monitor="f1_macro", mode="max"):
    """
    K-fold CV with fresh model per fold for 4-class classification.
    Uses per-fold class-weighted CrossEntropyLoss.
    Classes: 0=Normal, 1=SVEB, 2=VEB, 3=F
    """

    if isinstance(device, str):
        device = torch.device(device)

    # ---- Brevitas warmup template ----
    print("Creating template model for warmup...")
    model_template = model_factory().to(device)
    model_template.train()

    print("Running Brevitas warmup...")
    _brevitas_warmup(model_template, dataset, device, num_steps=num_steps, bs=8, n_steps=10)

    with torch.no_grad():
        warmup_state = {k: v.detach().cpu().clone()
                        for k, v in model_template.state_dict().items()}

    scaling_keys = sum('scaling_impl.value' in k for k in warmup_state.keys())
    print(f"Scaling keys after warmup: {scaling_keys}")
    if scaling_keys == 0:
        print("WARNING: No scaling_impl.value keys found!")

    del model_template
    _clear_cuda_cache()

    # ---- Label helpers ----
    def all_labels(ds):
        if hasattr(ds, 'tensors') and len(ds.tensors) >= 2:
            return ds.tensors[1].detach().cpu().long().numpy()
        return np.asarray([ds[i][1] for i in range(len(ds))], dtype=np.int64)

    def subset_labels(sub: Subset):
        base, idx = sub.dataset, sub.indices
        if hasattr(base, 'tensors') and len(base.tensors) >= 2:
            return base.tensors[1][idx].detach().cpu().long().numpy()
        return np.asarray([base[i][1] for i in idx], dtype=np.int64)

    # ---- Monitor score helper (expanded for 4 classes) ----
    per_class_map = {
        # Normal (class 0)
        "recall_normal": ("recall", 0),
        "precision_normal": ("precision", 0),
        "f1_normal": ("f1-score", 0),
        "specificity_normal": ("specificity", 0),
        
        # SVEB (class 1)
        "recall_sveb": ("recall", 1),
        "precision_sveb": ("precision", 1),
        "f1_sveb": ("f1-score", 1),
        "specificity_sveb": ("specificity", 1),
        
        # VEB (class 2)
        "recall_veb": ("recall", 2),
        "precision_veb": ("precision", 2),
        "f1_veb": ("f1-score", 2),
        "specificity_veb": ("specificity", 2),
        
        # F (class 3)
        "recall_f": ("recall", 3),
        "precision_f": ("precision", 3),
        "f1_f": ("f1-score", 3),
        "specificity_f": ("specificity", 3),
    }

    def extract_score(metrics, monitor_key):
        if monitor_key in per_class_map:
            k, i = per_class_map[monitor_key]
            return metrics[k][i]
        return metrics[monitor_key]

    # ---- CV setup ----
    y_all = all_labels(dataset)
    skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)

    fold_history = {}
    fold_ckpts = {}

    better = (lambda a, b: a > b) if mode == "max" else (lambda a, b: a < b)

    for fold, (train_idx, valid_idx) in enumerate(skf.split(np.arange(len(y_all)), y_all), 1):
        print(f"\n{'='*60}")
        print(f"Fold {fold}/{k_folds}")
        print(f"{'='*60}")

        # FRESH model instance per fold
        model_fold = model_factory()                 # CPU
        model_fold.load_state_dict(warmup_state)     # CPU load
        model_fold.to(device)
        model_fold.train()

        # Fresh optimizer
        optimizer_fold = optimizer_class(model_fold.parameters(), **optimizer_kwargs)

        # Subsets
        train_subset = Subset(dataset, train_idx)
        valid_subset = Subset(dataset, valid_idx)

        # ============ MODIFIED: Per-fold class weights for 4 classes ============
        y_train = subset_labels(train_subset)
        n0 = int((y_train == 0).sum())  # Normal
        n1 = int((y_train == 1).sum())  # SVEB
        n2 = int((y_train == 2).sum())  # VEB
        n3 = int((y_train == 3).sum())  # F
        
        total = len(y_train)
        # Inverse frequency weighting
        w0 = total / (4 * max(n0, 1))
        w1 = total / (4 * max(n1, 1))
        w2 = total / (4 * max(n2, 1))
        w3 = total / (4 * max(n3, 1))

        print(f"Train counts: Normal={n0}, SVEB={n1}, VEB={n2}, F={n3}")
        print(f"Class weights: w0={w0:.3f}, w1={w1:.3f}, w2={w2:.3f}, w3={w3:.3f}")

        class_w = torch.tensor([w0, w1, w2, w3], dtype=torch.float32, device=device)
        loss_func_fold = torch.nn.CrossEntropyLoss(weight=class_w)
        # =========================================================================

        # Loaders
        train_loader = DataLoader(
            train_subset,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=(device.type == 'cuda'),
            num_workers=0
        )
        valid_loader = DataLoader(
            valid_subset,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=(device.type == 'cuda'),
            num_workers=0
        )

        # ============ MODIFIED: History tracking for 4 classes ============
        epoch_history = {
            'train4_loss': [], 'train4_acc': [],
            'valid4_loss': [], 'valid4_acc': [],
            'valid4_prec': [], 'valid4_rec': [], 'valid4_spec': [], 'valid4_f1': [],
            'valid4_prec_macro': [], 'valid4_rec_macro': [],
            'valid4_spec_macro': [], 'valid4_f1_macro': []
        }
        # ===================================================================

        best_score = -float("inf") if mode == "max" else float("inf")
        best_epoch = None
        best_state_cpu = None
        
        patience = 5
        no_improve = 0

        for epoch in range(1, epochs + 1):
            tr4_loss_sum, tr4_corr = train_epoch(
                model_fold, train_loader, loss_func_fold, optimizer_fold, device, num_steps
            )

            v_metrics4 = validation_epoch(
                model_fold, valid_loader, loss_func_fold, device, num_steps
            )

            n_train = len(train_subset)
            tr4_loss = tr4_loss_sum / max(n_train, 1)
            tr4_acc  = 100.0 * tr4_corr / max(n_train, 1)

            # ============ MODIFIED: Print statement for 4 classes ============
            print(
                f"Epoch {epoch:2d}/{epochs} | "
                f"Train  L={tr4_loss:.4f}  A={tr4_acc:5.2f}% | "
                f"Valid  L={v_metrics4['loss']:.4f}  A={v_metrics4['acc']:5.2f}% | "
                f"Rec: N={v_metrics4['recall'][0]:.3f} "
                f"S={v_metrics4['recall'][1]:.3f} "
                f"V={v_metrics4['recall'][2]:.3f} "
                f"F={v_metrics4['recall'][3]:.3f} | "
                f"F1(M)={v_metrics4['f1_macro']:.4f}"
            )
            # ==================================================================

            epoch_history['train4_loss'].append(tr4_loss)
            epoch_history['train4_acc'].append(tr4_acc)
            epoch_history['valid4_loss'].append(v_metrics4['loss'])
            epoch_history['valid4_acc'].append(v_metrics4['acc'])

            epoch_history['valid4_prec'].append(v_metrics4['precision'])
            epoch_history['valid4_rec'].append(v_metrics4['recall'])
            epoch_history['valid4_spec'].append(v_metrics4['specificity'])
            epoch_history['valid4_f1'].append(v_metrics4['f1-score'])

            epoch_history['valid4_prec_macro'].append(v_metrics4['precision_macro'])
            epoch_history['valid4_rec_macro'].append(v_metrics4['recall_macro'])
            epoch_history['valid4_spec_macro'].append(v_metrics4['specificity_macro'])
            epoch_history['valid4_f1_macro'].append(v_metrics4['f1_macro'])

            score = extract_score(v_metrics4, monitor)

            if better(score, best_score):
                best_score = score
                best_epoch = epoch
                no_improve = 0
                with torch.no_grad():
                    best_state_cpu = {k: v.detach().cpu().clone()
                                      for k, v in model_fold.state_dict().items()}
            else:
                no_improve += 1
                if no_improve >= patience:
                    print(f"Early stopping: no improvement in {patience} epochs.")
                    break

        fold_key = f"fold{fold}"
        fold_history[fold_key] = epoch_history
        fold_ckpts[fold_key] = {
            "best_epoch": best_epoch,
            "best_score": float(best_score),
            "monitor": monitor,
            "state_dict_cpu": best_state_cpu
        }

        print(f"\n[Fold {fold}] Best {monitor}={best_score:.4f} at epoch {best_epoch}\n")

        del optimizer_fold, model_fold
        del train_loader, valid_loader
        _clear_cuda_cache()

    return fold_history, fold_ckpts

In [16]:
batch_size=128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_steps = 10
epochs = 15
k_folds=6

# Step 1: Load Data
train_data, train_targets = load_train_test_data()

# Step 2: Convert Data to Tensors
dataset = create_csnn_datasets(train_data, train_targets) 

print(f"Using device: {device}")

# ❌ REMOVE THESE LINES - Not needed anymore!
# csnet2.to(device)
# optimizer2 = torch.optim.Adam(csnet2.parameters(), lr=0.0001)

# loss_func = torch.nn.CrossEntropyLoss()   # template (used)
# loss_func = FocalLoss(alpha=[0.25, 0.75], gamma=2.0)  # favors abnormal a bit

# Run CV - this creates fresh models per fold
fold_history, fold_ckpts = train_model_cv_4class(
    model_factory=create_qcsnn_model_4class,
    epochs=80,
    dataset=dataset,
    device=device,
    loss_func=None,  # Created per-fold
    optimizer_class=torch.optim.Adam,
    optimizer_kwargs={'lr': 0.0001},
    num_steps=10,
    k_folds=6,
    batch_size=256,
    monitor="f1_macro",  # Or "recall_veb" for VEB-specific
    mode="max"
)

Loaded: MIT-BIH (80557, 180)
Data range after z-score normalization: [-4.87, 6.05]
Data mean: 0.0000, std: 1.0000
Class distribution: Counter({0: 72073, 2: 5616, 1: 2224, 3: 644})
X train shape: (80557, 1, 180)
y_train shape: (80557,)
Data tensor loaded successfully
dataset created successfully
Using device: cuda
Creating template model for warmup...
Output sizes: Block1=89, Block2=43, Block3=20
Flattened size: 480
Running Brevitas warmup...


  return super().rename(names)


Scaling keys after warmup: 13

Fold 1/6
Output sizes: Block1=89, Block2=43, Block3=20
Flattened size: 480
Train counts: Normal=60060, SVEB=1854, VEB=4680, F=536
Class weights: w0=0.279, w1=9.052, w2=3.586, w3=31.311
Epoch  1/80 | Train  L=1.3863  A=89.47% | Valid  L=1.3863  A=89.47% | Rec: N=1.000 S=0.000 V=0.000 F=0.000 | F1(M)=0.2361
Epoch  2/80 | Train  L=1.3863  A=89.47% | Valid  L=1.3863  A=89.47% | Rec: N=1.000 S=0.000 V=0.000 F=0.000 | F1(M)=0.2361
Epoch  3/80 | Train  L=1.3863  A=89.47% | Valid  L=1.3863  A=89.47% | Rec: N=1.000 S=0.000 V=0.000 F=0.000 | F1(M)=0.2361
Epoch  4/80 | Train  L=1.3413  A=83.88% | Valid  L=1.2647  A=82.47% | Rec: N=0.830 S=0.608 V=0.851 F=0.722 | F1(M)=0.5360
Epoch  5/80 | Train  L=1.1674  A=82.22% | Valid  L=1.0515  A=81.49% | Rec: N=0.817 S=0.646 V=0.862 F=0.704 | F1(M)=0.5226
Epoch  6/80 | Train  L=0.9650  A=81.54% | Valid  L=0.9555  A=82.44% | Rec: N=0.827 S=0.651 V=0.862 F=0.750 | F1(M)=0.5411
Epoch  7/80 | Train  L=0.9423  A=82.82% | Valid  L=0

In [17]:
# Load test data (adjust function name if different)
mitbih_test_path='/.../data/mitbih_processed_intra_patient_4class_180_center90_filtered/test'

test_data, test_targets = load_train_test_data(mitbih_test_path)
test_dataset = create_csnn_datasets(test_data, test_targets)

print(f"Test set size: {len(test_data)}")
print(f"Test normal samples: {(test_targets == 0).sum()}")
print(f"Test sveb samples: {(test_targets == 1).sum()}")
print(f"Test veb samples: {(test_targets == 2).sum()}")
print(f"Test f samples: {(test_targets == 3).sum()}")

Loaded: MIT-BIH (20161, 180)
Data range after z-score normalization: [-5.26, 6.12]
Data mean: 0.0000, std: 1.0000
Class distribution: Counter({0: 18052, 2: 1393, 1: 557, 3: 159})
X train shape: (20161, 1, 180)
y_train shape: (20161,)
Data tensor loaded successfully
dataset created successfully
Test set size: 20161
Test normal samples: 18052
Test sveb samples: 557
Test veb samples: 1393
Test f samples: 159


In [53]:


import os
import math
import numpy as np
import torch

import brevitas.nn as qnn
import snntorch as snn

# -----------------------
# Utilities
# -----------------------

def _ensure_dir(d):
    os.makedirs(d, exist_ok=True)

def _to_one(x):
    if isinstance(x, torch.Tensor):
        if x.numel() == 1:
            return float(x.detach().cpu().item())
        return x.detach().cpu().numpy()
    if isinstance(x, (float, int)):
        return x
    return x

def _get_bit_scale_zp_from_quant(q):
    """Return (bit_width, scale, zero_point, signed) from a brevitas quantizer-like object."""
    bit_width = None
    scale = None
    zero_point = None
    signed = True
    # bit width
    for k in ('bit_width', 'bit_width_impl', 'bit_width_f'):
        v = getattr(q, k, None)
        if v is None: continue
        try:
            v = v() if callable(v) else v
            bit_width = int(_to_one(v))
            break
        except Exception:
            pass
    # scale
    for k in ('scale', 'tensor_scale', 'scale_impl', 'act_scale', 'weight_scale'):
        v = getattr(q, k, None)
        if v is None: continue
        try:
            v = v() if callable(v) else v
            v = _to_one(v)
            if isinstance(v, (float, int)):
                scale = float(v)
                break
            if isinstance(v, np.ndarray) and v.size == 1:
                scale = float(v.item())
                break
        except Exception:
            pass
    # zero-point
    for k in ('zero_point', 'zero_point_impl', 'zp'):
        v = getattr(q, k, None)
        if v is None: continue
        try:
            v = v() if callable(v) else v
            zero_point = int(round(_to_one(v)))
            break
        except Exception:
            pass
    # signed
    sattr = getattr(q, 'signed', None)
    if isinstance(sattr, bool):
        signed = sattr
    elif zero_point is None:
        signed = True
    # defaults
    if bit_width is None: bit_width = 8
    if scale is None:     scale = 1.0
    if zero_point is None: zero_point = 0
    return bit_width, float(scale), int(zero_point), bool(signed)

def _quantize_multiplier(real_multiplier: float):
    """TFLite-style integer multiplier/shift approximation for a positive real multiplier."""
    if real_multiplier <= 0.0:
        return 0, 0
    mantissa, exponent = math.frexp(real_multiplier)  # real = mantissa * 2^exponent, mantissa in [0.5,1)
    q = int(round(mantissa * (1 << 31)))
    if q == (1 << 31):
        q //= 2
        exponent += 1
    shift = 31 - exponent
    if shift < 0:
        q <<= (-shift)
        shift = 0
    return int(q), int(shift)

# helper types used in headers (names only; arrays use ap_int in C++)
def _sum_weights_per_out(W_int8: np.ndarray) -> np.ndarray:
    # W_int8 shape: [OUT_CH, IN_CH, K]
    # returns int32 sums per OUT_CH
    return W_int8.reshape(W_int8.shape[0], -1).sum(axis=1).astype(np.int32)

def _bias_int32_vector(b_f: np.ndarray, s_in: float, s_w: float, out_ch: int) -> np.ndarray:
    # Quantize bias or return zeros if no bias provided
    if b_f is None:
        return np.zeros(out_ch, dtype=np.int32)
    s_bias = s_in * s_w if (s_in and s_w) else 1.0
    return np.round(b_f / s_bias).astype(np.int32)


def _qt_weight_int8_per_tensor(W_f: np.ndarray, scale: float, zero_point: int = 0):
    Wq = np.round(W_f / scale) + zero_point
    return np.clip(Wq, -128, 127).astype(np.int8)

def _bias_int32_from_float(b_f: np.ndarray, s_in: float, s_w: float):
    s_bias = s_in * s_w
    if s_bias == 0.0: s_bias = 1.0
    bq = np.round(b_f / s_bias).astype(np.int32)
    return bq

def _as1(x):
    if isinstance(x, (list, tuple)):
        return int(x[0])
    return int(x)

def _guard_out_scale(name: str, out_scale: float):
    if out_scale is None or not np.isfinite(out_scale) or out_scale <= 0.0:
        raise ValueError(f"{name}: invalid out_scale={out_scale}")

def _guard_in_scale(name: str, in_scale: float):
    if in_scale is None or not np.isfinite(in_scale) or in_scale <= 0.0:
        raise ValueError(f"{name}: invalid in_scale={in_scale}")

def _guard_weight_scale(name: str, w_scale: float):
    if w_scale is None or not np.isfinite(w_scale) or w_scale <= 0.0:
        raise ValueError(f"{name}: invalid weight_scale={w_scale}")

def _id_guard_macro(base: str):
    return base.upper().replace('/', '_').replace('.', '_')

def _sym(name: str):
    """C identifier from layer name (keep as-is but safe for C)."""
    return name.replace('/', '_').replace('.', '_')

def _fmt_int_list(vals, per_line=16):
    out = []
    line = []
    for i, v in enumerate(vals):
        line.append(str(int(v)))
        if (i + 1) % per_line == 0:
            out.append(", ".join(line))
            line = []
    if line:
        out.append(", ".join(line))
    return ",\n    ".join(out)

def _fmt_array_2d(arr2d):
    rows = []
    for r in arr2d:
        rows.append("{ " + _fmt_int_list(r) + " }")
    return "{\n  " + ",\n  ".join(rows) + "\n}"

def _fmt_array_3d(arr3d):
    blocks = []
    for b in arr3d:
        blocks.append(_fmt_array_2d(b))
    return "{\n" + ",\n".join(blocks) + "\n}"

# -----------------------
# Emitters
# -----------------------

def _emit_header_open(fp, guard, ns="hls4csnn1d_cblk_sd"):
    fp.write(f"#ifndef {guard}\n#define {guard}\n\n")
    fp.write("#include <ap_int.h>\n\n")
    fp.write(f"namespace {ns} {{\n\n")

def _emit_header_close(fp, guard, ns="hls4csnn1d_cblk_sd"):
    fp.write(f"}} // namespace\n#endif // {guard}\n")

def _emit_conv1d_header(path, lname, W_int8, rq_mult, rq_shift,
                        bias_int32_vec, input_zp, weight_sum_vec, m):
    # Guard: QCSNET2_CBLK1_QCONV1D_WEIGHTS_H style
    guard = _id_guard_macro(f"{_sym(lname)}_WEIGHTS_H")
    with open(path, "w") as fp:
        _emit_header_open(fp, guard)  # writes includes + namespace line

        sym   = _sym(lname)
        OC, IC, K = W_int8.shape
        stride = _as1(m.stride)

        # Structural constants (optional but handy for TBs)
        fp.write(f"const int {sym}_OUT_CH = {OC};\n")
        fp.write(f"const int {sym}_IN_CH  = {IC};\n")
        fp.write(f"const int {sym}_KERNEL_SIZE = {K};\n")
        fp.write(f"const int {sym}_STRIDE = {stride};\n\n")

        # Input ZP (INT8)
        fp.write(f"const ap_int<8> {sym}_input_zero_point = {int(input_zp)};\n\n")

        # Requantization arrays (duplicated per OUT_CH to match your API)
        fp.write(f"const ap_int<32> {sym}_scale_multiplier[{OC}] = {{\n  ")
        fp.write(_fmt_int_list([rq_mult]*OC))
        fp.write("\n};\n\n")

        fp.write(f"const int {sym}_right_shift[{OC}] = {{\n  ")
        fp.write(_fmt_int_list([rq_shift]*OC))
        fp.write("\n};\n\n")

        # Bias (INT32)
        fp.write(f"const acc32_t {sym}_bias[{OC}] = {{\n  ")
        fp.write(_fmt_int_list(bias_int32_vec))
        fp.write("\n};\n\n")

        # Weight sums for asymmetric correction (INT32)
        fp.write(f"const acc32_t {sym}_weight_sum[{OC}] = {{\n  ")
        fp.write(_fmt_int_list(weight_sum_vec))
        fp.write("\n};\n\n")

        # Weights (INT8): [OUT_CH][IN_CH][K]
        fp.write(f"const ap_int<8> {sym}_weights[{OC}][{IC}][{K}] = ")
        fp.write(_fmt_array_3d(W_int8))
        fp.write(";\n\n")

        _emit_header_close(fp, guard)  # closes namespace + guard


def _emit_linear_header(path, lname,
                           W_int8,                 # [OUT][IN] int8
                           rq_mult, rq_shift,      # per-tensor constants, repeated per OUT
                           bias_int32_vec,         # [OUT] int32
                           input_zp,               # int (will be placed as ap_int<8>)
                           weight_sum_vec):        # [OUT] int32
    guard = _id_guard_macro(f"{_sym(lname)}_WEIGHTS_H")
    with open(path, "w") as fp:
        fp.write(f"#ifndef {guard}\n#define {guard}\n\n")
        fp.write("#include <hls_stream.h>\n#include <ap_int.h>\n#include \"../constants_sd.h\"\n\n")
        fp.write("namespace hls4csnn1d_cblk_sd {\n\n")

        sym = _sym(lname)
        OUT, IN = W_int8.shape

        # Structural (optional helpers)
        fp.write(f"const int {sym}_OUTPUT_SIZE = {OUT};\n")
        fp.write(f"const int {sym}_INPUT_SIZE  = {IN};\n\n")

        # Input zero-point (matches template type)
        fp.write(f"const ap_int<8> {sym}_input_zero_point = {int(input_zp)};\n\n")

        # Requant arrays (match template types; repeat the per-tensor constants)
        fp.write(f"const ap_int<32> {sym}_scale_multiplier[{OUT}] = {{\n  ")
        fp.write(_fmt_int_list([rq_mult] * OUT))
        fp.write("\n};\n\n")

        fp.write(f"const int {sym}_right_shift[{OUT}] = {{\n  ")
        fp.write(_fmt_int_list([rq_shift] * OUT))
        fp.write("\n};\n\n")

        # Bias and weight_sum (acc domain)
        fp.write(f"using acc32_t = ap_int<32>;\n")
        fp.write(f"const acc32_t {sym}_bias[{OUT}] = {{\n  ")
        fp.write(_fmt_int_list(bias_int32_vec))
        fp.write("\n};\n\n")

        fp.write(f"const acc32_t {sym}_weight_sum[{OUT}] = {{\n  ")
        fp.write(_fmt_int_list(weight_sum_vec))
        fp.write("\n};\n\n")

        # Weights (use ap_int8_c to match your template)
        fp.write(f"const ap_int8_c {sym}_weights[{OUT}][{IN}] = ")
        fp.write(_fmt_array_2d(W_int8))
        fp.write(";\n\n")

        fp.write("} // namespace hls4csnn1d_cblk_sd\n")
        fp.write(f"#endif // {guard}\n")



def _emit_bn_header(path, lname, w_q, b32, mult_arr, shift_arr):
    guard = _id_guard_macro(f"{_sym(lname)}_BN_H")
    with open(path, "w") as fp:
        fp.write(f"#ifndef {guard}\n#define {guard}\n\n")
        fp.write("#include <hls_stream.h>\n#include <ap_int.h>\n#include \"../constants_sd.h\"\n\n")
        fp.write("namespace hls4csnn1d_cblk_sd {\n\n")
        sym = _sym(lname); C = len(w_q)

        fp.write(f"const int {sym}_C = {C};\n\n")

        fp.write(f"const ap_int8_c {sym}_weight[{C}] = {{\n  {_fmt_int_list(w_q)}\n}};\n\n")
        fp.write(f"const ap_int<32> {sym}_bias[{C}] = {{\n  {_fmt_int_list(b32)}\n}};\n\n")
        fp.write(f"const ap_int<32> {sym}_scale_multiplier[{C}] = {{\n  {_fmt_int_list(mult_arr)}\n}};\n\n")
        fp.write(f"const int {sym}_right_shift[{C}] = {{\n  {_fmt_int_list(shift_arr)}\n}};\n\n")

        fp.write("} // namespace hls4csnn1d_cblk_sd\n")
        fp.write(f"#endif // {guard}\n")



def _emit_lif_header_scalar_sd(path, lname, beta_q, theta_q, scale_q, frac_bits):
    guard = _id_guard_macro(f"{_sym(lname)}_LIF_H")
    with open(path, "w") as fp:
        fp.write(f"#ifndef {guard}\n#define {guard}\n\n")
        fp.write("#include <hls_stream.h>\n#include <ap_int.h>\n#include \"../constants_sd.h\"\n\n")
        fp.write("namespace hls4csnn1d_cblk_sd {\n\n")
        sym = _sym(lname)
        fp.write(f"enum {{ {sym}_FRAC_BITS = {int(frac_bits)} }};\n")
        fp.write(f"const ap_int<16> {sym}_beta_int   = {int(beta_q)};\n")
        fp.write(f"const ap_int<16> {sym}_theta_int  = {int(theta_q)};\n")
        fp.write(f"const ap_int<16> {sym}_scale_int  = {int(scale_q)};\n\n")
        fp.write("} // namespace hls4csnn1d_cblk_sd\n")
        fp.write(f"#endif // {guard}\n")


def _emit_lif_header_vector_sd(path, lname, beta_arr_q, theta_arr_q, scale_q, frac_bits):
    guard = _id_guard_macro(f"{_sym(lname)}_LIF_H")
    sym   = _sym(lname)
    N     = len(beta_arr_q)
    if len(theta_arr_q) != N:
        raise ValueError("beta/theta array lengths must match")

    def _fmt_list(vals, per_line=16):
        rows = []
        for i in range(0, len(vals), per_line):
            chunk = ", ".join(str(int(v)) for v in vals[i:i+per_line])
            rows.append("    " + chunk)
        return "{\n" + ",\n".join(rows) + "\n}"

    with open(path, "w") as fp:
        fp.write(f"#ifndef {guard}\n#define {guard}\n\n")
        fp.write("#include <hls_stream.h>\n#include <ap_int.h>\n#include \"../constants_sd.h\"\n\n")
        fp.write("namespace hls4csnn1d_cblk_sd {\n\n")
        fp.write(f"enum {{ {sym}_FRAC_BITS = {int(frac_bits)}, {sym}_OUT_CH = {int(N)} }};\n")
        fp.write(f"const ap_int<16> {sym}_scale_int = {int(scale_q)};\n\n")
        fp.write(f"const ap_int<16> {sym}_beta_int[{sym}_OUT_CH] = "  + _fmt_list(beta_arr_q)  + ";\n\n")
        fp.write(f"const ap_int<16> {sym}_theta_int[{sym}_OUT_CH] = " + _fmt_list(theta_arr_q) + ";\n\n")
        fp.write("} // namespace hls4csnn1d_cblk_sd\n")
        fp.write(f"#endif // {guard}\n")



INT16_MIN, INT16_MAX = -32768, 32767

def _to_q_i16(x: float, Q: int) -> int:
    return int(np.clip(round(float(x) * Q), INT16_MIN, INT16_MAX))

def _tensor_to_q_i16_list(t: torch.Tensor, Q: int):
    flat = t.detach().float().reshape(-1).cpu().tolist()
    return [_to_q_i16(v, Q) for v in flat]


_Q_SCALE = 1 << 12   # e.g., FRAC_BITS = 12

def _emit_qparams_header(path, lname, bit_w, scale, zp):
    guard = _id_guard_macro(f"QPARAMS_{_sym(lname)}_H")
    sym = _sym(lname)
    sym_base = sym  # no renaming, no suffix stripping

    # Q-encode the activation scale for HLS QuantIdentity (uses _Q_SCALE; not emitted)
    act_scale_int = _to_q_i16(float(scale), _Q_SCALE)

    with open(path, "w") as fp:
        _emit_header_open(fp, guard)  # must include <ap_int.h> and open your namespace

        fp.write("// Activation quantization parameters (optional for kernels)\n")
        fp.write(f"const int   {sym}_bit_width = {int(bit_w)};\n")
        fp.write(f"// const float {sym}_scale     = {float(scale):.10g};  // kept for reference only\n")
        fp.write(f"const ap_int<16> {sym_base}_act_scale_int = {int(act_scale_int)};\n")
        fp.write(f"const int   {sym}_zero_point= {int(zp)};\n\n")

        _emit_header_close(fp, guard)  # close namespace and guard


# -----------------------
# Orchestrator
# -----------------------

def emit_headers_for_model(model: torch.nn.Module,
                           example_input: torch.Tensor,
                           out_dir: str = "headers_int",
                           lif_frac_bits: int = 12):
    model.eval()
    _ensure_dir(out_dir)

    # Run one dry forward to initialize any lazy buffers (ignore output tuples)
    with torch.no_grad():
        try:
            _ = model(example_input)
        except Exception:
            pass

    # Track current activation qparams (propagated as in your graph)
    current_act = {"bit_width": 8, "scale": 1.0, "zero_point": 0}
    
    last_out_ch = None  # initialize outside the loop

    for name, m in model.named_modules():
        if m is model:
            continue

        # QuantIdentity (export activation qparams as optional header)
        if isinstance(m, qnn.QuantIdentity):
            aq = getattr(m, 'act_quant', getattr(m, 'output_quant', None))
            bit_w, s, zp, _ = _get_bit_scale_zp_from_quant(aq)
            _emit_qparams_header(os.path.join(out_dir, f"qparams_{name}.h"), name, bit_w, s, zp)
            current_act = {"bit_width": bit_w, "scale": s, "zero_point": zp}
            continue


        if isinstance(m, qnn.QuantConv1d):
            # Float → INT8 weights
            Wf = m.weight.detach().cpu().numpy()            # [OUT, IN, K]
            wq = getattr(m, 'weight_quant', None)
            wb, s_w, z_w, _ = _get_bit_scale_zp_from_quant(wq)
        
            _guard_in_scale(name, current_act['scale'])
            _guard_weight_scale(name, s_w)
        
            W_int8 = _qt_weight_int8_per_tensor(Wf, s_w, z_w)
        
            # Output activation qparams (for requant)
            oq = getattr(m, 'output_quant', None)
            ob, s_out, z_out, _ = _get_bit_scale_zp_from_quant(oq)
            _guard_out_scale(name, s_out)
        
            # Integer requant constants (per-tensor → repeat per OUT_CH)
            M = (current_act['scale'] * s_w) / s_out
            rq_mult, rq_shift = _quantize_multiplier(M)
        
            # Bias (if present) → INT32 vector (length OUT_CH); else zeros
            b_f = m.bias.detach().cpu().numpy() if (hasattr(m, 'bias') and m.bias is not None) else None
            bias_int32_vec = _bias_int32_vector(b_f, current_act['scale'], s_w, W_int8.shape[0])
        
            # Weight sums per output channel (for asymmetric correction)
            weight_sum_vec = _sum_weights_per_out(W_int8)
        
            # Input zero point (INT8) for asymmetric correction
            input_zp = current_act['zero_point']
        
            # Emit header that matches your Conv1D_SD::forward signature
            _emit_conv1d_header(
                os.path.join(out_dir, f"{name}_weights.h"),
                name,
                W_int8,
                rq_mult, rq_shift,
                bias_int32_vec,
                input_zp,
                weight_sum_vec,
                m
            )
        
            # Advance activation qparams for the next layer
            current_act = {"bit_width": ob, "scale": s_out, "zero_point": z_out}
            continue


        # BN -> ScaleBias (for BatchNorm1dToQuantScaleBias)
        if hasattr(qnn, 'BatchNorm1dToQuantScaleBias') and isinstance(m, qnn.BatchNorm1dToQuantScaleBias):
            # gamma, beta (float, per channel)
            gamma = _to_one(getattr(m, 'weight', None)) if getattr(m, 'weight', None) is not None else _to_one(getattr(m, 'scale', None))
            beta  = _to_one(getattr(m, 'bias',   None)) if getattr(m, 'bias',   None) is not None else _to_one(getattr(m, 'beta',  None))
            if gamma is None: gamma = 1.0
            if beta  is None: beta  = 0.0
            gamma = np.array(gamma, dtype=np.float32).reshape(-1)
            beta  = np.array(beta,  dtype=np.float32).reshape(-1)
            C = gamma.shape[0]
        
            # Input/output quant
            s_in = float(current_act['scale']);  z_in = int(current_act['zero_point']);  _guard_in_scale(name, s_in)
            oq   = getattr(m, 'output_quant', None)
            _, s_out, _, _ = _get_bit_scale_zp_from_quant(oq);  _guard_out_scale(name, s_out)
        
            # Brevitas BN quantized weight
            qW = m.quant_weight()  # IntQuantTensor
            w_q = qW.int().detach().cpu().numpy().astype(np.int8)    # [C]
            s_w = float(_to_one(qW.scale))                            # scalar
        
            # Shared requant scale S and its integer pair
            S = s_w * (s_in / s_out)
            mult_S, shift_S = _quantize_multiplier(S)
        
            # Int32 bias per channel (no int8 clipping)
            # M_c = (s_in/s_out) * gamma_c  = S * w_q[c]  (approximately)
            M = gamma * (s_in / s_out)                          # [C]
            b32 = np.round((beta / s_out - M * z_in) / S).astype(np.int32)  # [C]
        
            _emit_bn_header(
                os.path.join(out_dir, f"{name}_bn.h"),
                name,
                w_q.tolist(),
                b32.tolist(),
                [mult_S] * C,
                [shift_S] * C
            )
        
            # Advance activation qparams
            current_act = {"bit_width": current_act['bit_width'], "scale": s_out, "zero_point": 0}
            continue



        # QuantLinear
        if isinstance(m, qnn.QuantLinear):
            # Float → INT8 weights
            Wf = m.weight.detach().cpu().numpy()             # [OUT, IN]
            wq = getattr(m, 'weight_quant', None)
            wb, s_w, z_w, _ = _get_bit_scale_zp_from_quant(wq)
        
            _guard_in_scale(name, current_act['scale'])
            _guard_weight_scale(name, s_w)
        
            W_int8 = _qt_weight_int8_per_tensor(Wf, s_w, z_w)
        
            # Output activation qparams (for requant)
            oq = getattr(m, 'output_quant', None)
            ob, s_out, z_out, _ = _get_bit_scale_zp_from_quant(oq)
            _guard_out_scale(name, s_out)
        
            # Requant: M = (s_in * s_w) / s_out
            M = (current_act['scale'] * s_w) / s_out
            rq_mult, rq_shift = _quantize_multiplier(M)
        
            # Bias (if present) → INT32; else zeros
            b_f = m.bias.detach().cpu().numpy() if (hasattr(m, 'bias') and m.bias is not None) else None
            bias_int32_vec = _bias_int32_vector(b_f, current_act['scale'], s_w, W_int8.shape[0])
        
            # Weight sums for asymmetric correction: sum over input dim
            weight_sum_vec = W_int8.sum(axis=1).astype(np.int32)
        
            # Input zero-point for asymmetric correction
            input_zp = current_act['zero_point']
        
            # Emit header that matches Linear1D_SD::forward
            _emit_linear_header(
                os.path.join(out_dir, f"{name}_weights.h"),
                name,
                W_int8,
                rq_mult, rq_shift,
                bias_int32_vec,
                input_zp,
                weight_sum_vec
            )
        
            # Advance activation qparams
            current_act = {"bit_width": ob, "scale": s_out, "zero_point": z_out}
            continue

        
        if isinstance(m, snn.Leaky):
            scale_in = float(current_act['scale'])
            Q = 1 << lif_frac_bits
        
            # Grab beta/threshold as tensors (handles both scalar and vector)
            beta_t  = m.beta if isinstance(m.beta, torch.Tensor) else torch.as_tensor(m.beta)
            thr_t   = m.threshold if isinstance(m.threshold, torch.Tensor) else torch.as_tensor(m.threshold)
        
            beta_t  = beta_t.detach().float()
            print("beta: ", beta_t)
            thr_t   = thr_t.detach().float()
        
            # Number of “neurons” (channels) from parameter size
            n_beta  = int(beta_t.numel())
            n_thr   = int(thr_t.numel())
            if n_beta != n_thr and n_beta != 1 and n_thr != 1:
                raise ValueError(f"LIF param size mismatch: beta has {n_beta}, threshold has {n_thr}")
        
            # Broadcast if one is scalar
            out_ch = max(n_beta, n_thr)
            if n_beta == 1 and out_ch > 1:
                beta_t = beta_t.expand(out_ch)
            if n_thr == 1 and out_ch > 1:
                thr_t = thr_t.expand(out_ch)
        
            # Quantize
            beta_arr_q  = _tensor_to_q_i16_list(beta_t, Q)
            theta_arr_q = _tensor_to_q_i16_list(thr_t, Q)
            scale_q     = _to_q_i16(scale_in, Q)   # keep scale scalar for now
        
            # Emit vector or scalar header depending on out_ch
            out_path = os.path.join(out_dir, f"{name}_lif.h")
            if out_ch > 1:
                _emit_lif_header_vector_sd(
                    out_path, name, beta_arr_q, theta_arr_q, scale_q, lif_frac_bits
                )
            else:
                _emit_lif_header_scalar_sd(
                    out_path, name, beta_arr_q[0], theta_arr_q[0], scale_q, lif_frac_bits
                )
        
            # LIF outputs binary spikes {0,1} → treat next op input scale as 1.0
            current_act = {"bit_width": 8, "scale": 1.0, "zero_point": 0}
            continue



    print(f"[emit] C++ headers written to: {os.path.abspath(out_dir)}")


In [55]:
# Extract weights from Fold 5 (4-class model)
model_stage1 = create_qcsnn_model_4class()
model_stage1.load_state_dict(fold_ckpts['fold5']['state_dict_cpu'])
model_stage1.eval()

# Create example input
example_input = torch.randn(1, 1, 180)

# Extract to C++ headers
emit_headers_for_model(
    model=model_stage1,
    example_input=example_input,
    out_dir="weights_sd/intra_patient/headers_stage2"
)

Output sizes: Block1=89, Block2=43, Block3=20
Flattened size: 480
beta:  tensor(0.4599)
beta:  tensor(0.5905)
beta:  tensor(0.7678)
beta:  tensor(0.8207)
beta:  tensor(1.0003)
[emit] C++ headers written to: /home/velox-217533/Projects/fau_projects/research/snn_quant/model4/intra_patient_models/weights_sd/intra_patient/headers_stage2


In [59]:
# 2. 4-class model (Fold 5)
model_4class = create_qcsnn_model_4class()
model_4class.load_state_dict(fold_ckpts['fold5']['state_dict_cpu'])
torch.save(model_4class.state_dict(), 'weights_sd/fold5_4class_fpga_weights.pth')
print("Saved: fold5_4class_fpga_weights.pth")

Output sizes: Block1=89, Block2=43, Block3=20
Flattened size: 480
Saved: fold5_4class_fpga_weights.pth


In [None]:


import os
import math
import numpy as np
import torch

import brevitas.nn as qnn
import snntorch as snn

# -----------------------
# Utilities
# -----------------------

def _ensure_dir(d):
    os.makedirs(d, exist_ok=True)

def _to_one(x):
    if isinstance(x, torch.Tensor):
        if x.numel() == 1:
            return float(x.detach().cpu().item())
        return x.detach().cpu().numpy()
    if isinstance(x, (float, int)):
        return x
    return x

def _get_bit_scale_zp_from_quant(q):
    """Return (bit_width, scale, zero_point, signed) from a brevitas quantizer-like object."""
    bit_width = None
    scale = None
    zero_point = None
    signed = True
    # bit width
    for k in ('bit_width', 'bit_width_impl', 'bit_width_f'):
        v = getattr(q, k, None)
        if v is None: continue
        try:
            v = v() if callable(v) else v
            bit_width = int(_to_one(v))
            break
        except Exception:
            pass
    # scale
    for k in ('scale', 'tensor_scale', 'scale_impl', 'act_scale', 'weight_scale'):
        v = getattr(q, k, None)
        if v is None: continue
        try:
            v = v() if callable(v) else v
            v = _to_one(v)
            if isinstance(v, (float, int)):
                scale = float(v)
                break
            if isinstance(v, np.ndarray) and v.size == 1:
                scale = float(v.item())
                break
        except Exception:
            pass
    # zero-point
    for k in ('zero_point', 'zero_point_impl', 'zp'):
        v = getattr(q, k, None)
        if v is None: continue
        try:
            v = v() if callable(v) else v
            zero_point = int(round(_to_one(v)))
            break
        except Exception:
            pass
    # signed
    sattr = getattr(q, 'signed', None)
    if isinstance(sattr, bool):
        signed = sattr
    elif zero_point is None:
        signed = True
    # defaults
    if bit_width is None: bit_width = 8
    if scale is None:     scale = 1.0
    if zero_point is None: zero_point = 0
    return bit_width, float(scale), int(zero_point), bool(signed)

def _quantize_multiplier(real_multiplier: float):
    """TFLite-style integer multiplier/shift approximation for a positive real multiplier."""
    if real_multiplier <= 0.0:
        return 0, 0
    mantissa, exponent = math.frexp(real_multiplier)  # real = mantissa * 2^exponent, mantissa in [0.5,1)
    q = int(round(mantissa * (1 << 31)))
    if q == (1 << 31):
        q //= 2
        exponent += 1
    shift = 31 - exponent
    if shift < 0:
        q <<= (-shift)
        shift = 0
    return int(q), int(shift)

# helper types used in headers (names only; arrays use ap_int in C++)
def _sum_weights_per_out(W_int8: np.ndarray) -> np.ndarray:
    # W_int8 shape: [OUT_CH, IN_CH, K]
    # returns int32 sums per OUT_CH
    return W_int8.reshape(W_int8.shape[0], -1).sum(axis=1).astype(np.int32)

def _bias_int32_vector(b_f: np.ndarray, s_in: float, s_w: float, out_ch: int) -> np.ndarray:
    # Quantize bias or return zeros if no bias provided
    if b_f is None:
        return np.zeros(out_ch, dtype=np.int32)
    s_bias = s_in * s_w if (s_in and s_w) else 1.0
    return np.round(b_f / s_bias).astype(np.int32)


def _qt_weight_int8_per_tensor(W_f: np.ndarray, scale: float, zero_point: int = 0):
    Wq = np.round(W_f / scale) + zero_point
    return np.clip(Wq, -128, 127).astype(np.int8)

def _bias_int32_from_float(b_f: np.ndarray, s_in: float, s_w: float):
    s_bias = s_in * s_w
    if s_bias == 0.0: s_bias = 1.0
    bq = np.round(b_f / s_bias).astype(np.int32)
    return bq

def _as1(x):
    if isinstance(x, (list, tuple)):
        return int(x[0])
    return int(x)

def _guard_out_scale(name: str, out_scale: float):
    if out_scale is None or not np.isfinite(out_scale) or out_scale <= 0.0:
        raise ValueError(f"{name}: invalid out_scale={out_scale}")

def _guard_in_scale(name: str, in_scale: float):
    if in_scale is None or not np.isfinite(in_scale) or in_scale <= 0.0:
        raise ValueError(f"{name}: invalid in_scale={in_scale}")

def _guard_weight_scale(name: str, w_scale: float):
    if w_scale is None or not np.isfinite(w_scale) or w_scale <= 0.0:
        raise ValueError(f"{name}: invalid weight_scale={w_scale}")

def _id_guard_macro(base: str):
    return base.upper().replace('/', '_').replace('.', '_')

def _sym(name: str):
    """C identifier from layer name (keep as-is but safe for C)."""
    return name.replace('/', '_').replace('.', '_')

def _fmt_int_list(vals, per_line=16):
    out = []
    line = []
    for i, v in enumerate(vals):
        line.append(str(int(v)))
        if (i + 1) % per_line == 0:
            out.append(", ".join(line))
            line = []
    if line:
        out.append(", ".join(line))
    return ",\n    ".join(out)

def _fmt_array_2d(arr2d):
    rows = []
    for r in arr2d:
        rows.append("{ " + _fmt_int_list(r) + " }")
    return "{\n  " + ",\n  ".join(rows) + "\n}"

def _fmt_array_3d(arr3d):
    blocks = []
    for b in arr3d:
        blocks.append(_fmt_array_2d(b))
    return "{\n" + ",\n".join(blocks) + "\n}"

# -----------------------
# Emitters
# -----------------------

def _emit_header_open(fp, guard, ns="hls4csnn1d_cblk_sd"):
    fp.write(f"#ifndef {guard}\n#define {guard}\n\n")
    fp.write("#include <ap_int.h>\n\n")
    fp.write(f"namespace {ns} {{\n\n")

def _emit_header_close(fp, guard, ns="hls4csnn1d_cblk_sd"):
    fp.write(f"}} // namespace\n#endif // {guard}\n")

def _emit_conv1d_header(path, lname, W_int8, rq_mult, rq_shift,
                        bias_int32_vec, input_zp, weight_sum_vec, m):
    # Guard: QCSNET2_CBLK1_QCONV1D_WEIGHTS_H style
    guard = _id_guard_macro(f"{_sym(lname)}_WEIGHTS_H")
    with open(path, "w") as fp:
        _emit_header_open(fp, guard)  # writes includes + namespace line

        sym   = _sym(lname)
        OC, IC, K = W_int8.shape
        stride = _as1(m.stride)

        # Structural constants (optional but handy for TBs)
        fp.write(f"const int {sym}_OUT_CH = {OC};\n")
        fp.write(f"const int {sym}_IN_CH  = {IC};\n")
        fp.write(f"const int {sym}_KERNEL_SIZE = {K};\n")
        fp.write(f"const int {sym}_STRIDE = {stride};\n\n")

        # Input ZP (INT8)
        fp.write(f"const ap_int<8> {sym}_input_zero_point = {int(input_zp)};\n\n")

        # Requantization arrays (duplicated per OUT_CH to match your API)
        fp.write(f"const ap_int<32> {sym}_scale_multiplier[{OC}] = {{\n  ")
        fp.write(_fmt_int_list([rq_mult]*OC))
        fp.write("\n};\n\n")

        fp.write(f"const int {sym}_right_shift[{OC}] = {{\n  ")
        fp.write(_fmt_int_list([rq_shift]*OC))
        fp.write("\n};\n\n")

        # Bias (INT32)
        fp.write(f"const acc32_t {sym}_bias[{OC}] = {{\n  ")
        fp.write(_fmt_int_list(bias_int32_vec))
        fp.write("\n};\n\n")

        # Weight sums for asymmetric correction (INT32)
        fp.write(f"const acc32_t {sym}_weight_sum[{OC}] = {{\n  ")
        fp.write(_fmt_int_list(weight_sum_vec))
        fp.write("\n};\n\n")

        # Weights (INT8): [OUT_CH][IN_CH][K]
        fp.write(f"const ap_int<8> {sym}_weights[{OC}][{IC}][{K}] = ")
        fp.write(_fmt_array_3d(W_int8))
        fp.write(";\n\n")

        _emit_header_close(fp, guard)  # closes namespace + guard


def _emit_linear_header(path, lname,
                           W_int8,                 # [OUT][IN] int8
                           rq_mult, rq_shift,      # per-tensor constants, repeated per OUT
                           bias_int32_vec,         # [OUT] int32
                           input_zp,               # int (will be placed as ap_int<8>)
                           weight_sum_vec):        # [OUT] int32
    guard = _id_guard_macro(f"{_sym(lname)}_WEIGHTS_H")
    with open(path, "w") as fp:
        fp.write(f"#ifndef {guard}\n#define {guard}\n\n")
        fp.write("#include <hls_stream.h>\n#include <ap_int.h>\n#include \"../constants_sd.h\"\n\n")
        fp.write("namespace hls4csnn1d_cblk_sd {\n\n")

        sym = _sym(lname)
        OUT, IN = W_int8.shape

        # Structural (optional helpers)
        fp.write(f"const int {sym}_OUTPUT_SIZE = {OUT};\n")
        fp.write(f"const int {sym}_INPUT_SIZE  = {IN};\n\n")

        # Input zero-point (matches template type)
        fp.write(f"const ap_int<8> {sym}_input_zero_point = {int(input_zp)};\n\n")

        # Requant arrays (match template types; repeat the per-tensor constants)
        fp.write(f"const ap_int<32> {sym}_scale_multiplier[{OUT}] = {{\n  ")
        fp.write(_fmt_int_list([rq_mult] * OUT))
        fp.write("\n};\n\n")

        fp.write(f"const int {sym}_right_shift[{OUT}] = {{\n  ")
        fp.write(_fmt_int_list([rq_shift] * OUT))
        fp.write("\n};\n\n")

        # Bias and weight_sum (acc domain)
        fp.write(f"using acc32_t = ap_int<32>;\n")
        fp.write(f"const acc32_t {sym}_bias[{OUT}] = {{\n  ")
        fp.write(_fmt_int_list(bias_int32_vec))
        fp.write("\n};\n\n")

        fp.write(f"const acc32_t {sym}_weight_sum[{OUT}] = {{\n  ")
        fp.write(_fmt_int_list(weight_sum_vec))
        fp.write("\n};\n\n")

        # Weights (use ap_int8_c to match your template)
        fp.write(f"const ap_int8_c {sym}_weights[{OUT}][{IN}] = ")
        fp.write(_fmt_array_2d(W_int8))
        fp.write(";\n\n")

        fp.write("} // namespace hls4csnn1d_cblk_sd\n")
        fp.write(f"#endif // {guard}\n")



def _emit_bn_header(path, lname, w_q, b32, mult_arr, shift_arr):
    guard = _id_guard_macro(f"{_sym(lname)}_BN_H")
    with open(path, "w") as fp:
        fp.write(f"#ifndef {guard}\n#define {guard}\n\n")
        fp.write("#include <hls_stream.h>\n#include <ap_int.h>\n#include \"../constants_sd.h\"\n\n")
        fp.write("namespace hls4csnn1d_cblk_sd {\n\n")
        sym = _sym(lname); C = len(w_q)

        fp.write(f"const int {sym}_C = {C};\n\n")

        fp.write(f"const ap_int8_c {sym}_weight[{C}] = {{\n  {_fmt_int_list(w_q)}\n}};\n\n")
        fp.write(f"const ap_int<32> {sym}_bias[{C}] = {{\n  {_fmt_int_list(b32)}\n}};\n\n")
        fp.write(f"const ap_int<32> {sym}_scale_multiplier[{C}] = {{\n  {_fmt_int_list(mult_arr)}\n}};\n\n")
        fp.write(f"const int {sym}_right_shift[{C}] = {{\n  {_fmt_int_list(shift_arr)}\n}};\n\n")

        fp.write("} // namespace hls4csnn1d_cblk_sd\n")
        fp.write(f"#endif // {guard}\n")



def _emit_lif_header_scalar_sd(path, lname, beta_q, theta_q, scale_q, frac_bits):
    guard = _id_guard_macro(f"{_sym(lname)}_LIF_H")
    with open(path, "w") as fp:
        fp.write(f"#ifndef {guard}\n#define {guard}\n\n")
        fp.write("#include <hls_stream.h>\n#include <ap_int.h>\n#include \"../constants_sd.h\"\n\n")
        fp.write("namespace hls4csnn1d_cblk_sd {\n\n")
        sym = _sym(lname)
        fp.write(f"enum {{ {sym}_FRAC_BITS = {int(frac_bits)} }};\n")
        fp.write(f"const ap_int<16> {sym}_beta_int   = {int(beta_q)};\n")
        fp.write(f"const ap_int<16> {sym}_theta_int  = {int(theta_q)};\n")
        fp.write(f"const ap_int<16> {sym}_scale_int  = {int(scale_q)};\n\n")
        fp.write("} // namespace hls4csnn1d_cblk_sd\n")
        fp.write(f"#endif // {guard}\n")


def _emit_lif_header_vector_sd(path, lname, beta_arr_q, theta_arr_q, scale_q, frac_bits):
    guard = _id_guard_macro(f"{_sym(lname)}_LIF_H")
    sym   = _sym(lname)
    N     = len(beta_arr_q)
    if len(theta_arr_q) != N:
        raise ValueError("beta/theta array lengths must match")

    def _fmt_list(vals, per_line=16):
        rows = []
        for i in range(0, len(vals), per_line):
            chunk = ", ".join(str(int(v)) for v in vals[i:i+per_line])
            rows.append("    " + chunk)
        return "{\n" + ",\n".join(rows) + "\n}"

    with open(path, "w") as fp:
        fp.write(f"#ifndef {guard}\n#define {guard}\n\n")
        fp.write("#include <hls_stream.h>\n#include <ap_int.h>\n#include \"../constants_sd.h\"\n\n")
        fp.write("namespace hls4csnn1d_cblk_sd {\n\n")
        fp.write(f"enum {{ {sym}_FRAC_BITS = {int(frac_bits)}, {sym}_OUT_CH = {int(N)} }};\n")
        fp.write(f"const ap_int<16> {sym}_scale_int = {int(scale_q)};\n\n")
        fp.write(f"const ap_int<16> {sym}_beta_int[{sym}_OUT_CH] = "  + _fmt_list(beta_arr_q)  + ";\n\n")
        fp.write(f"const ap_int<16> {sym}_theta_int[{sym}_OUT_CH] = " + _fmt_list(theta_arr_q) + ";\n\n")
        fp.write("} // namespace hls4csnn1d_cblk_sd\n")
        fp.write(f"#endif // {guard}\n")



INT16_MIN, INT16_MAX = -32768, 32767

def _to_q_i16(x: float, Q: int) -> int:
    return int(np.clip(round(float(x) * Q), INT16_MIN, INT16_MAX))

def _tensor_to_q_i16_list(t: torch.Tensor, Q: int):
    flat = t.detach().float().reshape(-1).cpu().tolist()
    return [_to_q_i16(v, Q) for v in flat]


_Q_SCALE = 1 << 12   # e.g., FRAC_BITS = 12

def _emit_qparams_header(path, lname, bit_w, scale, zp):
    guard = _id_guard_macro(f"QPARAMS_{_sym(lname)}_H")
    sym = _sym(lname)
    sym_base = sym  # no renaming, no suffix stripping

    # Q-encode the activation scale for HLS QuantIdentity (uses _Q_SCALE; not emitted)
    act_scale_int = _to_q_i16(float(scale), _Q_SCALE)

    with open(path, "w") as fp:
        _emit_header_open(fp, guard)  # must include <ap_int.h> and open your namespace

        fp.write("// Activation quantization parameters (optional for kernels)\n")
        fp.write(f"const int   {sym}_bit_width = {int(bit_w)};\n")
        fp.write(f"// const float {sym}_scale     = {float(scale):.10g};  // kept for reference only\n")
        fp.write(f"const ap_int<16> {sym_base}_act_scale_int = {int(act_scale_int)};\n")
        fp.write(f"const int   {sym}_zero_point= {int(zp)};\n\n")

        _emit_header_close(fp, guard)  # close namespace and guard


# -----------------------
# Orchestrator
# -----------------------

def emit_headers_for_model(model: torch.nn.Module,
                           example_input: torch.Tensor,
                           out_dir: str = "headers_int",
                           lif_frac_bits: int = 12):
    model.eval()
    _ensure_dir(out_dir)

    # Run one dry forward to initialize any lazy buffers (ignore output tuples)
    with torch.no_grad():
        try:
            _ = model(example_input)
        except Exception:
            pass

    # Track current activation qparams (propagated as in your graph)
    current_act = {"bit_width": 8, "scale": 1.0, "zero_point": 0}
    
    last_out_ch = None  # initialize outside the loop

    for name, m in model.named_modules():
        if m is model:
            continue

        # QuantIdentity (export activation qparams as optional header)
        if isinstance(m, qnn.QuantIdentity):
            aq = getattr(m, 'act_quant', getattr(m, 'output_quant', None))
            bit_w, s, zp, _ = _get_bit_scale_zp_from_quant(aq)
            _emit_qparams_header(os.path.join(out_dir, f"qparams_{name}.h"), name, bit_w, s, zp)
            current_act = {"bit_width": bit_w, "scale": s, "zero_point": zp}
            continue


        if isinstance(m, qnn.QuantConv1d):
            # Float → INT8 weights
            Wf = m.weight.detach().cpu().numpy()            # [OUT, IN, K]
            wq = getattr(m, 'weight_quant', None)
            wb, s_w, z_w, _ = _get_bit_scale_zp_from_quant(wq)
        
            _guard_in_scale(name, current_act['scale'])
            _guard_weight_scale(name, s_w)
        
            W_int8 = _qt_weight_int8_per_tensor(Wf, s_w, z_w)
        
            # Output activation qparams (for requant)
            oq = getattr(m, 'output_quant', None)
            ob, s_out, z_out, _ = _get_bit_scale_zp_from_quant(oq)
            _guard_out_scale(name, s_out)
        
            # Integer requant constants (per-tensor → repeat per OUT_CH)
            M = (current_act['scale'] * s_w) / s_out
            rq_mult, rq_shift = _quantize_multiplier(M)
        
            # Bias (if present) → INT32 vector (length OUT_CH); else zeros
            b_f = m.bias.detach().cpu().numpy() if (hasattr(m, 'bias') and m.bias is not None) else None
            bias_int32_vec = _bias_int32_vector(b_f, current_act['scale'], s_w, W_int8.shape[0])
        
            # Weight sums per output channel (for asymmetric correction)
            weight_sum_vec = _sum_weights_per_out(W_int8)
        
            # Input zero point (INT8) for asymmetric correction
            input_zp = current_act['zero_point']
        
            # Emit header that matches your Conv1D_SD::forward signature
            _emit_conv1d_header(
                os.path.join(out_dir, f"{name}_weights.h"),
                name,
                W_int8,
                rq_mult, rq_shift,
                bias_int32_vec,
                input_zp,
                weight_sum_vec,
                m
            )
        
            # Advance activation qparams for the next layer
            current_act = {"bit_width": ob, "scale": s_out, "zero_point": z_out}
            continue


        # BN -> ScaleBias (for BatchNorm1dToQuantScaleBias)
        if hasattr(qnn, 'BatchNorm1dToQuantScaleBias') and isinstance(m, qnn.BatchNorm1dToQuantScaleBias):
            # gamma, beta (float, per channel)
            gamma = _to_one(getattr(m, 'weight', None)) if getattr(m, 'weight', None) is not None else _to_one(getattr(m, 'scale', None))
            beta  = _to_one(getattr(m, 'bias',   None)) if getattr(m, 'bias',   None) is not None else _to_one(getattr(m, 'beta',  None))
            if gamma is None: gamma = 1.0
            if beta  is None: beta  = 0.0
            gamma = np.array(gamma, dtype=np.float32).reshape(-1)
            beta  = np.array(beta,  dtype=np.float32).reshape(-1)
            C = gamma.shape[0]
        
            # Input/output quant
            s_in = float(current_act['scale']);  z_in = int(current_act['zero_point']);  _guard_in_scale(name, s_in)
            oq   = getattr(m, 'output_quant', None)
            _, s_out, _, _ = _get_bit_scale_zp_from_quant(oq);  _guard_out_scale(name, s_out)
        
            # Brevitas BN quantized weight
            qW = m.quant_weight()  # IntQuantTensor
            w_q = qW.int().detach().cpu().numpy().astype(np.int8)    # [C]
            s_w = float(_to_one(qW.scale))                            # scalar
        
            # Shared requant scale S and its integer pair
            S = s_w * (s_in / s_out)
            mult_S, shift_S = _quantize_multiplier(S)
        
            # Int32 bias per channel (no int8 clipping)
            # M_c = (s_in/s_out) * gamma_c  = S * w_q[c]  (approximately)
            M = gamma * (s_in / s_out)                          # [C]
            b32 = np.round((beta / s_out - M * z_in) / S).astype(np.int32)  # [C]
        
            _emit_bn_header(
                os.path.join(out_dir, f"{name}_bn.h"),
                name,
                w_q.tolist(),
                b32.tolist(),
                [mult_S] * C,
                [shift_S] * C
            )
        
            # Advance activation qparams
            current_act = {"bit_width": current_act['bit_width'], "scale": s_out, "zero_point": 0}
            continue



        # QuantLinear
        if isinstance(m, qnn.QuantLinear):
            # Float → INT8 weights
            Wf = m.weight.detach().cpu().numpy()             # [OUT, IN]
            wq = getattr(m, 'weight_quant', None)
            wb, s_w, z_w, _ = _get_bit_scale_zp_from_quant(wq)
        
            _guard_in_scale(name, current_act['scale'])
            _guard_weight_scale(name, s_w)
        
            W_int8 = _qt_weight_int8_per_tensor(Wf, s_w, z_w)
        
            # Output activation qparams (for requant)
            oq = getattr(m, 'output_quant', None)
            ob, s_out, z_out, _ = _get_bit_scale_zp_from_quant(oq)
            _guard_out_scale(name, s_out)
        
            # Requant: M = (s_in * s_w) / s_out
            M = (current_act['scale'] * s_w) / s_out
            rq_mult, rq_shift = _quantize_multiplier(M)
        
            # Bias (if present) → INT32; else zeros
            b_f = m.bias.detach().cpu().numpy() if (hasattr(m, 'bias') and m.bias is not None) else None
            bias_int32_vec = _bias_int32_vector(b_f, current_act['scale'], s_w, W_int8.shape[0])
        
            # Weight sums for asymmetric correction: sum over input dim
            weight_sum_vec = W_int8.sum(axis=1).astype(np.int32)
        
            # Input zero-point for asymmetric correction
            input_zp = current_act['zero_point']
        
            # Emit header that matches Linear1D_SD::forward
            _emit_linear_header(
                os.path.join(out_dir, f"{name}_weights.h"),
                name,
                W_int8,
                rq_mult, rq_shift,
                bias_int32_vec,
                input_zp,
                weight_sum_vec
            )
        
            # Advance activation qparams
            current_act = {"bit_width": ob, "scale": s_out, "zero_point": z_out}
            continue

        
        if isinstance(m, snn.Leaky):
            scale_in = float(current_act['scale'])
            Q = 1 << lif_frac_bits
        
            # Grab beta/threshold as tensors (handles both scalar and vector)
            beta_t  = m.beta if isinstance(m.beta, torch.Tensor) else torch.as_tensor(m.beta)
            thr_t   = m.threshold if isinstance(m.threshold, torch.Tensor) else torch.as_tensor(m.threshold)
        
            beta_t  = beta_t.detach().float()
            print("beta: ", beta_t)
            thr_t   = thr_t.detach().float()
        
            # Number of “neurons” (channels) from parameter size
            n_beta  = int(beta_t.numel())
            n_thr   = int(thr_t.numel())
            if n_beta != n_thr and n_beta != 1 and n_thr != 1:
                raise ValueError(f"LIF param size mismatch: beta has {n_beta}, threshold has {n_thr}")
        
            # Broadcast if one is scalar
            out_ch = max(n_beta, n_thr)
            if n_beta == 1 and out_ch > 1:
                beta_t = beta_t.expand(out_ch)
            if n_thr == 1 and out_ch > 1:
                thr_t = thr_t.expand(out_ch)
        
            # Quantize
            beta_arr_q  = _tensor_to_q_i16_list(beta_t, Q)
            theta_arr_q = _tensor_to_q_i16_list(thr_t, Q)
            scale_q     = _to_q_i16(scale_in, Q)   # keep scale scalar for now
        
            # Emit vector or scalar header depending on out_ch
            out_path = os.path.join(out_dir, f"{name}_lif.h")
            if out_ch > 1:
                _emit_lif_header_vector_sd(
                    out_path, name, beta_arr_q, theta_arr_q, scale_q, lif_frac_bits
                )
            else:
                _emit_lif_header_scalar_sd(
                    out_path, name, beta_arr_q[0], theta_arr_q[0], scale_q, lif_frac_bits
                )
        
            # LIF outputs binary spikes {0,1} → treat next op input scale as 1.0
            current_act = {"bit_width": 8, "scale": 1.0, "zero_point": 0}
            continue



    print(f"[emit] C++ headers written to: {os.path.abspath(out_dir)}")


In [None]:
# emit_headers_for_model(qcsnet2_eval, torch.randn(1,1,180), out_dir="weights_sd/headers_int")

In [None]:
# Extract weights from Fold 5
model_stage1 = create_qcsnn_model()
model_stage1.load_state_dict(fold_ckpts['fold5']['state_dict_cpu'])
model_stage1.eval()

# Create example input
example_input = torch.randn(1, 1, 180)

# Extract to C++ headers
emit_headers_for_model(
    model=model_stage1,
    example_input=example_input,
    out_dir="weights_sd/intra_patient/headers_stage1"
)