In [None]:
!pip install snntorch --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/125.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import random
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io as sio

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import snntorch as snn
import snntorch.spikegen as spikegen
import snntorch.surrogate as surrogate

from sklearn.preprocessing import StandardScaler, MinMaxScaler

In [None]:
# Set a fixed random seed
RANDOM_SEED = 42  # You can choose any integer

# Set random seeds for reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed_all(RANDOM_SEED)  # for multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# When creating DataLoader, add worker_init_fn for reproducibility
def seed_worker(worker_id):
    worker_seed = RANDOM_SEED + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g = g.manual_seed(RANDOM_SEED)

In [None]:
def get_data_splits(val_idx=0, mabp=False):
    SIGNAL_FOLDS_PATH = os.path.join("drive", "MyDrive", "datasets", "bcg_dataset", "signal_data")

    train_sp_data = train_dp_data = train_signal_data = None
    eval_sp_data  = eval_dp_data  = eval_signal_data  = None

    for file_name in os.listdir(SIGNAL_FOLDS_PATH):
        if not file_name.endswith(".mat"):
            continue
        if ("mabp" in file_name) != mabp:
            continue

        is_eval = f"fold_{val_idx}" in file_name
        raw = sio.loadmat(os.path.join(SIGNAL_FOLDS_PATH, file_name))
        sp = torch.from_numpy(raw["SP"]).float()
        dp = torch.from_numpy(raw["DP"]).float()
        sig = torch.from_numpy(raw["signal"]).float()

        if is_eval:
            eval_sp_data     = sp
            eval_dp_data     = dp
            eval_signal_data = sig
        else:
            train_sp_data     = sp  if train_sp_data is None else torch.vstack([train_sp_data, sp])
            train_dp_data     = dp  if train_dp_data is None else torch.vstack([train_dp_data, dp])
            train_signal_data = sig if train_signal_data is None else torch.vstack([train_signal_data, sig])

    if train_sp_data is None or eval_sp_data is None:
        raise ValueError("Missing training or evaluation data")

    train_data = torch.hstack([train_sp_data, train_dp_data, train_signal_data])
    eval_data  = torch.hstack([eval_sp_data,  eval_dp_data,  eval_signal_data])
    return train_data, eval_data

In [None]:
class PPGDataset(Dataset):
    def __init__(self, data: torch.Tensor):
        self.data = data
        self.X = data[:, 2:]
        self.y = data[:, :2]

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

class PPGDatasetExtremeAug(Dataset):
    def __init__(
        self,
        data: torch.Tensor,
        extreme_quantile: float = 0.05,
        n_aug_per_extreme: int = 5
    ):
        self.data = data
        self.X    = data[:, 2:]
        self.y    = data[:, :2]
        self.n_aug = n_aug_per_extreme
        sp, dp = self.y[:,0], self.y[:,1]
        sp_low, sp_high = torch.quantile(sp, extreme_quantile), torch.quantile(sp, 1-extreme_quantile)
        dp_low, dp_high = torch.quantile(dp, extreme_quantile), torch.quantile(dp, 1-extreme_quantile)
        sp_idxs = {i for i,v in enumerate(sp) if v <= sp_low or v >= sp_high}
        dp_idxs = {i for i,v in enumerate(dp) if v <= dp_low or v >= dp_high}
        self.extreme_idx = list(sp_idxs.union(dp_idxs))
        all_idx = set(range(len(self.y)))
        self.normal_idx = list(all_idx.difference(self.extreme_idx))
        self.indices = []
        self.indices += self.normal_idx
        for i in self.extreme_idx:
            self.indices.append(i)
            for _ in range(self.n_aug):
                self.indices.append((i, _))

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        key = self.indices[idx]
        if isinstance(key, int):
            return self.X[key], self.y[key]
        orig_i, _ = key
        x = self.X[orig_i].clone()
        y = self.y[orig_i]
        return x, y

