# **Use GPU**

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# **Hide Warnings**

In [None]:
import warnings
warnings.filterwarnings("ignore")

# **Install Libraries**

In [None]:
pip install torch torchvision torchaudio torchinfo einops scikit-learn pandas

In [None]:
!pip install -U mne==1.0.0 scipy==1.13.1 numpy==1.26.4

# **Import Libraries**

In [None]:
import os
import mne
import time
import math
import pickle
import random
import datetime
import scipy.io
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from einops import rearrange
from torchinfo import summary
import matplotlib.pyplot as plt
from sklearn.utils import class_weight
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from einops.layers.torch import Rearrange, Reduce
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import LeaveOneGroupOut, train_test_split

In [None]:
###############################################################################
# 1) Set seeds for reproducibility
###############################################################################
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

###############################################################################
# 2) Load & preprocess data from the pickle file
###############################################################################
def pad_and_resize_eeg(eeg_2d, target_h=3, target_w=321):
    """
    Zero-pad or resize EEG data to ensure consistent dimensions
    
    Args:
        eeg_2d (np.ndarray): EEG data of shape (channels, time_samples)
        target_h (int): Target number of EEG channels
        target_w (int): Target number of time samples
    
    Returns:
        np.ndarray: Zero-padded EEG data of shape (target_h, target_w)
    """
    out = np.zeros((target_h, target_w), dtype=eeg_2d.dtype)
    h, w = eeg_2d.shape
    used_h = min(h, target_h)
    used_w = min(w, target_w)
    out[:used_h, :used_w] = eeg_2d[:used_h, :used_w]
    return out

def load_and_preprocess_data(pickle_path, target_h=3, target_w=321, l_freq=7, h_freq=30):
    """
    Loads and preprocesses EEG data from a pickle file

    Args:
        pickle_path (str): Path to the pickle file
        target_h (int): Target number of EEG channels
        target_w (int): Target number of time samples
        l_freq (float): Low cutoff frequency for band-pass filter
        h_freq (float): High cutoff frequency for band-pass filter

    Returns:
        all_X (np.ndarray): Preprocessed EEG data of shape (trials, 1, channels, time_samples)
        all_y (np.ndarray): Labels of shape (trials,)
        all_subjects (np.ndarray): Subject indices for each trial
    """
    with open(pickle_path, 'rb') as f:
        raw_data = pickle.load(f)  # List of MNE Epochs, one per subject

    all_X = []
    all_y = []
    all_subjects = []

    for subj_idx, epochs in enumerate(raw_data):
        # Apply band-pass filter
        epochs = epochs.copy().filter(l_freq=l_freq, h_freq=h_freq)
        X = epochs.get_data()  # (trials, channels, time_samples)
        y = epochs.events[:, 2]  # (trials,)

        # Mapping: Left Hand = 0, Right Hand = 1
        y = np.where(y == 1, 0, 1)

        # Resize each trial
        padded = [pad_and_resize_eeg(trial, target_h, target_w) for trial in X]
        padded_array = np.array(padded)  # (trials, channels, time_samples)

        # Expand dimensions to match (B, 1, C, T)
        padded_array = np.expand_dims(padded_array, axis=1)  # (trials, 1, channels, time_samples)

        all_X.append(padded_array)
        all_y.append(y)
        all_subjects += [subj_idx] * len(y)

    all_X = np.concatenate(all_X, axis=0)  # (total_trials, 1, channels, time_samples)
    all_y = np.concatenate(all_y, axis=0)  # (total_trials,)
    all_subjects = np.array(all_subjects)   # (total_trials,)

    # Standardize
    scaler = StandardScaler()
    # Reshape to (trials, channels * time_samples) for scaling
    all_X_reshaped = all_X.reshape(all_X.shape[0], -1)
    all_X_scaled = scaler.fit_transform(all_X_reshaped)
    all_X = all_X_scaled.reshape(all_X.shape[0], 1, target_h, target_w).astype(np.float32)

    return all_X, all_y, all_subjects

