In [None]:
!pip install -q wfdb torchinfo

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/91.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/163.8 kB[0m [31m17.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.4/12.4 MB[0m [31m148.4 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.3 which is incompatible.[0m[31m
[0m

In [None]:
import os
import math
import random
from collections import defaultdict, Counter
from typing import List, Optional

import h5py
import numpy as np
import pandas as pd
import pywt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler
from scipy.signal import butter, filtfilt, find_peaks
from sklearn.model_selection import train_test_split
import wfdb

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import zipfile
import os

zip_path = '/content/drive/MyDrive/CINC2020.zip'
extract_dir = '/content/'

os.makedirs(extract_dir, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

In [None]:
# -------------------------
# Config
# -------------------------
SAMPLING_RATE = 500
DROP_LAST = True
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# -------------------------
# utilities: filtering and peak detection
# -------------------------
def denoise(data):
    # wavelet transform
    coeffs = pywt.wavedec(data=data, wavelet='db5', level=9)
    cA9, cD9, cD8, cD7, cD6, cD5, cD4, cD3, cD2, cD1 = coeffs

    # Threshold denoising
    threshold = (np.median(np.abs(cD1)) / 0.6745) * (np.sqrt(2 * np.log(len(cD1))))
    cD1.fill(0)
    cD2.fill(0)
    for i in range(1, len(coeffs) - 2):
        coeffs[i] = pywt.threshold(coeffs[i], threshold)

    # Inverse wavelet transform to obtain the denoised signal
    rdata = pywt.waverec(coeffs=coeffs, wavelet='db5')
    return rdata


def pan_tompkins_detector(ecg_signal, fs):
    lowcut, highcut = 5.0, 15.0
    nyquist = 0.5 * fs
    low, high = lowcut / nyquist, highcut / nyquist
    b, a = butter(1, [low, high], btype='band')
    filtered_ecg = filtfilt(b, a, ecg_signal)
    diff_ecg = np.diff(filtered_ecg)
    squared_ecg = diff_ecg ** 2
    window_size = int(0.150 * fs)
    mwa_ecg = np.convolve(squared_ecg, np.ones(window_size) / window_size, mode='same')
    peaks, _ = find_peaks(mwa_ecg, distance=int(0.6 * fs))
    return peaks


def multi_lead_fusion(detected_peaks, fs, fusion_window=0.1, min_leads=None):
    n_leads = len(detected_peaks)
    if min_leads is None:
        min_leads = int(np.ceil(n_leads / 2))

    # Collect all peaks with their lead information
    all_peaks = [(p, lead) for lead, peaks in enumerate(detected_peaks) for p in peaks]
    all_peaks.sort(key=lambda x: x[0])

    fused_peaks = []
    i = 0

    while i < len(all_peaks):
        # Start a new cluster
        cluster = [all_peaks[i]]
        i += 1

        # Add nearby peaks to the cluster
        while i < len(all_peaks) and all_peaks[i][0] - cluster[-1][0] <= fusion_window * fs:
            cluster.append(all_peaks[i])
            i += 1

        # Check if cluster has peaks from enough leads
        unique_leads = {lead for _, lead in cluster}
        if len(unique_leads) >= min_leads:
            # Use median position as the fused peak
            fused_peak = int(np.median([p for (p, _) in cluster]))
            fused_peaks.append(fused_peak)

    return np.array(sorted(fused_peaks))


def detect_r_peaks(ecg_signals, fs):
    detected_peaks = []
    for lead in ecg_signals:
        peaks = pan_tompkins_detector(lead, fs)
        detected_peaks.append(peaks)

    fused_r_peaks = multi_lead_fusion(detected_peaks, fs, fusion_window=0.1, min_leads=6)
    return fused_r_peaks


def extract_segments_around_peaks(signal, r_peaks, pre_samples, post_samples):
    segments = []

    for peak in r_peaks:
        start = max(0, peak - pre_samples)
        end = min(len(signal), peak + post_samples)

        # Only include segments with the correct length
        if end - start == pre_samples + post_samples:
            segment = signal[start:end]
            segments.append(segment)

    return segments


def extract_rr_beats_multi_lead(ecg_signals, fs, denoise_fn=None,
                                min_rr_ms=300, max_rr_ms=1500, min_beats=2):
    """
    Returns a list of beats, each beat is an array with shape (T_i, C) where C = n_leads.
    """
    ecg = np.array(ecg_signals)                  # (n_leads, n_samples)
    n_leads, n_samples = ecg.shape

    if denoise_fn is not None:
        ecg = np.array([denoise_fn(lead) for lead in ecg])

    # Detect fused R-peaks once across leads
    r_peaks = detect_r_peaks(ecg, fs)            # uses your multi-lead fusion
    if len(r_peaks) < min_beats:
        return None  # not enough beats

    # RR in samples/ms
    rr_samples = np.diff(r_peaks)
    rr_ms = (rr_samples / fs) * 1000.0
    valid = (rr_ms >= min_rr_ms) & (rr_ms <= max_rr_ms)
    if valid.sum() == 0:
        return None

    beats = []
    for i in range(len(r_peaks) - 1):
        if not valid[i]:
            continue
        start = r_peaks[i]
        end   = r_peaks[i+1]
        seg = ecg[:, start:end].T                # (T_i, C)
        if seg.shape[0] > 0:
            beats.append(seg.astype(np.float32))

    if len(beats) == 0:
        return None
    return beats  # list of variable-length (T_i, C)

In [None]:
import os
import wfdb
import numpy as np
from collections import Counter

print("Loading and preprocessing ECG data (CINC)…")

RECORDS = "/content/classification-of-12-lead-ecgs-the-physionetcomputing-in-cardiology-challenge-2020-1.0.2/RECORDS"

# LEAD_TO_KEEP = 1            # keep if you ever want single-lead; we're using all leads

# keep only these Dx codes (strings)
available_labels = [
    '10370003', '164889003', '164909002', '164934002',
    '270492004', '284470004', '426177001', '426783006',
    '427084000', '427393009', '59118001'
]

tmp = []  # will store (beats_list, label_str)

# --- enumerate dataset folders from CINC/RECORDS ---
with open(RECORDS) as f:
    FOLDERS = [ln.strip() for ln in f if ln.strip()]

for folder in FOLDERS:
    # your original exclusions
    if folder in ["training/ptb/g1/", "training/st_petersburg_incart/g1/"]:
        continue

    records_path = os.path.join("/content/classification-of-12-lead-ecgs-the-physionetcomputing-in-cardiology-challenge-2020-1.0.2", folder, "RECORDS")
    if not os.path.exists(records_path):
        continue

    with open(records_path) as r:
        files = [ln.strip() for ln in r if ln.strip()]

    for file_name in files:
        final_data_path = os.path.join("/content/classification-of-12-lead-ecgs-the-physionetcomputing-in-cardiology-challenge-2020-1.0.2", folder, file_name)

        # --- header: find a single Dx (skip multi-label) ---
        try:
            hdr = wfdb.rdheader(final_data_path)
        except Exception as e:
            continue

        label = None
        for comment in (hdr.comments or []):
            if comment.startswith("Dx:") and "," not in comment:
                # typical: "Dx: 164889003"
                label = comment.replace("Dx:", "").strip()
                break

        if (label is None) or (label not in available_labels):
            continue

        # --- read signal ---
        try:
            record = wfdb.rdrecord(final_data_path)       # p_signal: (n_samples, n_leads)
            signal = record.p_signal
        except Exception as e:
            continue

        if signal is None:
            continue

        # replace NaNs, transpose -> (n_leads, n_samples), float32
        signal = np.nan_to_num(signal, nan=0.0).T.astype(np.float32)

        # multi-lead denoising (one wavelet pass per lead)
        denoised = np.array([denoise(lead) for lead in signal], dtype=np.float32)
        denoised = np.nan_to_num(denoised, nan=0.0)

        if np.isnan(denoised).any():
            print("WARNING: denoised signal contains NaNs for", final_data_path)

        # --- extract variable-length R–R beats across all leads ---
        beats = extract_rr_beats_multi_lead(
            denoised,               # (n_leads, n_samples)
            fs=SAMPLING_RATE,
            denoise_fn=None         # already denoised above
        )

        # beats is a Python list of arrays; each array has shape (T_i, C)
        if beats is None or len(beats) == 0:
            continue

        tmp.append((beats, label))

# --- build label -> index mapping from labels actually present ---
unique_labels = sorted({lbl for (_, lbl) in tmp})
label2idx = {lbl: i for i, lbl in enumerate(unique_labels)}
num_classes = len(unique_labels)
print("num_classes:", num_classes, "labels:", unique_labels)

# --- final data_list: (beats_list, label_idx) ---
data_list = []
for beats, lbl in tmp:
    if lbl not in label2idx:
        continue
    data_list.append((beats, label2idx[lbl]))

print("Prepared records:", len(data_list))
print("Beat counts distribution:", Counter([len(x[0]) for x in data_list]))


In [None]:
# ---------- Length-bucketing Batch Sampler ----------
class LengthBucketBatchSampler(Sampler):
    """
    Yields lists of indices (batches) where all samples in a batch share the same length-bucket.

    Args:
        lengths: list/array-like of int lengths (len(dataset) entries).
        batch_size: target batch size.
        bin_size: None for exact-length buckets. If int > 0, length_key = (length // bin_size) * bin_size.
        shuffle: shuffle within buckets and shuffle batch order each epoch.
        drop_last: whether to drop the last small batch in a bucket.
    """
    def __init__(self,
                 lengths: List[int],
                 batch_size: int,
                 bin_size: Optional[int] = None,
                 shuffle: bool = True,
                 drop_last: bool = False):
        self.lengths = list(lengths)
        self.batch_size = int(batch_size)
        self.bin_size = bin_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        # Build mapping length_key -> list of indices
        self._buckets = defaultdict(list)
        for idx, L in enumerate(self.lengths):
            key = self._length_key(L)
            self._buckets[key].append(idx)

        # Convert to normal dict for iteration; keep keys list stable
        self.bucket_keys = list(self._buckets.keys())

    def _length_key(self, length: int) -> int:
        if self.bin_size is None or self.bin_size <= 0:
            return int(length)   # exact-length bucket
        else:
            return (length // self.bin_size) * self.bin_size

    def __iter__(self):
      # For each epoch, build batches from buckets.
      batches = []
      for key in self.bucket_keys:
          idxs = list(self._buckets[key])
          if len(idxs) < self.batch_size:
              # skip this bucket entirely
              continue
          if self.shuffle:
              random.shuffle(idxs)
          # chunk into batches
          for i in range(0, len(idxs), self.batch_size):
              batch = idxs[i:i + self.batch_size]
              if len(batch) < self.batch_size and self.drop_last:
                  continue
              batches.append(batch)

      if self.shuffle:
          random.shuffle(batches)

      for batch in batches:
          yield batch

    def __len__(self):
        total = 0
        for key in self.bucket_keys:
            n = len(self._buckets[key])
            if self.drop_last:
                total += n // self.batch_size
            else:
                total += math.ceil(n / self.batch_size)
        return total

In [None]:
class ECGSegmentDatasetVarLen(Dataset):
    """
    data_list: list of tuples (beats_list, label_idx)
      - beats_list is a Python list of arrays with shapes (T_i, C), i=1..S
    """
    def __init__(self, data_list):
        self.data = data_list
        # number of beats per record for length-bucketing by S
        self.num_beats = [len(x[0]) for x in data_list]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        beats_list, label = self.data[idx]
        # convert each beat to torch tensor
        beats_tensors = [torch.from_numpy(b) if isinstance(b, np.ndarray) else torch.tensor(b)
                         for b in beats_list]  # [(T_i, C), ...]
        return {"beats": beats_tensors, "label": int(label), "num_beats": len(beats_tensors)}


def collate_by_num_beats(batch):
    s_vals = [item["num_beats"] for item in batch]
    if not all(s == s_vals[0] for s in s_vals):
        raise ValueError("collate_by_num_beats received mixed num_beats in a batch")
    signals = torch.stack([item["signal"] for item in batch], dim=0)  # (B, S, T, C)
    labels = torch.tensor([item["label"] for item in batch], dtype=torch.long)
    return {"signal": signals, "label": labels, "num_beats": torch.tensor(s_vals, dtype=torch.long)}


def pad_collate_varlen(batch):
    B = len(batch)
    labels = torch.tensor([b["label"] for b in batch], dtype=torch.long)
    S_vals = [b["num_beats"] for b in batch]
    if not all(S_vals[0] == s for s in S_vals):
        raise ValueError("Bucketed sampler should ensure same #beats (S) per batch.")

    S = S_vals[0]
    # C might be 12; infer from first item, first beat
    C = batch[0]["beats"][0].shape[1]
    # T_max across all beats of all items
    T_max = max(beat.shape[0] for item in batch for beat in item["beats"])

    signal = torch.zeros((B, S, T_max, C), dtype=torch.float32)
    mask   = torch.zeros((B, S, T_max),   dtype=torch.float32)

    for bi, item in enumerate(batch):
        for si, beat in enumerate(item["beats"]):
            T = beat.shape[0]
            signal[bi, si, :T, :] = beat
            mask[bi, si, :T] = 1.0

    return {"signal": signal, "mask": mask, "label": labels,
            "num_beats": torch.tensor(S_vals, dtype=torch.long)}



In [None]:
from torch.utils.data import Subset, DataLoader

# # instantiate dataset and sampler
ds = ECGSegmentDatasetVarLen(data_list)

# -------------------------
# Split indices
# -------------------------
all_indices = list(range(len(ds)))
all_labels = [data_list[i][1] for i in all_indices]

# First split: train+val vs test (stratified)
trainval_indices, test_indices = train_test_split(
    all_indices,
    test_size=0.1,
    random_state=10,
    stratify=all_labels  # ← stratify by class labels
)

# Extract labels for the trainval subset
trainval_labels = [all_labels[i] for i in trainval_indices]

# Second split: train vs val (stratified)
train_indices, val_indices = train_test_split(
    trainval_indices,
    test_size=0.1,
    random_state=10,
    stratify=trainval_labels  # ← stratify by class labels
)

# -------------------------
# Create Subsets
# -------------------------
train_ds = Subset(ds, train_indices)
val_ds   = Subset(ds, val_indices)
test_ds  = Subset(ds, test_indices)

def make_loader(subset_ds, batch_size=16, drop_last=True, shuffle=True):
    subset_indices = subset_ds.indices
    lengths = [ds.num_beats[i] for i in subset_indices]
    sampler = LengthBucketBatchSampler(lengths, batch_size, bin_size=None,
                                       shuffle=shuffle, drop_last=drop_last)
    loader = DataLoader(subset_ds, batch_sampler=sampler,
                        collate_fn=pad_collate_varlen,
                        num_workers=0, pin_memory=True)
    return loader

In [None]:
# Verifying stratifying
train_loader = make_loader(train_ds, batch_size=16)
val_loader   = make_loader(val_ds, batch_size=16, shuffle=False)
test_loader  = make_loader(test_ds, batch_size=16, shuffle=False)

def get_label_distribution(indices):
    labels = [data_list[i][1] for i in indices]
    return Counter(labels)

print("Train distribution:", get_label_distribution(train_indices))
print("Val distribution:", get_label_distribution(val_indices))
print("Test distribution:", get_label_distribution(test_indices))

Train distribution: Counter({7: 7364, 10: 1299, 1: 879, 4: 610, 5: 436, 6: 370, 0: 220, 8: 185, 2: 175, 3: 129, 9: 103})
Val distribution: Counter({7: 818, 10: 144, 1: 98, 4: 68, 5: 48, 6: 41, 0: 25, 8: 21, 2: 20, 3: 14, 9: 11})
Test distribution: Counter({7: 910, 10: 160, 1: 109, 4: 75, 5: 54, 6: 45, 0: 27, 8: 23, 2: 22, 3: 16, 9: 13})


In [None]:
# -------------------------
# Attention / helper layers
# -------------------------
class ChannelAttention(nn.Module):
    """
    Expects x shape = (batch, channels, seq_len)
    """
    def __init__(self, channels, ratio=8):
        super().__init__()
        mid = max(1, channels // ratio)
        self.mlp = nn.Sequential(
            nn.Linear(channels, mid, bias=True),
            nn.ReLU(),
            nn.Linear(mid, channels, bias=True)
        )

    def forward(self, x):
        # x: (B, C, L)
        avg_pool = torch.mean(x, dim=2)           # (B, C)
        max_pool, _ = torch.max(x, dim=2)         # (B, C)
        avg_out = self.mlp(avg_pool)              # (B, C)
        max_out = self.mlp(max_pool)              # (B, C)
        att = torch.sigmoid(avg_out + max_out)    # (B, C)
        att = att.unsqueeze(2)                    # (B, C, 1)
        return x * att                             # broadcast multiply -> (B, C, L)

class SegmentAttention(nn.Module):
    def __init__(self, input_dim, units):
        super().__init__()
        self.linear = nn.Linear(input_dim, units, bias=True)
        self.u = nn.Parameter(torch.randn(units))

    def forward(self, inputs, mask=None):
        """
        inputs: (B, T, D)
        mask:   (B, T) with 1=valid, 0=pad (or None)
        """
        v = torch.tanh(self.linear(inputs))             # (B, T, units)
        vu = torch.matmul(v, self.u)                    # (B, T)
        if mask is not None:
            # set -inf (large negative) where mask==0 so softmax->0
            vu = vu.masked_fill(mask == 0, float('-inf'))
        alphas = F.softmax(vu, dim=1)                   # (B, T)
        # NaN-safe: if an all-pad row slipped through, replace NaNs with 0
        alphas = torch.nan_to_num(alphas, nan=0.0)
        output = torch.sum(inputs * alphas.unsqueeze(-1), dim=1)  # (B, D)
        return output, alphas

class TimeDistributedSegmentAttention(nn.Module):
    def __init__(self, input_dim, units):
        super().__init__()
        self.segment_attention = SegmentAttention(input_dim, units)

    def forward(self, inputs, mask=None):
        """
        inputs: (B, S, T, D)
        mask:   (B, S, T) or None
        """
        B, S, T, D = inputs.shape
        flat = inputs.view(B * S, T, D)                         # (B*S, T, D)
        if mask is not None:
            mask_flat = mask.view(B * S, T)                     # (B*S, T)
        else:
            mask_flat = None
        outputs, alphas = self.segment_attention(flat, mask_flat)  # (B*S, D), (B*S, T)
        outputs = outputs.view(B, S, D)                         # (B, S, D)
        alphas  = alphas.view(B, S, T)                          # (B, S, T)
        return outputs, alphas

class HANWithAttention(nn.Module):
    def __init__(self, num_classes=11):
        super().__init__()
        self.conv1d = nn.Conv1d(in_channels=12, out_channels=256, kernel_size=25, padding=12)
        self.channel_attention = ChannelAttention(256, ratio=8)
        self.bn1 = nn.BatchNorm1d(256)
        self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.lstm_segment = nn.LSTM(input_size=256, hidden_size=512, batch_first=True)
        self.time_distributed_attention = TimeDistributedSegmentAttention(input_dim=512, units=512)
        self.lstm_sequence = nn.LSTM(input_size=512, hidden_size=256, batch_first=True)
        self.final_attention = SegmentAttention(input_dim=256, units=256)
        self.fc = nn.Linear(256, 512)
        self.dropout = nn.Dropout(0.24)
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x, mask=None):
        logits, _, _ = self.forward_with_attention(x, mask)
        return logits

    def forward_with_attention(self, x, mask=None):
        """
        x:    (B, S, T, C)
        mask: (B, S, T) with 1=valid, 0=pad (or None)
        """
        B, S, T, C = x.shape
        x = x.view(B * S, T, C).permute(0, 2, 1)    # (B*S, C, T)

        conv = self.conv1d(x)                       # (B*S, 128, T)
        conv = self.bn1(conv)        # Normalize activations
        conv = F.relu(conv)
        att  = self.channel_attention(conv)         # (B*S, 128, T)
        pooled = self.pool(att)                     # (B*S, 128, T2)
        pooled = pooled.permute(0, 2, 1)            # (B*S, T2, 128)

        # Downsample mask to T2 with the same pooling parameters
        if mask is not None:
            m = mask.view(B * S, 1, T)              # (B*S, 1, T)
            m2 = F.max_pool1d(m, kernel_size=3, stride=2, padding=1)  # (B*S, 1, T2)
            m2 = (m2 > 0.0).float().squeeze(1)      # (B*S, T2)
        else:
            m2 = None

        seg_lstm_out, _ = self.lstm_segment(pooled) # (B*S, T2, 256)
        seg_lstm_out = seg_lstm_out.view(B, S, seg_lstm_out.shape[1], seg_lstm_out.shape[2])  # (B, S, T2, 256)
        if m2 is not None:
            m2 = m2.view(B, S, -1)                  # (B, S, T2)

        segment_outputs, segment_alphas = self.time_distributed_attention(seg_lstm_out, mask=m2)  # (B, S, 256), (B, S, T2)
        seq_lstm_out, _ = self.lstm_sequence(segment_outputs)                                     # (B, S, 512)
        final_output, final_alphas = self.final_attention(seq_lstm_out)                           # (B, 512), (B, S)

        x = F.relu(self.fc(final_output))
        x = self.dropout(x)
        logits = self.classifier(x)                 # (B, num_classes)
        return logits, final_alphas, segment_alphas


model = HANWithAttention(num_classes=11).to(DEVICE)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HANWithAttention().to(device)

from torchinfo import summary
summary(model,
        input_size=(2, 10, 300, 12),    # (batch, segments, timesteps, channels)
        col_names=("input_size", "output_size", "num_params", "trainable"),
        depth=4,
        device=device.type)        


Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Trainable
HANWithAttention                              [2, 10, 300, 12]          [2, 11]                   --                        True
├─Conv1d: 1-1                                 [20, 12, 300]             [20, 256, 300]            77,056                    True
├─BatchNorm1d: 1-2                            [20, 256, 300]            [20, 256, 300]            512                       True
├─ChannelAttention: 1-3                       [20, 256, 300]            [20, 256, 300]            --                        True
│    └─Sequential: 2-1                        [20, 256]                 [20, 256]                 --                        True
│    │    └─Linear: 3-1                       [20, 256]                 [20, 32]                  8,224                     True
│    │    └─ReLU: 3-2                         [20, 32]                  [20, 32]            

In [None]:
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, f1_score

from tqdm import tqdm

def evaluate_metrics(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_labels, all_preds, all_probs = [], [], []

    with torch.no_grad():
        for batch in loader:
            inputs = batch["signal"].to(device)
            labels = batch["label"].to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)

            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            preds = outputs.argmax(dim=1).cpu().numpy()
            labels = labels.cpu().numpy()

            all_labels.extend(labels)
            all_preds.extend(preds)
            all_probs.extend(probs)

    avg_loss = running_loss / len(all_labels)
    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average="weighted", zero_division=0)
    recall = recall_score(all_labels, all_preds, average="weighted", zero_division=0)
    f1 = f1_score(all_labels, all_preds, average="weighted", zero_division=0)

    try:
      y_true = np.eye(num_classes)[all_labels]
      y_score = np.array(all_probs)

      auc_list = []
      for i in range(num_classes):
          if np.any(y_true[:, i]):  # class i exists
              auc_list.append(roc_auc_score(y_true[:, i], y_score[:, i]))
      if auc_list:
          auc = np.mean(auc_list)
      else:
          auc = float("nan")
    except ValueError:
        auc = float("nan")

    return avg_loss, acc, precision, recall, auc, f1

def train(model, train_loader, val_loader, optimizer, criterion, device,
          epochs, scheduler=None):
    import numpy as np
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
    from tqdm import tqdm

    model.to(device)

    history = {
        "train_loss": [], "train_acc": [], "train_f1": [], "train_precision": [], "train_recall": [], "train_auc": [],
        "val_loss": [],   "val_acc": [],   "val_f1": [],   "val_precision": [],   "val_recall": [], "val_auc": []
    }

    skip_seed = False

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        all_preds, all_probs, all_labels = [], [], []

        loop = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False)
        for batch in loop:
            inputs = batch["signal"].to(device)
            mask = batch.get("mask", None)
            if mask is not None:
                mask = mask.to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad(set_to_none=True)
            outputs = model(inputs, mask=mask) if mask is not None else model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            batch_size = labels.size(0)
            running_loss += loss.item() * batch_size
            preds = outputs.argmax(dim=1)
            probs = torch.softmax(outputs, dim=1).detach().cpu().numpy()

            correct += (preds == labels).sum().item()
            total += batch_size

            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs)
            all_labels.extend(labels.cpu().numpy())

            loop.set_postfix(loss=f"{loss.item():.4f}", acc=f"{correct/total:.4f}")

        epoch_loss = running_loss / total
        epoch_acc = correct / total
        epoch_precision = precision_score(all_labels, all_preds, average="weighted", zero_division=0)
        epoch_recall = recall_score(all_labels, all_preds, average="weighted", zero_division=0)
        epoch_f1 = f1_score(all_labels, all_preds, average="weighted", zero_division=0)

        try:
            y_true = np.eye(num_classes)[all_labels]
            y_score = np.array(all_probs)
            auc_list = [
                roc_auc_score(y_true[:, i], y_score[:, i])
                for i in range(num_classes) if np.any(y_true[:, i])
            ]
            epoch_auc = np.mean(auc_list) if auc_list else float("nan")
        except ValueError:
            epoch_auc = float("nan")

        history["train_loss"].append(epoch_loss)
        history["train_acc"].append(epoch_acc)
        history["train_precision"].append(epoch_precision)
        history["train_recall"].append(epoch_recall)
        history["train_f1"].append(epoch_f1)
        history["train_auc"].append(epoch_auc)

        print(
            f"Epoch {epoch}/{epochs} | "
            f"train_loss {epoch_loss:.4f} | train_f1 {epoch_f1:.4f}"
        )

        if val_loader is not None:
            vloss, vacc, vprecision, vrecall, vauc, vf1 = evaluate_metrics(
                model, val_loader, criterion, device
            )

            history["val_loss"].append(vloss)
            history["val_acc"].append(vacc)
            history["val_precision"].append(vprecision)
            history["val_recall"].append(vrecall)
            history["val_f1"].append(vf1)
            history["val_auc"].append(vauc)

            print(
                f"Epoch {epoch} | "
                f"val_loss {vloss:.4f} | val_f1 {vf1:.4f} | "
                f"val_precision {vprecision:.4f} | val_recall {vrecall:.4f}"
            )

            if scheduler is not None:
                scheduler.step(vloss)

        else:
            history["val_loss"].append(None)
            history["val_acc"].append(None)
            history["val_precision"].append(None)
            history["val_recall"].append(None)
            history["val_f1"].append(None)
            history["val_auc"].append(None)

    return history, skip_seed



In [None]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score
from tqdm import tqdm
from tqdm.auto import tqdm
import random

criterion = nn.CrossEntropyLoss()
BATCH_SIZE = 32
EPOCHS = 30

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def run_experiments(model_fn,
                    optimizer_fn, criterion, device, epochs, num_classes,
                    seeds=10):
    results = {"acc": [], "precision": [], "recall": [], "auc": [], "loss": []}

    for seed in range(40, 60):
        print(f"\n=== Seed {seed} ===")
        set_seed(seed)
        ds = ECGSegmentDatasetVarLen(data_list)

        all_indices = list(range(len(ds)))
        all_labels = [data_list[i][1] for i in all_indices]

        # Train/val/test split
        trainval_indices, test_indices = train_test_split(
            all_indices, test_size=0.1, random_state=seed, stratify=all_labels
        )
        trainval_labels = [all_labels[i] for i in trainval_indices]
        train_indices, val_indices = train_test_split(
            trainval_indices, test_size=0.1, random_state=seed, stratify=trainval_labels
        )

        # Datasets and loaders
        train_ds = Subset(ds, train_indices)
        val_ds = Subset(ds, val_indices)
        test_ds = Subset(ds, test_indices)
        train_loader = make_loader(train_ds, batch_size=BATCH_SIZE)
        val_loader = make_loader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = make_loader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

        print("Train distribution:", get_label_distribution(train_indices))
        print("Val distribution:", get_label_distribution(val_indices))
        print("Test distribution:", get_label_distribution(test_indices))

        model = model_fn().to(device)
        optimizer = optimizer_fn(model)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.1, patience=2
        )

        # --- Train ---
        history, _ = train(model, train_loader, val_loader, optimizer, criterion, device, epochs, scheduler)
    

        # --- Evaluate on test set ---
        loss, acc, precision, recall, auc, f1 = evaluate_metrics(model, test_loader, criterion, device)
        print(f"Test (seed {seed}) — loss: {loss:.4f}, acc: {acc:.4f}, "
              f"precision: {precision:.4f}, recall: {recall:.4f}, auc: {auc:.4f}, f1: {f1:.4f}")

        results["loss"].append(loss)
        results["acc"].append(acc)
        results["precision"].append(precision)
        results["recall"].append(recall)
        results["auc"].append(auc)

    # --- Aggregate results ---
    print("\n=== Final Results (across seeds) ===")
    for k, v in results.items():
        arr = np.array(v, dtype=np.float32)
        print(f"{k}: mean={arr.mean():.4f}, std={arr.std():.4f}")

    return results

In [None]:
def model_fn():
  return HANWithAttention(num_classes=num_classes)

def optimizer_fn(model):
  return torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-3)

In [None]:
results = run_experiments(
    model_fn=model_fn,
    optimizer_fn=optimizer_fn,
    criterion=criterion,
    device=DEVICE,
    epochs=25,
    num_classes=num_classes,
    seeds=10
)