In [None]:
class Block(nn.Module):
    def __init__(
            self,
            n_steps: int,
            in_channels: int,
            out_channels: int,
            stride: int,
            config: dict
    ):
        super(Block, self).__init__()
        self.n_steps = n_steps
        self.spike_grad = config["spike_grad"]

        # --- main path ---
        self.bn1   = nn.ModuleList([nn.BatchNorm1d(in_channels)   for _ in range(n_steps)])
        self.lif1  = snn.Leaky(beta=0.95, learn_beta=True, learn_threshold=True, spike_grad=self.spike_grad)
        self.conv1 = nn.Conv1d(in_channels, out_channels,kernel_size=3, stride=stride, padding=1)

        self.bn2   = nn.ModuleList([nn.BatchNorm1d(out_channels)  for _ in range(n_steps)])
        self.lif2  = snn.Leaky(beta=0.95, learn_beta=True, learn_threshold=True, spike_grad=self.spike_grad)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        # --- residual projection ---
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride)
        else:
            self.downsample = None

        self.dropout = nn.Dropout(config["dropout"])

    def forward(self, dati):
        x_seq, mem1, mem2 = dati
        out_seq = []

        for t in range(self.n_steps):
            x   = x_seq[t]
            res = x if self.downsample is None else self.downsample(x)

            out, mem1 = self.lif1(self.bn1[t](x), mem1)
            out       = self.conv1(out)

            out, mem2 = self.lif2(self.bn2[t](out), mem2)
            out       = self.conv2(out)

            # now out.shape == res.shape
            out = out + res
            out = self.dropout(out)
            out_seq.append(out)

        return out_seq, mem1, mem2

In [None]:
class SpikingDABlock(nn.Module):
    def __init__(
        self,
        n_steps: int,
        in_channels_list: list[int],
        out_channels: int,
        downsample_strides: list[int],
        config: dict
    ):
        """
        Args:
            n_steps: number of time steps for spiking sequences
            in_channels_list: channels of each input stream
            out_channels: channels for the fused output
            downsample_strides: temporal downsample factor per input (power of 2)
        """
        super(SpikingDABlock, self).__init__()
        assert len(in_channels_list) == len(downsample_strides), \
            "in_channels_list and downsample_strides must match in length"

        self.n_steps    = n_steps
        self.num_inputs = len(in_channels_list)
        self.spike_grad = config["spike_grad"]

        self.transforms = nn.ModuleList()
        self.bns        = nn.ModuleList()
        self.lifs       = nn.ModuleList()

        self.dropout = nn.Dropout(config["dropout"])

        # Build a 1x1 conv for each input with its own stride
        for c_in, stride in zip(in_channels_list, downsample_strides):
            self.transforms.append(nn.Conv1d(c_in, out_channels, kernel_size=1, stride=stride))
            # time-step-specific batchnorm
            self.bns.append(nn.ModuleList([nn.BatchNorm1d(out_channels) for _ in range(n_steps)]))
            # one LIF per input stream
            self.lifs.append(snn.Leaky(beta=0.95, learn_beta=True, spike_grad=self.spike_grad))

    def forward(self, x_seq_list: list[list[torch.Tensor]], mem_list: list[torch.Tensor]):
        """
        Args:
            x_seq_list: list of length num_inputs, each a list of Tensors (len n_steps)
            mem_list: list of length num_inputs, each initial membrane state
        Returns:
            out_seq: fused spike sequence (list of Tensors)
            mem_list: updated membrane states
        """
        N = len(x_seq_list)
        out_seq = []

        for t in range(self.n_steps):
            sum_spk = 0
            for i, seq in enumerate(x_seq_list):
                x   = seq[t]
                x   = self.transforms[i](x)
                x   = self.bns[i][t](x)
                spk, new_mem = self.lifs[i](x, mem_list[i])
                sum_spk    += spk
                mem_list[i] = new_mem

            fused = sum_spk / N
            fused = self.dropout(fused)
            out_seq.append(fused)

        return out_seq, mem_list