###############################################################################
# 3) Dataset with optional augmentation
###############################################################################
class MultiSubjectBCIDataset(Dataset):
    def __init__(self, X, y, augment=False):
        """
        Args:
            X (np.ndarray): EEG data of shape (trials, 1, channels, time_samples)
            y (np.ndarray): Labels of shape (trials,)
            augment (bool): Whether to apply data augmentation
        """
        self.X = X
        self.y = y
        self.augment = augment

        self.max_shift = 10     # shift range
        self.noise_amp = 0.01   # noise amplitude
        self.dropout_rate = 0.05

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x_np = self.X[idx]  # shape: (1, channels, time_samples)
        y = self.y[idx]

        # Convert to torch
        x_t = torch.from_numpy(x_np).float()  # (1, channels, time_samples)

        if self.augment:
            # Random shift in time
            shift = random.randint(-self.max_shift, self.max_shift)
            if shift > 0:
                x_t = torch.cat([x_t[:, :, shift:], 
                                 torch.zeros((1, x_t.size(1), shift))], dim=2)
            elif shift < 0:
                x_t = torch.cat([torch.zeros((1, x_t.size(1), -shift)),
                                 x_t[:, :, :shift]], dim=2)

            # Add small Gaussian noise
            noise = torch.randn_like(x_t) * self.noise_amp
            x_t += noise

            # Random electrode dropout
            num_drop = int(x_t.size(1) * self.dropout_rate)
            drop_indices = torch.randperm(x_t.size(1))[:num_drop]
            x_t[:, drop_indices, :] = 0

        return x_t, torch.tensor(y).long()

###############################################################################
# 4) Window-partitioning & shift for Swin-1D
###############################################################################
def pad_sequence_1d(x, window_size):
    """
    x shape: (B, L, C)
    Zero-pad or resize length to multiple of window_size
    """
    B, L, C = x.shape
    remainder = L % window_size
    if remainder == 0:
        return x, L, 0
    pad_len = window_size - remainder
    pad_vec = torch.zeros(B, pad_len, C, dtype=x.dtype, device=x.device)
    x_padded = torch.cat([x, pad_vec], dim=1)
    return x_padded, L, pad_len

def window_partition_1d(x, window_size):
    """
    x shape: (B, L, C) => (B*nW, window_size, C)
    """
    B, L, C = x.shape
    x_padded, orig_L, pad_len = pad_sequence_1d(x, window_size)
    Bp, Lp, Cp = x_padded.shape
    num_windows = Lp // window_size
    x_padded = x_padded.view(Bp, num_windows, window_size, Cp)
    x_windows = x_padded.reshape(Bp * num_windows, window_size, Cp)
    return x_windows, (orig_L, pad_len, num_windows)

def window_reverse_1d(x_windows, window_size, pad_info):
    """
    Reconstruct from (B*nW, window_size, C) => (B, L, C)
    """
    orig_L, pad_len, num_windows = pad_info
    BnW, WS, C = x_windows.shape
    B = BnW // num_windows
    x_reshaped = x_windows.view(B, num_windows, WS, C)
    x_merged = x_reshaped.reshape(B, num_windows * WS, C)
    if pad_len > 0:
        x_merged = x_merged[:, :orig_L, :]
    return x_merged

def cyclic_shift_1d(x, shift_size):
    """
    Negative roll along dimension=1
    """
    return torch.roll(x, shifts=-shift_size, dims=1)

def cyclic_shift_back_1d(x, shift_size):
    return torch.roll(x, shifts=shift_size, dims=1)

