# **Use GPU**

In [2]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


# **Hide Warnings**

In [3]:
import warnings
warnings.filterwarnings("ignore")

# Install Libraries

In [4]:
pip install torch torchvision torchaudio torchinfo einops scikit-learn pandas

Note: you may need to restart the kernel to use updated packages.


In [5]:
!pip install -U mne==1.0.0 scipy==1.13.1 numpy==1.26.4



# Dataset Details

In [6]:
# Path to file
file_path1 = '/kaggle/input/PhysioNet.pkl'

In [7]:
import pickle

# Load the dataset
file_path = file_path1
with open(file_path, 'rb') as file:
    data = pickle.load(file)

# Check basic information
print(f"Loaded {len(data)} subjects.")
print(f"First subject summary:\n{data[0]}")

Loaded 109 subjects.
First subject summary:
<Epochs |  90 events (all good), 0 - 4 sec, baseline off, ~28.2 MB, data loaded, with metadata,
 'left_hand': 23
 'right_hand': 22
 'both_hands': 21
 'feet': 24>


In [8]:
import pickle

# Load the dataset
file_path = file_path1
with open(file_path, 'rb') as file:
    data = pickle.load(file)

# Print basic information about the dataset
print(f"Number of subjects loaded: {len(data)}")

for idx, epochs in enumerate(data):
    print(f"\n--- Subject {idx + 1} ---")
    print(f"Shape of raw data: {epochs.get_data().shape}")  # (n_epochs, n_channels, n_times)
    print(f"Sampling frequency: {epochs.info['sfreq']} Hz")
    print(f"Number of channels: {len(epochs.info['ch_names'])}")
    print(f"Channel names: {epochs.info['ch_names']}")
    print(f"Event types: {set(event[2] for event in epochs.events)}")

Number of subjects loaded: 109

--- Subject 1 ---
Shape of raw data: (90, 64, 641)
Sampling frequency: 160.0 Hz
Number of channels: 64
Channel names: ['FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FT8', 'T7', 'T8', 'T9', 'T10', 'TP7', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'O1', 'Oz', 'O2', 'Iz']
Event types: {1, 2, 3, 4}