In [None]:
class SpikingResNet(nn.Module):
    def __init__(
        self,
        n_steps: int,
        block: nn.Module,
        dablock: nn.Module,
        layers: list[int],
        signal_channels: int,
        num_classes: int,
        config: dict
    ):
        super().__init__()
        self.n_steps     = n_steps
        self.in_channels = 64
        self.block_cls   = block
        self.dablock_cls = dablock

        self.spike_grad = config["spike_grad"]
        self.encoding = config["encoding"]

        self.dropout = nn.Dropout(config["dropout"])

        # Stem
        beta0 = torch.rand(313)

        self.conv0 = nn.Conv1d(signal_channels, 64, kernel_size=7, stride=2, padding=3)
        self.bns0   = nn.ModuleList([nn.BatchNorm1d(64) for _ in range(self.n_steps)])
        self.lif0  = snn.Leaky(beta=beta0, learn_beta=True, learn_threshold=True, spike_grad=self.spike_grad)

        # Residual stages
        self.stage1 = self._make_stage(layers[0],  64, stride=1, config=config)
        self.stage2 = self._make_stage(layers[1],  128, stride=2, config=config)
        self.stage3 = self._make_stage(layers[2], 256, stride=2, config=config)
        self.stage4 = self._make_stage(layers[3], 512, stride=2, config=config)

        # DA fusion blocks with downsample_strides for each input
        self.da12   = self.dablock_cls(n_steps, [64, 64],                64, downsample_strides=[1, 1],          config=config)
        self.da123  = self.dablock_cls(n_steps, [64, 64, 128],            128, downsample_strides=[2, 2, 1],       config=config)
        self.da1234 = self.dablock_cls(n_steps, [64, 64, 128, 256],      256, downsample_strides=[4, 4, 2, 1],    config=config)
        self.da_all = self.dablock_cls(n_steps, [64, 64, 128, 256, 512], 512, downsample_strides=[8, 8, 4, 2, 1], config=config)

        # Classifier head
        beta_out_1   = torch.rand(1)
        thr_out_1    = torch.rand(1)
        beta_out_2   = torch.rand(1)
        thr_out_2    = torch.rand(1)

        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.dropout  = nn.Dropout(config["dropout"])

        self.fc_out_1   = nn.Linear(512, 1)
        self.lif_out_1 = snn.Leaky(beta=beta_out_1, threshold=thr_out_1, learn_beta=True, learn_threshold=True, spike_grad=self.spike_grad, reset_mechanism="none")

        self.fc_out_2   = nn.Linear(512, 1)
        self.lif_out_2 = snn.Leaky(beta=beta_out_2, threshold=thr_out_2, learn_beta=True, learn_threshold=True, spike_grad=self.spike_grad, reset_mechanism="none")

    def forward(self, x):
        # helper to reset membrane states
        def reset(lif): return lif.reset_mem()

        # reset & store
        mem_stem    = reset(self.lif0)
        mem_stage1, mem_stage2, mem_stage3, mem_stage4 = [
            [[reset(blk.lif1), reset(blk.lif2)] for blk in stage]
            for stage in (self.stage1, self.stage2, self.stage3, self.stage4)
        ]
        mem_da12   = [reset(l) for l in self.da12.lifs]
        mem_da123  = [reset(l) for l in self.da123.lifs]
        mem_da1234 = [reset(l) for l in self.da1234.lifs]
        mem_da_all = [reset(l) for l in self.da_all.lifs]

        mem_cls_1  = reset(self.lif_out_1)
        mem_cls_2  = reset(self.lif_out_2)

        # f0 / stem
        stem_seq  = []

        for t in range(self.n_steps):
            if self.encoding:
                x_step = x[t]
            else:
                x_step = x

            cur_stem = self.bns0[t](self.conv0(x_step))
            spk, mem_stem = self.lif0(cur_stem, mem_stem)
            stem_seq.append(spk)

        # stage 1
        seq = stem_seq
        for idx, blk in enumerate(self.stage1):
            m1, m2 = mem_stage1[idx]
            seq, m1, m2 = blk([seq, m1, m2])
            mem_stage1[idx] = [m1, m2]
        out_seq1_da, mem_da12 = self.da12([stem_seq, seq], mem_da12)
        out_seq1p = [f0 + d for f0, d in zip(stem_seq, out_seq1_da)]

        # stage 2
        seq = out_seq1p
        for idx, blk in enumerate(self.stage2):
            m1, m2 = mem_stage2[idx]
            seq, m1, m2 = blk([seq, m1, m2])
            mem_stage2[idx] = [m1, m2]

        out_seq2_da, mem_da123 = self.da123([stem_seq, out_seq1p, seq], mem_da123)
        out_seq2p = [f2 + d for f2, d in zip(seq, out_seq2_da)]

        # stage 3
        seq = out_seq2p
        for idx, blk in enumerate(self.stage3):
            m1, m2 = mem_stage3[idx]
            seq, m1, m2 = blk([seq, m1, m2])
            mem_stage3[idx] = [m1, m2]
        out_seq3_da, mem_da1234 = self.da1234([
            stem_seq, out_seq1p, out_seq2p, seq
        ], mem_da1234)
        out_seq3p = [f3 + d for f3, d in zip(seq, out_seq3_da)]

        # stage 4
        seq = out_seq3p
        for idx, blk in enumerate(self.stage4):
            m1, m2 = mem_stage4[idx]
            seq, m1, m2 = blk([seq, m1, m2])
            mem_stage4[idx] = [m1, m2]
        out_seq4_da, mem_da_all = self.da_all([
            stem_seq, out_seq1p, out_seq2p, out_seq3p, seq
        ], mem_da_all)
        out_seq4p = [f4 + d for f4, d in zip(seq, out_seq4_da)]

        # classifier head
        mem_seq_1 = []
        mem_seq_2 = []
        for t in range(self.n_steps):

            pooled = self.avg_pool(out_seq4p[t]).squeeze(-1)
            pooled = self.dropout(pooled)

            cur_cls_1 = self.fc_out_1(pooled)
            spk_1, mem_cls_1 = self.lif_out_1(cur_cls_1, mem_cls_1)

            cur_cls_2 = self.fc_out_2(pooled)
            spk_2, mem_cls_2 = self.lif_out_2(cur_cls_2, mem_cls_2)

            mem_seq_1.append(mem_cls_1)
            mem_seq_2.append(mem_cls_2)

        mem_seq_1 = torch.stack(mem_seq_1, dim=0)
        mem_seq_2 = torch.stack(mem_seq_2, dim=0)

        mem_seq = torch.cat([mem_seq_1, mem_seq_2], dim=2)

        # return the list of final membrane states across time
        return mem_seq

    def _make_stage(self, num_blocks, out_channels, stride, config):
        blocks = []
        for i in range(num_blocks):
            blk_stride = stride if i == 0 else 1

            blocks.append(self.block_cls(self.n_steps, self.in_channels, out_channels, blk_stride, config))

            self.in_channels = out_channels

        return nn.ModuleList(blocks)