###############################################################################
# 5) Custom QKV Attention & MLP
###############################################################################
class Attention(nn.Module):
    def __init__(self, dim, num_heads, attn_dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.attn_drop = nn.Dropout(attn_dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(attn_dropout)

    def forward(self, x):
        B, L, C = x.shape
        qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, num_heads, L, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, num_heads, L, L)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = (attn @ v).transpose(1, 2).reshape(B, L, C)  # (B, L, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

###############################################################################
# 6) Swin1DBlock & Swin1DTransformer
###############################################################################
class Swin1DBlock(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        window_size=4,
        shift_size=2,
        mlp_hidden=128,
        attn_dropout=0.1
    ):
        super().__init__()
        self.window_size = window_size
        self.shift_size = shift_size

        # LayerNorm & Attention
        self.ln1 = nn.LayerNorm(dim)
        self.attn = Attention(dim=dim, num_heads=num_heads, attn_dropout=attn_dropout)

        # LayerNorm & MLP
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim=dim, hidden_dim=mlp_hidden, dropout=attn_dropout)

    def forward(self, x):
        # x: (B, L, dim)
        if self.shift_size > 0:
            x = cyclic_shift_1d(x, self.shift_size)

        x_windows, pad_info = window_partition_1d(x, self.window_size)

        # Attention + residual
        shortcut = x_windows
        x_windows = self.ln1(x_windows)
        x_windows = self.attn(x_windows)
        x_windows = shortcut + x_windows

        # MLP + residual
        shortcut = x_windows
        x_windows = self.ln2(x_windows)
        x_windows = self.mlp(x_windows)
        x_windows = shortcut + x_windows

        x_merged = window_reverse_1d(x_windows, self.window_size, pad_info)
        if self.shift_size > 0:
            x_merged = cyclic_shift_back_1d(x_merged, self.shift_size)
        return x_merged

class Swin1DTransformer(nn.Module):
    def __init__(
        self,
        dim=64,           # Embedding dimension
        num_layers=3,     # Number of Swin1DBlock layers
        num_heads=4,      # Number of attention heads
        mlp_hidden=128,   # Hidden dimension in MLP
        window_size=4,    # Window size for attention
        attn_dropout=0.1, # Dropout rate for attention
        fc_dropout=0.3    # Dropout rate before final FC
    ):
        super().__init__()
        blocks = []
        for i in range(num_layers):
            shift = window_size // 2 if (i % 2 == 1) else 0
            block = Swin1DBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=shift,
                mlp_hidden=mlp_hidden,
                attn_dropout=attn_dropout
            )
            blocks.append(block)

        self.blocks = nn.ModuleList(blocks)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        # x: (B, L, dim)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        # average pooling over the L dimension
        x = x.mean(dim=1)  # (B, dim)
        return x  # Feature Vector: (B, dim)

###############################################################################
# 7) Positional Embeddings: learnable, sine, or none
###############################################################################
def create_positional_embedding(mode, seq_len, dim):
    """
    Creates positional embeddings

    Args:
        mode (str): 'learnable', 'sine', or 'none'
        seq_len (int): Sequence length
        dim (int): Embedding dimension

    Returns:
        nn.Parameter or None: Positional embedding tensor
    """
    if mode == 'none':
        return None
    elif mode == 'learnable':
        pe = nn.Parameter(torch.zeros(1, seq_len, dim))
        nn.init.trunc_normal_(pe, std=0.02)
        return pe
    elif mode == 'sine':
        # Classic sinusoidal
        pe_np = np.zeros((seq_len, dim))
        for pos in range(seq_len):
            for i in range(0, dim, 2):
                theta = pos / (10000 ** ((2 * i) / dim))
                pe_np[pos, i]   = np.sin(theta)
                if i+1 < dim:
                    pe_np[pos, i+1] = np.cos(theta)
        pe = torch.from_numpy(pe_np).float().unsqueeze(0)  # shape (1, seq_len, dim)
        return nn.Parameter(pe, requires_grad=False)
    else:
        raise ValueError(f"Unsupported positional embedding mode: {mode}")

###############################################################################
# 8) Patch Embedding
###############################################################################
class PatchEmbedding(nn.Module):
    def __init__(self, emb_size=40):
        super().__init__()

        self.shallownet = nn.Sequential(
            nn.Conv2d(1, 40, kernel_size=(1, 25), stride=(1, 1)),  # Temporal convolution
            nn.Conv2d(40, 40, kernel_size=(3, 1), stride=(1, 1)),  # Spatial convolution across electrodes
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, 75), stride=(1, 15)),  # Downsample temporal dimension
            nn.Dropout(p=0.5),
            nn.Conv2d(40, emb_size, kernel_size=(1,1), stride=(1,1))  # Projection to emb_size
        )
        self.rearrange = Rearrange('b e (h) (w) -> b (h w) e')  # Reshape to (B, seq_len, emb_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.shallownet(x)           # [B, emb_size, 1, w]
        x = self.rearrange(x)            # [B, seq_len, emb_size]
        return x