--- Subject 2 ---
Shape of raw data: (90, 64, 641)
Sampling frequency: 160.0 Hz
Number of channels: 64
Channel names: ['FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FT8', 'T7', 'T8', 'T9', 'T10'

# **Import Libraries**

In [9]:
import os
import mne
import time
import math
import pickle
import random
import datetime
import scipy.io
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from einops import rearrange
from torchinfo import summary
import matplotlib.pyplot as plt
from sklearn.utils import class_weight
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from einops.layers.torch import Rearrange, Reduce
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import LeaveOneGroupOut, train_test_split

In [None]:
import os
import random
import pickle
import math
import csv
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, LeaveOneGroupOut

from einops.layers.torch import Rearrange
from tqdm import tqdm

# -----------------------
# 0) Basic imports & seed
# -----------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# -----------------------
# 1) Pad & resize helper
# -----------------------
def pad_and_resize_eeg(eeg_2d: np.ndarray, target_h: int = 64, target_w: int = 201) -> np.ndarray:
    """
    Zero-pad or truncate EEG trial to (target_h, target_w)
    eeg_2d shape: (channels, time_samples)
    """
    out = np.zeros((target_h, target_w), dtype=eeg_2d.dtype)
    h, w = eeg_2d.shape
    used_h = min(h, target_h)
    used_w = min(w, target_w)
    out[:used_h, :used_w] = eeg_2d[:used_h, :used_w]
    return out

# -----------------------
# 2) Load & preprocess
# -----------------------
def load_and_preprocess_data(
    pickle_path: str,
    target_h: int = 64,
    target_w: int = 641,
    l_freq: float = 7.0,
    h_freq: float = 30.0
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Load list-of-MNE-Epochs pickle and return standardized arrays.
    Returns:
        all_X: (N_trials, 1, target_h, target_w) float32
        all_y: (N_trials,) int64 with labels 0/1
        all_subjects: (N_trials,) subject index
    """
    print(f"Loading data from: {pickle_path}")
    with open(pickle_path, "rb") as f:
        raw_data = pickle.load(f)  # list of MNE Epochs objects (one per subject)

    all_X_list = []
    all_y_list = []
    all_subj_list = []

    for subj_idx, epochs in enumerate(raw_data):
        # Some Epochs objects may not have filter method available if already filtered; guard with try/catch
        try:
            epochs = epochs.copy().filter(l_freq=l_freq, h_freq=h_freq)
        except Exception:
            # if filtering fails, continue without throwing (user can prefilter)
            pass

        X = epochs.get_data()  # (trials, channels, time_samples)
        events = epochs.events  # (trials, 3) usually [sample, 0, event_id]
        # event_id likely in column 2; guard for shape
        if events.shape[1] >= 3:
            y_raw = events[:, 2]
        else:
            # fallback if events stored differently
            y_raw = epochs.events[:, -1]

        # Map event codes to labels: 1->0, 2->1, 3->2, 4->3
        # This includes all observed events {1,2,3,4} as 4 classes (left_hand, right_hand, both_hands, feet)
        y_mapped = np.zeros_like(y_raw) - 1  # Initialize to -1 for invalid
        y_mapped[y_raw == 1] = 0
        y_mapped[y_raw == 2] = 1
        y_mapped[y_raw == 3] = 2
        y_mapped[y_raw == 4] = 3
        # Filter only valid labels (0-3 after mapping)
        valid_mask = (y_mapped >= 0) & (y_mapped < 4)
        X = X[valid_mask]
        y_mapped = y_mapped[valid_mask]

        # resize each trial to (target_h, target_w)
        padded = [pad_and_resize_eeg(trial, target_h, target_w) for trial in X]
        padded = np.stack(padded, axis=0)  # (trials, channels, time)

        padded = np.expand_dims(padded, axis=1)  # (trials, 1, channels, time)
        all_X_list.append(padded)
        all_y_list.append(y_mapped)
        all_subj_list += [subj_idx] * len(y_mapped)

        # print(f"Subject {subj_idx+1}: trials={padded.shape[0]}, shape(per_trial)={(1,target_h,target_w)}")

    all_X = np.concatenate(all_X_list, axis=0)
    all_y = np.concatenate(all_y_list, axis=0)
    all_subjects = np.array(all_subj_list)

    # Standardize across channels*time per-trial (zero mean, unit var)
    nsamples, _, H, W = all_X.shape
    scaler = StandardScaler()
    X_flat = all_X.reshape(nsamples, -1)
    X_scaled = scaler.fit_transform(X_flat)
    all_X = X_scaled.reshape(nsamples, 1, H, W).astype(np.float32)

    # print(f"Total trials: {all_X.shape[0]}, Subjects: {len(np.unique(all_subjects))}")
    return all_X, all_y.astype(np.int64), all_subjects

# -----------------------
# 3) Dataset without augmentation
# -----------------------
class MultiSubjectBCIDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray, augment: bool = False):
        """
        X: (N, 1, C, T)
        y: (N,)
        augment: Kept for compatibility, but no augmentations are applied
        """
        self.X = X
        self.y = y
        self.augment = augment

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x_np = self.X[idx]  # (1, C, T)
        label = int(self.y[idx])
        x_t = torch.from_numpy(x_np).float()
        return x_t, torch.tensor(label, dtype=torch.long)

# -----------------------
# 4) Window helpers (Swin1D utils)
# -----------------------
def pad_sequence_1d(x: torch.Tensor, window_size: int):
    B, L, C = x.shape
    remainder = L % window_size
    if remainder == 0:
        return x, L, 0
    pad_len = window_size - remainder
    pad_vec = torch.zeros(B, pad_len, C, dtype=x.dtype, device=x.device)
    x_padded = torch.cat([x, pad_vec], dim=1)
    return x_padded, L, pad_len

def window_partition_1d(x: torch.Tensor, window_size: int):
    B, L, C = x.shape
    x_padded, orig_L, pad_len = pad_sequence_1d(x, window_size)
    Bp, Lp, Cp = x_padded.shape
    num_windows = Lp // window_size
    x_padded = x_padded.view(Bp, num_windows, window_size, Cp)
    x_windows = x_padded.reshape(Bp * num_windows, window_size, Cp)
    return x_windows, (orig_L, pad_len, num_windows)

def window_reverse_1d(x_windows: torch.Tensor, window_size: int, pad_info):
    orig_L, pad_len, num_windows = pad_info
    BnW, WS, C = x_windows.shape
    B = BnW // num_windows
    x_reshaped = x_windows.view(B, num_windows, WS, C)
    x_merged = x_reshaped.reshape(B, num_windows * WS, C)
    if pad_len > 0:
        x_merged = x_merged[:, :orig_L, :]
    return x_merged

def cyclic_shift_1d(x: torch.Tensor, shift_size: int):
    return torch.roll(x, shifts=-shift_size, dims=1)

def cyclic_shift_back_1d(x: torch.Tensor, shift_size: int):
    return torch.roll(x, shifts=shift_size, dims=1)

# -----------------------
# 5) Attention & MLP
# -----------------------
class Attention(nn.Module):
    def __init__(self, dim: int, num_heads: int, attn_dropout: float = 0.1):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.attn_drop = nn.Dropout(attn_dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(attn_dropout)

    def forward(self, x):
        B, L, C = x.shape
        qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = (attn @ v).transpose(1, 2).reshape(B, L, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class MLP(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

# -----------------------
# 6) Swin1D block and transformer
# -----------------------
class Swin1DBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, window_size: int = 4, shift_size: int = 2, mlp_hidden: int = 128, attn_dropout: float = 0.1):
        super().__init__()
        self.window_size = window_size
        self.shift_size = shift_size

        self.ln1 = nn.LayerNorm(dim)
        self.attn = Attention(dim=dim, num_heads=num_heads, attn_dropout=attn_dropout)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim=dim, hidden_dim=mlp_hidden, dropout=attn_dropout)

    def forward(self, x):
        # x: (B, L, dim)
        if self.shift_size > 0:
            x = cyclic_shift_1d(x, self.shift_size)

        x_windows, pad_info = window_partition_1d(x, self.window_size)

        shortcut = x_windows
        x_windows = self.ln1(x_windows)
        x_windows = self.attn(x_windows)
        x_windows = shortcut + x_windows

        shortcut = x_windows
        x_windows = self.ln2(x_windows)
        x_windows = self.mlp(x_windows)
        x_windows = shortcut + x_windows

        x_merged = window_reverse_1d(x_windows, self.window_size, pad_info)
        if self.shift_size > 0:
            x_merged = cyclic_shift_back_1d(x_merged, self.shift_size)
        return x_merged

class Swin1DTransformer(nn.Module):
    def __init__(self, dim: int = 64, num_layers: int = 3, num_heads: int = 4, mlp_hidden: int = 128, window_size: int = 4, attn_dropout: float = 0.1):
        super().__init__()
        blocks = []
        for i in range(num_layers):
            shift = window_size // 2 if (i % 2 == 1) else 0
            blocks.append(Swin1DBlock(dim=dim, num_heads=num_heads, window_size=window_size, shift_size=shift, mlp_hidden=mlp_hidden, attn_dropout=attn_dropout))
        self.blocks = nn.ModuleList(blocks)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        x = x.mean(dim=1)  # average pool over sequence length
        return x

# -----------------------
# 7) Positional embeddings
# -----------------------
def create_positional_embedding(mode: str, seq_len: int, dim: int):
    if mode == 'none':
        return None
    elif mode == 'learnable':
        pe = nn.Parameter(torch.zeros(1, seq_len, dim))
        nn.init.trunc_normal_(pe, std=0.02)
        return pe
    elif mode == 'sine':
        pe_np = np.zeros((seq_len, dim))
        for pos in range(seq_len):
            for i in range(0, dim, 2):
                theta = pos / (10000 ** ((2 * i) / dim))
                pe_np[pos, i]   = np.sin(theta)
                if i+1 < dim:
                    pe_np[pos, i+1] = np.cos(theta)
        pe = torch.from_numpy(pe_np).float().unsqueeze(0)
        return nn.Parameter(pe, requires_grad=False)
    else:
        raise ValueError("Unsupported pos emb mode")

# -----------------------
# 8) Patch Embedding (dynamic for C and W)
# -----------------------
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int, input_time_len: int, emb_size: int = 40, pool_kernel: int = 75, pool_stride: int = 15):
        """
        Args:
            in_channels: number of EEG channels (e.g., 64)
            input_time_len: number of time samples per trial (e.g., 201)
            emb_size: output embedding dimension
            pool_kernel/pool_stride: controls temporal downsampling -> determines seq_len
        """
        super().__init__()
        # Temporal conv (across time) -> keep channel dim intact (still in channel dim)
        self.conv_temp = nn.Conv2d(1, 40, kernel_size=(1, 25), stride=(1, 1), padding=(0, 0))
        # Spatial conv across electrodes: kernel height = in_channels -> collapses electrode dim to 1
        self.conv_spat = nn.Conv2d(40, 40, kernel_size=(in_channels, 1), stride=(1, 1), padding=(0, 0))
        self.bn = nn.BatchNorm2d(40)
        self.act = nn.ELU()
        self.pool = nn.AvgPool2d(kernel_size=(1, pool_kernel), stride=(1, pool_stride))
        self.drop = nn.Dropout(p=0.5)
        self.proj = nn.Conv2d(40, emb_size, kernel_size=(1,1), stride=(1,1))
        self.rearrange = Rearrange('b e h w -> b (h w) e')  # will be (B, seq_len, emb_size)

        # store params for computing seq_len
        self.input_time_len = input_time_len
        self.pool_kernel = pool_kernel
        self.pool_stride = pool_stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, 1, C, T)
        x = self.conv_temp(x)     # -> (B, 40, C, T-24)
        x = self.conv_spat(x)     # -> (B, 40, 1, T-24)
        x = self.bn(x)
        x = self.act(x)
        x = self.pool(x)          # -> (B, 40, 1, Wp)
        x = self.drop(x)
        x = self.proj(x)          # -> (B, emb_size, 1, Wp)
        x = self.rearrange(x)     # -> (B, seq_len=Wp, emb_size)
        return x

    def compute_seq_len(self):
        # compute width after conv_temp (kernel 25, stride 1, no pad)
        W1 = self.input_time_len - 25 + 1  # (T - kernel + 1)
        # after pool: floor((W1 - pool_kernel)/pool_stride) + 1 if W1 >= pool_kernel else 0 or 1 (handle small sizes)
        if W1 < self.pool_kernel:
            # if smaller than kernel, pooling will return size 1 (PyTorch behavior: kernel > input -> kernel reduces to input length)
            # but to be safe, set seq_len = 1
            Wp = 1
        else:
            Wp = (W1 - self.pool_kernel) // self.pool_stride + 1
            Wp = max(1, Wp)
        return Wp

# -----------------------
# 9) Classification head & CCST
# -----------------------
class ClassificationHead(nn.Module):
    def __init__(self, input_dim: int = 64, num_classes: int = 4, dropout: float = 0.3):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        x = self.dropout(x)
        return self.linear(x)

class CCST(nn.Module):
    def __init__(
        self,
        in_channels: int = 64,
        input_time_len: int = 201,
        emb_size: int = 40,
        swin_embedding_dim: int = 64,
        num_swin_layers: int = 3,
        num_heads: int = 4,
        mlp_size: int = 128,
        pos_emb_mode: str = 'learnable',
        num_classes: int = 4
    ):
        super().__init__()
        self.patch_embedding = PatchEmbedding(in_channels=in_channels, input_time_len=input_time_len, emb_size=emb_size)
        seq_len = self.patch_embedding.compute_seq_len()
        self.embedding_projection = nn.Linear(emb_size, swin_embedding_dim)
        self.pos_encoding = create_positional_embedding(pos_emb_mode, seq_len=seq_len, dim=swin_embedding_dim)
        self.transformer = Swin1DTransformer(dim=swin_embedding_dim, num_layers=num_swin_layers, num_heads=num_heads, mlp_hidden=mlp_size, window_size=4, attn_dropout=0.1)
        self.classification_head = ClassificationHead(input_dim=swin_embedding_dim, num_classes=num_classes)

    def forward(self, x: torch.Tensor):
        # x: (B,1,C,T)
        x = self.patch_embedding(x)             # [B, seq_len, emb_size]
        x = self.embedding_projection(x)        # [B, seq_len, swin_embedding_dim]
        if self.pos_encoding is not None:
            x = x + self.pos_encoding[:, :x.size(1), :].to(x.device)
        features = self.transformer(x)          # [B, swin_embedding_dim]
        logits = self.classification_head(features)
        return logits

# -----------------------
# 10) Experiment wrapper & LOSO training
# -----------------------
class HybridExP():
    def __init__(self, pickle_path: str, device: str = 'cuda:0', batch_size: int = 64, n_epochs: int = 150):
        self.pickle_path = pickle_path
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.patience = 10
        self.lr = 3e-4
        self.betas = (0.9, 0.999)
        self.c_dim = 4

        # Data fields (will be filled later)
        self.X = None
        self.y = None
        self.subjects = None

        # Model will be constructed after data loaded to get dims
        self.model = None
        self.criterion_cls = nn.CrossEntropyLoss().to(self.device)

    def get_source_data(self):
        X, y, subjects = load_and_preprocess_data(self.pickle_path, target_h=64, target_w=641, l_freq=7.0, h_freq=30.0)
        self.X = X
        self.y = y
        self.subjects = subjects
        return X, y, subjects

    def build_model(self):
        # Build model with dims matching data
        _, _, C, T = self.X.shape
        # print(f"Building model for input channels={C}, time_len={T}")
        model = CCST(in_channels=C, input_time_len=T, emb_size=40, swin_embedding_dim=64, num_swin_layers=3, num_heads=4, mlp_size=128, pos_emb_mode='learnable', num_classes=self.c_dim)
        model = model.to(self.device)
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        self.model = model

    def train_fold(self, train_indices, val_indices, test_indices):
        X_train_full = self.X[train_indices]
        y_train_full = self.y[train_indices]
        X_val = self.X[val_indices]
        y_val = self.y[val_indices]
        X_test = self.X[test_indices]
        y_test = self.y[test_indices]

        # small stratified split already done in caller; if val empty, fallback
        train_ds = MultiSubjectBCIDataset(X_train_full, y_train_full, augment=False)
        val_ds = MultiSubjectBCIDataset(X_val, y_val, augment=False)
        test_ds = MultiSubjectBCIDataset(X_test, y_test, augment=False)

        train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=True, num_workers=2, pin_memory=True)
        val_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False, num_workers=2, pin_memory=True)
        test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False, num_workers=2, pin_memory=True)

        optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=self.betas)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=False)

        best_val_acc = 0.0
        best_state = None
        patience_counter = 0

        for epoch in range(self.n_epochs):
            # Training
            self.model.train()
            running_loss = 0.0
            correct = 0
            total = 0

            for imgs, labels in train_loader:
                imgs = imgs.to(self.device)
                labels = labels.to(self.device)
                optimizer.zero_grad()
                outputs = self.model(imgs)
                loss = self.criterion_cls(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * imgs.size(0)
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

            train_loss = running_loss / max(1, total)
            train_acc = correct / max(1, total)

            # Validation
            self.model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            with torch.no_grad():
                for imgs, labels in val_loader:
                    imgs = imgs.to(self.device)
                    labels = labels.to(self.device)
                    outputs = self.model(imgs)
                    loss = self.criterion_cls(outputs, labels)
                    val_loss += loss.item() * imgs.size(0)
                    preds = outputs.argmax(dim=1)
                    val_correct += (preds == labels).sum().item()
                    val_total += labels.size(0)
            if val_total == 0:
                val_loss = 0.0
                val_acc = 0.0
            else:
                val_loss = val_loss / val_total
                val_acc = val_correct / val_total

            scheduler.step(val_loss if val_total>0 else train_loss)

            # Print progress every epoch
            print(f"Epoch {epoch+1}/{self.n_epochs} | Train Loss: {train_loss:.4f} Acc: {train_acc*100:.2f}% | Val Loss: {val_loss:.4f} Acc: {val_acc*100:.2f}%")

            # Early stopping based on val_acc
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_state = {k:v.cpu() for k,v in self.model.state_dict().items()}
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= self.patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break

        # load best
        if best_state is not None:
            self.model.load_state_dict(best_state)

        # Test
        self.model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for imgs, labels in test_loader:
                imgs = imgs.to(self.device)
                labels = labels.to(self.device)
                outputs = self.model(imgs)
                loss = self.criterion_cls(outputs, labels)
                test_loss += loss.item() * imgs.size(0)
                preds = outputs.argmax(dim=1)
                test_correct += (preds == labels).sum().item()
                test_total += labels.size(0)
                all_preds.append(preds.cpu().numpy())
                all_labels.append(labels.cpu().numpy())

        test_loss = test_loss / max(1, test_total)
        test_acc = test_correct / max(1, test_total)
        if len(all_preds) > 0:
            all_preds = np.concatenate(all_preds)
            all_labels = np.concatenate(all_labels)
        else:
            all_preds = np.array([])
            all_labels = np.array([])

        return test_acc, test_loss, all_labels, all_preds

    def train_loso(self, out_csv: str = "loso_results.csv"):
        X, y, subjects = self.get_source_data()
        self.build_model()

        logo = LeaveOneGroupOut()
        test_accuracies = []
        subject_ids = []

        fold = 0

        # print("\nStarting LOSO training ...")
        for train_idx, test_idx in logo.split(X, y, groups=subjects):
            fold += 1
            heldout_subj = subjects[test_idx[0]]
            print(f"\n--- Fold {fold}: Test subject {heldout_subj+1}")

            # reseed per fold
            set_seed(42 + fold)

            # further split train into train/val stratified
            X_train_full = X[train_idx]
            y_train_full = y[train_idx]

            unique_classes = len(np.unique(y_train_full))
            if unique_classes < 4:
                print(f"Warning: fewer than 4 classes ({unique_classes} classes) in training set for this fold; skipping")
                continue

            X_train, X_val, y_train, y_val = train_test_split(X_train_full, y_train_full, test_size=0.1, stratify=y_train_full, random_state=42+fold)

            test_indices = test_idx
            val_indices = np.arange(len(X))[:0] 

            # Build fresh model for this fold
            self.build_model()
            self.model.to(self.device)

            # Prepare arrays for this fold
            X_combined = np.concatenate([X_train, X_val, X[test_idx]], axis=0)
            y_combined = np.concatenate([y_train, y_val, y[test_idx]], axis=0)
            n_train = len(X_train)
            n_val = len(X_val)
            n_test = len(test_idx)

            # Get absolute indices
            train_indices = np.arange(0, n_train)
            val_indices = np.arange(n_train, n_train + n_val)
            test_indices_rel = np.arange(n_train + n_val, n_train + n_val + n_test)

            # Temporarily set self.X/y to combined arrays to reuse train_fold
            X_old, y_old = self.X, self.y
            self.X, self.y = X_combined, y_combined

            # Train fold
            test_acc, test_loss, all_labels, all_preds = self.train_fold(train_indices, val_indices, test_indices_rel)

            print(f"Fold {fold} | Subject {heldout_subj+1} Test Acc: {test_acc*100:.2f} % | Test Loss: {test_loss:.4f}")

            test_accuracies.append(test_acc)
            subject_ids.append(heldout_subj)

            # Restore X/y
            self.X, self.y = X_old, y_old

            # Append results to CSV
            with open(out_csv, mode='a', newline='') as f:
                writer = csv.writer(f)
                if f.tell() == 0:
                    writer.writerow(["subject_idx", "test_acc", "test_loss", "n_test_samples"])
                writer.writerow([int(heldout_subj), float(test_acc), float(test_loss), int(n_test)])

        # Summary
        if len(test_accuracies) > 0:
            avg_test_acc = np.mean(test_accuracies) * 100.0
        else:
            avg_test_acc = 0.0
        print("\nLOSO Summary:")
        for sid, acc in zip(subject_ids, test_accuracies):
            print(f"Subject {sid+1}: {acc*100:.2f} %")
        print(f"Average Test Accuracy: {avg_test_acc:.2f} %")
        # print(f"Results saved to {out_csv}")

# -----------------------
# 11) Main
# -----------------------
def main():
    pickle_path = '/kaggle/input/PhysioNet.pkl'  # user-provided
    out_csv = "loso_results.csv"
    if os.path.exists(out_csv):
        os.remove(out_csv)

    exp = HybridExP(pickle_path=pickle_path, device='cuda:0', batch_size=72, n_epochs=150)
    exp.train_loso(out_csv=out_csv)

if __name__ == "__main__":
    main()

Loading data from: /kaggle/input/PhysioNet.pkl
Setting up band-pass filter from 7 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 7.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 6.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 265 samples (1.656 sec)

Setting up band-pass filter from 7 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 7.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 6.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 