In [None]:
MODEL_CONFIG = {
    "num_steps": 24,
    "spike_grad": surrogate.atan(alpha=8),
    "encoding": True,
    "layers": [1, 1, 1, 1],
    "dropout": 0.3,
}

TRAIN_CONFIG = {
    "batch_size":         256,
    "n_epochs":           100,
    "lr":                 5e-4,
    "weight_decay":       5e-5,
    "feature_scaling":    True,
    "target_scaling":     True,
    "device":             torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    # augmentation & oversampling params
    "extreme_quantile":   0.05,
    "n_aug_per_extreme":  3
}

In [None]:
def scale_data(train_data, eval_data):
    if TRAIN_CONFIG["feature_scaling"]:
        fs = StandardScaler().fit(train_data[:,2:])
        train_data[:,2:] = torch.from_numpy(fs.transform(train_data[:,2:])).float()
        eval_data[:,2:]  = torch.from_numpy(fs.transform(eval_data[:,2:])).float()
        if MODEL_CONFIG["encoding"]:
            mms = MinMaxScaler((0,1)).fit(train_data[:,2:])
            train_data[:,2:] = torch.from_numpy(mms.transform(train_data[:,2:])).float()
            eval_data[:,2:]  = torch.from_numpy(mms.transform(eval_data[:,2:])).float()
    if TRAIN_CONFIG["target_scaling"]:
        ts = StandardScaler().fit(train_data[:,:2])
        train_data[:,:2] = torch.from_numpy(ts.transform(train_data[:,:2])).float()
        eval_data[:,:2]  = torch.from_numpy(ts.transform(eval_data[:,:2])).float()
    else:
        ts = None
    return train_data, eval_data, ts