###############################################################################
# 9) Classification Head
###############################################################################
class ClassificationHead(nn.Module):
    def __init__(self, input_dim=64, num_classes=2):
        super(ClassificationHead, self).__init__()
        self.dropout = nn.Dropout(p=0.3)
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        x = self.dropout(x)            # [B, dim]
        logits = self.linear(x)        # [B, num_classes]
        return logits


###############################################################################
# 10) CCST
###############################################################################
class CCST(nn.Module):
    def __init__(
        self,
        emb_size=40,
        swin_embedding_dim=64,
        num_swin_layers=3,
        num_heads=4,
        mlp_size=128,
        fc_dropout=0.3,
        pos_emb_mode='learnable',
        num_classes=2
    ):
        super().__init__()
        self.patch_embedding = PatchEmbedding(emb_size=emb_size)  # [B, seq_len, emb_size]
        self.embedding_projection = nn.Linear(emb_size, swin_embedding_dim)  # [B, seq_len, swin_embedding_dim]
        self.pos_encoding = create_positional_embedding(pos_emb_mode, seq_len=15, dim=swin_embedding_dim)
        self.transformer = Swin1DTransformer(
            dim=swin_embedding_dim,
            num_layers=num_swin_layers,
            num_heads=num_heads,
            mlp_hidden=mlp_size,
            window_size=4,
            attn_dropout=0.1
        )
        self.classification_head = ClassificationHead(input_dim=swin_embedding_dim, num_classes=num_classes)  # [B, num_classes]

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (B, 1, 3, 321)
        
        Returns:
            torch.Tensor: Logits of shape (B, num_classes)
        """
        x = self.patch_embedding(x)         # [B, seq_len, emb_size]
        x = self.embedding_projection(x)    # [B, seq_len, swin_embedding_dim]
        if self.pos_encoding is not None:
            x = x + self.pos_encoding[:, :x.size(1), :]
        features = self.transformer(x)      # [B, swin_embedding_dim]
        logits = self.classification_head(features)  # [B, num_classes]
        return logits

###############################################################################
# 11) Training and Evaluation
###############################################################################
class HybridExP():
    def __init__(self, nsub, pickle_path='/kaggle/input/bcic-iv/2b.pickle', device='cuda:0'):
        super(HybridExP, self).__init__()
        self.batch_size = 72
        self.n_epochs = 100
        self.patience = 10
        self.c_dim = 2  # Number of classes (Left Hand, Right Hand)
        self.lr = 3e-4
        self.betas = (0.9, 0.999)
        self.nSub = nsub
        self.pickle_path = pickle_path
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")

        self.criterion_cls = nn.CrossEntropyLoss().to(self.device)

        self.model = CCST(num_classes=self.c_dim).to(self.device)
        self.model = nn.DataParallel(self.model)
        self.model = self.model.to(self.device)

    # Data augmentation method
    def interaug(self, timg, label):
        """
        Placeholder for additional data augmentation if required
        """
        return timg, label

    # Load data
    def get_source_data(self):
        # Load and preprocess data
        all_X, all_y, all_subjects = load_and_preprocess_data(
            pickle_path=self.pickle_path,
            target_h=3,
            target_w=321,
            l_freq=7,
            h_freq=30
        )

        return all_X, all_y, all_subjects

    # Training method for one fold
    def train_fold(self, train_indices, val_indices, test_indices, train_labels, test_labels):
        # Split data into training and validation
        X_train = self.X[train_indices]
        y_train = self.y[train_indices]
        X_val = self.X[val_indices]
        y_val = self.y[val_indices]
        X_test = self.X[test_indices]
        y_test = self.y[test_indices]

        # Create datasets
        train_ds = MultiSubjectBCIDataset(X_train, y_train, augment=True)
        val_ds = MultiSubjectBCIDataset(X_val, y_val, augment=False)
        test_ds = MultiSubjectBCIDataset(X_test, y_test, augment=False)

        # Create dataloaders
        train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False)
        test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False)

        # Define optimizer and scheduler
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=self.betas)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                         factor=0.5, patience=5, verbose=True)

        best_val_acc = 0.0
        patience_counter = 0
        best_model_state = None

        for epoch in range(self.n_epochs):
            # Training Phase
            self.model.train()
            running_loss = 0.0
            correct = 0
            total = 0

            for img, label in train_loader:
                img = img.to(self.device)  # (B, 1, 3, 321)
                label = label.to(self.device)  # (B,)

                optimizer.zero_grad()
                outputs = self.model(img)  # (B, num_classes)
                loss = self.criterion_cls(outputs, label)
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * img.size(0)
                _, preds = torch.max(outputs, 1)
                correct += (preds == label).sum().item()
                total += label.size(0)

            train_loss = running_loss / total
            train_acc = correct / total

            # Validation Phase
            self.model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for img, label in val_loader:
                    img = img.to(self.device)
                    label = label.to(self.device)

                    outputs = self.model(img)
                    loss = self.criterion_cls(outputs, label)

                    val_loss += loss.item() * img.size(0)
                    _, preds = torch.max(outputs, 1)
                    val_correct += (preds == label).sum().item()
                    val_total += label.size(0)

            val_loss /= val_total
            val_acc = val_correct / val_total

            # Step scheduler
            scheduler.step(val_loss)

            print(f'Epoch {epoch+1}/{self.n_epochs} | '
                  f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | '
                  f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%')

            # Early Stopping
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_model_state = self.model.state_dict()
                patience_counter = 0
            else:
                patience_counter += 1
                # if patience_counter >= self.patience:
                #     print(f"Early stopping triggered at epoch {epoch+1}")
                #     break

        # Load best model
        if best_model_state is not None:
            self.model.load_state_dict(best_model_state)

        # Testing Phase
        self.model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for img, label in test_loader:
                img = img.to(self.device)
                label = label.to(self.device)

                outputs = self.model(img)
                loss = self.criterion_cls(outputs, label)

                test_loss += loss.item() * img.size(0)
                _, preds = torch.max(outputs, 1)
                test_correct += (preds == label).sum().item()
                test_total += label.size(0)

                all_preds.append(preds.cpu().numpy())
                all_labels.append(label.cpu().numpy())

        test_loss /= test_total
        test_acc = test_correct / test_total

        all_preds = np.concatenate(all_preds)
        all_labels = np.concatenate(all_labels)

        print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc*100:.2f}%')

        return test_acc, all_labels, all_preds

    # Training method with LOSO cross-validation
    def train_loso(self):
        # Load data
        self.X, self.y, self.subjects = self.get_source_data()

        logo = LeaveOneGroupOut()
        device = self.device

        test_accuracies = []
        subject_ids = []
        fold_count = 0

        print("\n####################")
        print("Traininig Started")
        print("####################")
        
        for train_idx, test_idx in logo.split(self.X, self.y, groups=self.subjects):
            fold_count += 1
            heldout_subj = self.subjects[test_idx[0]]
            training_subjects = np.unique(self.subjects[train_idx])
            # Convert to 1-based indexing for display
            training_subjects_1based = [s + 1 for s in training_subjects]
            test_subj_1based = heldout_subj + 1
            print(f"\n===== Fold {fold_count}")
            print(f"===== Seed : {42+fold_count}")
            print(f"===== Training Subject : {', '.join(map(str, training_subjects_1based))}")
            print(f"===== Test Subject : {test_subj_1based}\n")

            # Re-seed per fold for reproducibility
            set_seed(42 + fold_count)

            X_train_full, X_test = self.X[train_idx], self.X[test_idx]
            y_train_full, y_test = self.y[train_idx], self.y[test_idx]

            # Further split training data into training and validation (e.g., 90-10)
            X_train, X_val, y_train, y_val = train_test_split(
                X_train_full, y_train_full, test_size=0.1,
                stratify=y_train_full, random_state=42 + fold_count
            )

            # Create datasets
            train_ds = MultiSubjectBCIDataset(X_train, y_train, augment=True)
            val_ds = MultiSubjectBCIDataset(X_val, y_val, augment=False)
            test_ds = MultiSubjectBCIDataset(X_test, y_test, augment=False)

            # Create dataloaders
            train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=True)
            val_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False)
            test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False)

            # Define optimizer and scheduler
            optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=self.betas)
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                             factor=0.5, patience=5, verbose=True)

            best_val_acc = 0.0
            patience_counter = 0
            best_model_state = None

            for epoch in range(self.n_epochs):
                # Training Phase
                self.model.train()
                running_loss = 0.0
                correct = 0
                total = 0

                for img, label in train_loader:
                    img = img.to(self.device)  # (B, 1, 3, 321)
                    label = label.to(self.device)  # (B,)

                    optimizer.zero_grad()
                    outputs = self.model(img)  # (B, num_classes)
                    loss = self.criterion_cls(outputs, label)
                    loss.backward()
                    optimizer.step()

                    running_loss += loss.item() * img.size(0)
                    _, preds = torch.max(outputs, 1)
                    correct += (preds == label).sum().item()
                    total += label.size(0)

                train_loss = running_loss / total
                train_acc = correct / total

                # Validation Phase
                self.model.eval()
                val_loss = 0.0
                val_correct = 0
                val_total = 0

                with torch.no_grad():
                    for img, label in val_loader:
                        img = img.to(self.device)
                        label = label.to(self.device)

                        outputs = self.model(img)
                        loss = self.criterion_cls(outputs, label)

                        val_loss += loss.item() * img.size(0)
                        _, preds = torch.max(outputs, 1)
                        val_correct += (preds == label).sum().item()
                        val_total += label.size(0)

                val_loss /= val_total
                val_acc = val_correct / val_total

                # Step scheduler
                scheduler.step(val_loss)

                print(f'Epoch {epoch+1}/{self.n_epochs} | '
                      f'Training Loss : {train_loss:.4f} | Training Accuracy : {train_acc*100:.2f} % | '
                      f'Validation Loss : {val_loss:.4f} | Validation Accuracy : {val_acc*100:.2f} %')

                # Early Stopping
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    best_model_state = self.model.state_dict()
                    patience_counter = 0
                else:
                    patience_counter += 1
                    # if patience_counter >= self.patience:
                    #     print(f"\n=== Early Stopping Triggered at Epoch {epoch+1}\n")
                    #     break

            # Load best model
            if best_model_state is not None:
                self.model.load_state_dict(best_model_state)

            # Testing Phase
            self.model.eval()
            test_loss = 0.0
            test_correct = 0
            test_total = 0
            all_preds = []
            all_labels = []

            with torch.no_grad():
                for img, label in test_loader:
                    img = img.to(self.device)
                    label = label.to(self.device)

                    outputs = self.model(img)
                    loss = self.criterion_cls(outputs, label)

                    test_loss += loss.item() * img.size(0)
                    _, preds = torch.max(outputs, 1)
                    test_correct += (preds == label).sum().item()
                    test_total += label.size(0)

                    all_preds.append(preds.cpu().numpy())
                    all_labels.append(label.cpu().numpy())

            test_loss /= test_total
            test_acc = test_correct / test_total

            all_preds = np.concatenate(all_preds)
            all_labels = np.concatenate(all_labels)

            print(f'Test Subject : {test_subj_1based} | Test Loss : {test_loss:.4f} | Test Accuracy : {test_acc*100:.2f} %')
            print("\n################################")

            test_accuracies.append(test_acc)
            subject_ids.append(heldout_subj)

        # Summary of LOSO
        avg_test_acc = np.mean(test_accuracies) * 100
        
        print("\n================")
        print("LOSO Summary")
        print("================")
        for i, sid in enumerate(subject_ids):
            print(f"Subject {sid} -> Test Accuracy : {test_accuracies[i]*100:.2f} %")
        print("\n-----------------------------")
        print(f"Average Test Accuracy : {avg_test_acc:.2f} %")
        print("-----------------------------")

###############################################################################
# 12) Main Execution
###############################################################################
def main():
    best = 0
    aver = 0

    # Path to the pickle file
    pickle_path = '/kaggle/input/bcic-iv/2b.pickle'

    # Initialize and train the model
    exp = HybridExP(nsub=1, pickle_path=pickle_path, device='cuda:0')  # 'nsub' will be managed by LOSO
    exp.train_loso()

if __name__ == "__main__":
    main()