def make_dataloaders(train_data, eval_data):
    # train_ds = PPGDatasetExtremeAug(
    #     data=train_data,
    #     extreme_quantile=TRAIN_CONFIG["extreme_quantile"],
    #     n_aug_per_extreme=TRAIN_CONFIG["n_aug_per_extreme"]
    # )
    train_ds = PPGDataset(train_data)
    val_ds   = PPGDataset(eval_data)
    train_loader = DataLoader(train_ds, batch_size=TRAIN_CONFIG["batch_size"],
                              shuffle=True, worker_init_fn=seed_worker, generator=g)
    val_loader   = DataLoader(val_ds,   batch_size=TRAIN_CONFIG["batch_size"],
                              shuffle=False, worker_init_fn=seed_worker, generator=g)
    return train_loader, val_loader

# Model
def build_model():
    return SpikingResNet(
        n_steps=MODEL_CONFIG["num_steps"],
        layers=MODEL_CONFIG["layers"],
        block=Block,
        dablock=SpikingDABlock,
        signal_channels=1,
        num_classes=2,
        config=MODEL_CONFIG
    ).to(TRAIN_CONFIG["device"])

# Training / Validation Epochs
def train_one_epoch(model, loader, loss_fn, optimizer, scheduler, targ_scl):
    model.train()
    total_loss = sp_mae = dp_mae = 0.0
    for Xb, yb in loader:
        X = Xb.to(TRAIN_CONFIG["device"]).float()
        y = yb.to(TRAIN_CONFIG["device"]).float()
        if MODEL_CONFIG["encoding"]:
            X = spikegen.rate(data=X.unsqueeze(1), num_steps=MODEL_CONFIG["num_steps"])
        else:
            X = X.unsqueeze(1)
        optimizer.zero_grad()
        yp = model(X)[-1]
        loss = loss_fn(yp, y)
        loss.backward()
        optimizer.step()
        scheduler.step()
        bsz = X.size(0)
        total_loss += loss.item() * bsz
        if targ_scl:
            yt = torch.from_numpy(targ_scl.inverse_transform(y.cpu().numpy())).to(y.device)
            yh = torch.from_numpy(targ_scl.inverse_transform(yp.detach().cpu().numpy())).to(y.device)
        else:
            yt, yh = y, yp
        sp_mae += torch.abs(yh[:,0] - yt[:,0]).sum().item()
        dp_mae += torch.abs(yh[:,1] - yt[:,1]).sum().item()
    n = len(loader.dataset)
    return total_loss/n, sp_mae/n, dp_mae/n

def validate_one_epoch(model, loader, loss_fn, targ_scl):
    model.eval()
    total_loss = sp_mae = dp_mae = 0.0
    with torch.no_grad():
        for Xb, yb in loader:
            X = Xb.to(TRAIN_CONFIG["device"]).float()
            y = yb.to(TRAIN_CONFIG["device"]).float()
            if MODEL_CONFIG["encoding"]:
                X = spikegen.rate(data=X.unsqueeze(1), num_steps=MODEL_CONFIG["num_steps"])
            else:
                X = X.unsqueeze(1)
            yp = model(X)[-1]
            bsz = X.size(0)
            total_loss += loss_fn(yp, y).item() * bsz
            if targ_scl:
                yt = torch.from_numpy(targ_scl.inverse_transform(y.cpu().numpy())).to(y.device)
                yh = torch.from_numpy(targ_scl.inverse_transform(yp.cpu().numpy())).to(y.device)
            else:
                yt, yh = y, yp
            sp_mae += torch.abs(yh[:,0] - yt[:,0]).sum().item()
            dp_mae += torch.abs(yh[:,1] - yt[:,1]).sum().item()
    n = len(loader.dataset)
    return total_loss/n, sp_mae/n, dp_mae/n

def predict_on_loader(model, loader, targ_scl=None):
    model.eval()
    trues, preds = [], []
    with torch.no_grad():
        for Xb, yb in loader:
            X = Xb.to(TRAIN_CONFIG["device"]).float()
            y = yb.to(TRAIN_CONFIG["device"]).float()
            if MODEL_CONFIG["encoding"]:
                X = spikegen.rate(data=X.unsqueeze(1), num_steps=MODEL_CONFIG["num_steps"])
            else:
                X = X.unsqueeze(1)
            yp = model(X)[-1]

            if targ_scl:
                yt = targ_scl.inverse_transform(y.cpu().numpy())
                yh = targ_scl.inverse_transform(yp.cpu().numpy())
            else:
                yt, yh = y.cpu().numpy(), yp.cpu().numpy()

            trues.append(yt)
            preds.append(yh)

    return np.vstack(trues), np.vstack(preds)

# Fold / K-Fold Runner
def run_fold(val_idx, fold_num=None):
    if fold_num is not None:
        print(f"\n=== Starting fold {fold_num} ===")

    train_data, eval_data = get_data_splits(val_idx=val_idx, mabp=False)
    train_data, eval_data, targ_scl = scale_data(train_data, eval_data)
    train_loader, val_loader = make_dataloaders(train_data, eval_data)

    model     = build_model()
    loss_fn   = nn.L1Loss()

    max_lr  = TRAIN_CONFIG["lr"]
    base_lr = max_lr / 50.0
    optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=TRAIN_CONFIG["weight_decay"])

    total_steps = TRAIN_CONFIG["n_epochs"] * len(train_loader)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=max_lr,
        total_steps=total_steps,
        div_factor=50,
        pct_start=0.3,
        final_div_factor=1e3
        )

    # best_val, no_imp = float('inf'), 0
    # patience, delta = 15, 0.0
    train_losses, val_losses = [], []
    train_sps, val_sps     = [], []
    train_dps, val_dps     = [], []
    for ep in range(1, TRAIN_CONFIG["n_epochs"]+1):
        tr_loss, tr_sp, tr_dp = train_one_epoch(model, train_loader, loss_fn, optimizer, scheduler, targ_scl)
        vl_loss, vl_sp, vl_dp = validate_one_epoch(model, val_loader, loss_fn, targ_scl)

        if ep % 2 == 0:
            print(f"Fold {fold_num} Epoch {ep}/{TRAIN_CONFIG['n_epochs']}  Train L={tr_loss:.4f} (SP={tr_sp:.3f}, DP={tr_dp:.3f})  Val  L={vl_loss:.4f} (SP={vl_sp:.3f}, DP={vl_dp:.3f})")
        train_losses.append(tr_loss)
        val_losses.append(vl_loss)
        train_sps.append(tr_sp)
        val_sps.append(vl_sp)
        train_dps.append(tr_dp)
        val_dps.append(vl_dp)
        #if vl_loss < best_val - delta:
        #    best_val, no_imp = vl_loss, 0
        #else:
        #    no_imp += 1
        #    if no_imp >= patience:
        #        print(f"Fold {fold_num} stopping early at epoch {ep}")
        #        torch.save(model.state_dict(), f"best_model_fold{fold_num}.pt")
        #        break

    torch.save(model.state_dict(), f"best_model_fold{fold_num}.pt")
    y_trues, y_preds = predict_on_loader(model, val_loader, targ_scl=targ_scl)

    metrics = {
        "train_loss": train_losses,
        "val_loss":   val_losses,
        "train_sp":   train_sps,
        "val_sp":     val_sps,
        "train_dp":   train_dps,
        "val_dp":     val_dps,
    }

    return vl_sp, vl_dp, metrics, y_trues, y_preds

def run_kfold(n_folds=5):
    sbps, dbps, logs = [], [], {}
    all_trues, all_preds = [], []

    for f in range(n_folds):
        sp_mae, dp_mae, lg, y_trues, y_preds = run_fold(val_idx=f, fold_num=f+1 )
        print(f"→ Fold {f+1} result: "
        f"SBP MAE={sp_mae:.3f}, DBP MAE={dp_mae:.3f}")

        sbps.append(sp_mae)
        dbps.append(dp_mae)
        logs[f] = lg

        # collect for the final pooled analysis
        all_trues.append(y_trues)
        all_preds.append(y_preds)

    # compute mean ± std
    mean_sbp, mean_dbp = np.mean(sbps), np.mean(dbps)
    std_sbp,  std_dbp  = np.std(sbps, ddof=1), np.std(dbps, ddof=1)

    print(f"\n=== K-Fold Summary ({n_folds} folds) ===")
    print(f"SBP MAE = {mean_sbp:.3f} ± {std_sbp:.3f}  •  "
    f"DBP MAE = {mean_dbp:.3f} ± {std_dbp:.3f}")

    # stack up all folds’ val predictions & truths
    y_true_all = np.vstack(all_trues)
    y_pred_all = np.vstack(all_preds)

    with open("learning_curves.json", "w") as f:
        json.dump(logs, f, indent=2)

    return mean_sbp, std_sbp, mean_dbp, std_dbp, y_true_all, y_pred_all

In [None]:
mean_sbp, std_sbp, mean_dbp, std_dbp, y_true_all, y_pred_all = run_kfold(n_folds=5)
print(f"Overall SBP MAE = {mean_sbp:.3f} ± {std_sbp:.3f}")
print(f"Overall DBP MAE = {mean_dbp:.3f} ± {std_dbp:.3f}")


=== Starting fold 1 ===


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 22.16 GiB of which 7.38 MiB is free. Process 6779 has 22.15 GiB memory in use. Of the allocated memory 21.93 GiB is allocated by PyTorch, and 5.18 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Load the data
with open("learning_curves.json", "r") as f:
    logs = json.load(f)

# Plot settings
metrics = {
    "loss": ("Loss", "Training and Validation Loss"),
    "sp": ("SBP MAE", "SBP Mean Absolute Error"),
    "dp": ("DBP MAE", "DBP Mean Absolute Error")
}

# Plot each metric with one subplot per fold
for key, (ylabel, title) in metrics.items():
    n_folds = len(logs)
    fig, axs = plt.subplots(1, n_folds, figsize=(5 * n_folds, 4), sharey=True)
    fig.suptitle(title, fontsize=16)

    if n_folds == 1:
        axs = [axs]  # ensure it's iterable

    for i, (fold, data) in enumerate(logs.items()):
        axs[i].plot(data[f"train_{key}"], label="Train", color="blue", linewidth=2)
        axs[i].plot(data[f"val_{key}"], label="Val", color="orange", linestyle="--", linewidth=2)
        axs[i].set_title(f"Fold {fold}")
        axs[i].set_xlabel("Epoch")
        axs[i].set_ylabel(ylabel)
        axs[i].legend()
        axs[i].grid(True)

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

def plot_error_histograms(y_true, y_pred):
    """Plot side-by-side histograms of absolute errors for SBP & DBP."""
    errors = np.abs(y_pred - y_true)
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    axes[0].hist(errors[:, 0], bins=50, alpha=0.7)
    axes[0].set_title("SBP Absolute Error Histogram")
    axes[0].set_xlabel("Absolute Error (mmHg)")
    axes[0].set_ylabel("Count")

    axes[1].hist(errors[:, 1], bins=50, alpha=0.7)
    axes[1].set_title("DBP Absolute Error Histogram")
    axes[1].set_xlabel("Absolute Error (mmHg)")
    axes[1].set_ylabel("Count")

    plt.tight_layout()
    plt.show()


def plot_bland_altman(y_true, y_pred):
    """Bland–Altman plot: prediction error vs. average of true & pred."""
    errors = y_pred - y_true
    averages = (y_pred + y_true) / 2

    for i, label in enumerate(["SBP", "DBP"]):
        avg = averages[:, i]
        err = errors[:, i]
        m = np.mean(err)
        sd = np.std(err, ddof=1)

        plt.figure(figsize=(6, 4))
        plt.scatter(avg, err, alpha=0.5)
        plt.axhline(m, linestyle='-', label=f"Mean Error = {m:.2f}")
        plt.axhline(m + 1.96*sd, linestyle='--', label=f"+1.96 SD = {m + 1.96*sd:.2f}")
        plt.axhline(m - 1.96*sd, linestyle='--', label=f"-1.96 SD = {m - 1.96*sd:.2f}")
        plt.title(f"Bland–Altman Plot ({label})")
        plt.xlabel("Average of True & Predicted (mmHg)")
        plt.ylabel("Error (Predicted − True) (mmHg)")
        plt.legend()
        plt.tight_layout()
        plt.show()


def plot_scatter_with_r2(y_true, y_pred):
    """Scatter plot of predicted vs. true values, plus identity line & R²."""
    for i, label in enumerate(["SBP", "DBP"]):
        true = y_true[:, i]
        pred = y_pred[:, i]
        r2 = r2_score(true, pred)

        plt.figure(figsize=(6, 6))
        plt.scatter(true, pred, alpha=0.5)
        lims = [min(true.min(), pred.min()), max(true.max(), pred.max())]
        plt.plot(lims, lims, 'k--', linewidth=1)
        plt.xlim(lims)
        plt.ylim(lims)
        plt.title(f"{label}: Predicted vs True (R² = {r2:.3f})")
        plt.xlabel("True (mmHg)")
        plt.ylabel("Predicted (mmHg)")
        plt.tight_layout()
        plt.show()

def plot_value_distributions(y_true, y_pred, bins=100):
    """
    Plot side-by-side histograms of the TRUE and PREDICTED values for SBP & DBP.
    """
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    for i, label in enumerate(["SBP", "DBP"]):
        axes[i].hist(y_true[:, i], bins=bins, alpha=0.6, label="True", edgecolor='black')
        axes[i].hist(y_pred[:, i], bins=bins, alpha=0.6, label="Predicted", edgecolor='black')
        axes[i].set_title(f"{label} Value Distribution")
        axes[i].set_xlabel("Blood Pressure (mmHg)")
        axes[i].set_ylabel("Count")
        axes[i].legend()

    plt.tight_layout()
    plt.show()

In [None]:
plot_error_histograms(y_true_all, y_pred_all)
plot_bland_altman(y_true_all, y_pred_all)
plot_scatter_with_r2(y_true_all, y_pred_all)
plot_value_distributions(y_true_all, y_pred_all)

In [None]:
def make_serializable(d):
    serial = {}
    for k, v in d.items():
        # torch.device → string
        if isinstance(v, torch.device):
            serial[k] = str(v)
        # functions or snntorch surrogate objs → use repr()
        elif callable(v) or not isinstance(v, (int, float, bool, str, list, dict)):
            serial[k] = repr(v)
        else:
            serial[k] = v
    return serial

model_cfg_ser = make_serializable(MODEL_CONFIG)
train_cfg_ser = make_serializable(TRAIN_CONFIG)

with open("final_metrics.json","a") as f:
    f.write(f"\nFinal K-Fold SBP MAE: {mean_sbp:.4f}, DBP MAE: {mean_dbp:.4f}\n")
    f.write("\nModel Configuration:\n")
    json.dump(model_cfg_ser, f, indent=2); f.write("\n")
    f.write("\nTraining Configuration:\n")
    json.dump(train_cfg_ser, f, indent=2)