# ECoG/EEG Motor Imagery Classification with ATCNet + LoRA

This notebook demonstrates:
1. Loading ECoG/EEG data from .fif files (MNE format)
2. Training ATCNet on base subjects
3. Fine-tuning using LoRA on target subjects

Dataset: High Gamma Motor Imagery dataset (MNE .fif format)

## Install Dependencies

In [None]:
# Install required packages if not available
# !pip install mne pytorch-lightning torchmetrics scipy scikit-learn

## Dataset Class for ECoG/EEG (.fif files)

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import mne
from sklearn.preprocessing import StandardScaler

class MultiSubjectECoGDataset(Dataset):
    def __init__(self, file_list, freq_band=(1, 100), remove_bad=True, scale=True, transform=None, time_range=None):
        """
        PyTorch Dataset for loading multi-subject ECoG/EEG data from .fif files.

        Parameters
        ----------
        file_list : list of str
            List of .fif file paths.
        freq_band : tuple, optional
            Bandpass filter frequency range in Hz, e.g., (1, 100).
        remove_bad : bool, optional
            Remove bad channels marked in the data.
        scale : bool, optional
            Apply StandardScaler per channel.
        transform : callable, optional
            Optional transformation to apply to each sample.
        time_range : tuple or None, optional
            Time window (start_sec, end_sec) relative to event onset to extract. Example: (0.5, 3).
        """
        self.transform = transform
        self.data = []
        self.labels = []
        self.scale = scale
        self.remove_bad = remove_bad
        self.freq_band = freq_band
        self.time_range = time_range

        for file in file_list:
            print(f'Loading and processing {file}')
            self._load_subject_data(file)

        self.data = np.array(self.data)  # (n_trials, n_channels, n_times)
        self.labels = np.array(self.labels)
        print(f"Loaded {len(self.data)} trials from {len(file_list)} subjects.")

    def _load_subject_data(self, file_path):
        # Load raw data
        raw = mne.io.read_raw_fif(file_path, preload=True, verbose=False)
        if self.remove_bad:
            raw.pick_types(ecog=True, eeg=True, exclude='bads')
        
        # Apply bandpass filter
        if self.freq_band:
            raw.filter(self.freq_band[0], self.freq_band[1], fir_design='firwin', verbose=False)

        # Get events and labels
        events, event_id = mne.events_from_annotations(raw, verbose=False)
        print(f"  Found {len(events)} events, event_id mapping: {event_id}")

        # Filter out REST events (optional - keep only MI events)
        mi_event_ids = {}
        for name, code in event_id.items():
            # Uncomment the line below to filter out rest events
            # if 'rest' not in name.lower():
            mi_event_ids[name] = code

        if not mi_event_ids:
            print("  No Motor Imagery events found in this file.")
            return

        print(f"  Using events: {list(mi_event_ids.keys())}")

        # Epoching
        if self.time_range:
            tmin, tmax = self.time_range
        else:
            tmin, tmax = 0, 4  # Default 4s trial length

        epochs = mne.Epochs(
            raw, events, event_id=mi_event_ids,
            tmin=tmin, tmax=tmax,
            baseline=None, preload=True, verbose=False
        )

        X = epochs.get_data()  # (n_trials, n_channels, n_times)
        y = epochs.events[:, -1]  # Label indices

        # Optional scaling per channel
        if self.scale:
            for ch in range(X.shape[1]):
                scaler = StandardScaler()
                X[:, ch, :] = scaler.fit_transform(X[:, ch, :])
        
        # Save data and labels
        for trial, label in zip(X, y):
            self.data.append(trial)
            self.labels.append(label - 1)  # Zero-based labels
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]  # (n_channels, n_times)
        label = self.labels[index]

        # Convert to PyTorch tensor and add channel dimension (1, n_channels, n_times)
        sample = torch.from_numpy(sample).float().unsqueeze(0)
        label = torch.tensor(label).long()

        if self.transform:
            sample = self.transform(sample)

        return sample, label

## Test Dataset Loading

In [None]:
# Test with a subset of files
test_file_list = [
    "/kaggle/input/high-gamma-dataset-fif/Subject 1/0/0-raw.fif",
]

# Create test dataset
test_dataset = MultiSubjectECoGDataset(
    file_list=test_file_list,
    freq_band=None,  # or (2, 40) for bandpass
    time_range=(0.0, 4.0),
    remove_bad=False,
    scale=False
)

# Create test dataloader
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)

# Check batch shape
for samples, labels in test_loader:
    print("Sample shape:", samples.shape)  # (batch_size, 1, n_channels, n_times)
    print("Labels:", labels)
    print("Value range:", samples.min().item(), "to", samples.max().item())
    break

## Visualize Data Quality

In [None]:
import matplotlib.pyplot as plt
from scipy.signal import welch

def compute_mean_psd(batch, fs=250, nperseg=256):
    """
    Compute the mean PSD for each channel across the batch.
    """
    batch_size, _, num_channels, segment_length = batch.shape
    mean_psd_per_channel = []
    freqs = None

    for ch in range(num_channels):
        psds = []
        for i in range(batch_size):
            signal_ch = batch[i, 0, ch, :].cpu().numpy()
            f, pxx = welch(signal_ch, fs=fs, nperseg=nperseg)
            psds.append(pxx)
        psds = np.array(psds)
        mean_psd = psds.mean(axis=0)
        mean_psd_per_channel.append(mean_psd)
        if freqs is None:
            freqs = f
    return freqs, mean_psd_per_channel

def plot_mean_psd(freqs, mean_psd_per_channel):
    """
    Plot the mean PSD curves for all channels.
    """
    plt.figure(figsize=(12, 6))
    for ch, psd in enumerate(mean_psd_per_channel):
        plt.semilogy(freqs, psd, alpha=0.6, label=f'Ch {ch+1}' if ch < 10 else None)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('PSD (V^2/Hz)')
    plt.title('Mean Power Spectral Density per Channel')
    if len(mean_psd_per_channel) <= 10:
        plt.legend(loc='upper right')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Compute and plot PSD
for batch in test_loader:
    samples, labels = batch
    freqs, mean_psd = compute_mean_psd(samples, fs=250, nperseg=256)
    plot_mean_psd(freqs, mean_psd)
    break

In [None]:
# Channel statistics
for batch in test_loader:
    samples, labels = batch
    samples_squeezed = samples.squeeze(1)  # (batch_size, n_channels, n_times)
    
    channel_means = samples_squeezed.mean(dim=(0, 2)).numpy()
    channel_stds = samples_squeezed.std(dim=(0, 2)).numpy()
    
    channels = np.arange(1, samples_squeezed.shape[1] + 1)
    
    plt.figure(figsize=(12, 5))
    plt.errorbar(channels, channel_means, yerr=channel_stds, fmt='o', capsize=5, color='steelblue')
    plt.title("Mean ± STD per Channel (Averaged over Batch & Time)")
    plt.xlabel("Channel")
    plt.ylabel("Amplitude")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    break

## ATCNet Model

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class ATCNet(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 num_classes: int = 4,
                 num_windows: int = 3,
                 num_electrodes: int = 22,
                 conv_pool_size: int = 7,
                 F1: int = 16,
                 D: int = 2,
                 tcn_kernel_size: int = 4,
                 tcn_depth: int = 2,
                 chunk_size: int = 1125):
        super(ATCNet, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.num_windows = num_windows
        self.num_electrodes = num_electrodes
        self.pool_size = conv_pool_size
        self.F1 = F1
        self.D = D
        self.tcn_kernel_size = tcn_kernel_size
        self.tcn_depth = tcn_depth
        self.chunk_size = chunk_size
        F2 = F1*D

        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, F1, (1, int(chunk_size/2+1)),
                      stride=1, padding='same', bias=False),
            nn.BatchNorm2d(F1, False),
            nn.Conv2d(F1, F2, (num_electrodes, 1), padding=0, groups=F1),
            nn.BatchNorm2d(F2, False),
            nn.ELU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout2d(0.1),
            nn.Conv2d(F2, F2, (1, 16), bias=False, padding='same'),
            nn.BatchNorm2d(F2, False),
            nn.ELU(),
            nn.AvgPool2d((1, self.pool_size)),
            nn.Dropout2d(0.1)
        )
        self.__build_model()

    def __build_model(self):
        with torch.no_grad():
            x = torch.zeros(2, self.in_channels,
                            self.num_electrodes, self.chunk_size)
            x = self.conv_block(x)
            x = x[:, :, -1, :]
            x = x.permute(0, 2, 1)
            self.__chan_dim, self.__embed_dim = x.shape[1:]
            self.win_len = self.__chan_dim - self.num_windows + 1

            for i in range(self.num_windows):
                st = i
                end = x.shape[1] - self.num_windows+i+1
                x2 = x[:, st:end, :]

                self.__add_msa(i)
                x2_ = self.get_submodule("msa"+str(i))(x2, x2, x2)[0]
                self.__add_msa_drop(i)
                x2_ = self.get_submodule("msa_drop"+str(i))(x2)
                x2 = torch.add(x2, x2_)

                for j in range(self.tcn_depth):
                    self.__add_tcn((i+1)*j, x2.shape[1])
                    out = self.get_submodule("tcn"+str((i+1)*j))(x2)
                    if x2.shape[1] != out.shape[1]:
                        self.__add_recov(i)
                        x2 = self.get_submodule("re"+str(i))(x2)
                    x2 = torch.add(x2, out)
                    x2 = nn.ELU()(x2)
                x2 = x2[:, -1, :]
                self.__dense_dim = x2.shape[-1]
                self.__add_dense(i)
                x2 = self.get_submodule("dense"+str(i))(x2)

    def __add_msa(self, index: int):
        self.add_module('msa'+str(index), nn.MultiheadAttention(
            embed_dim=self.__embed_dim,
            num_heads=2,
            batch_first=True))

    def __add_msa_drop(self, index):
        self.add_module('msa_drop'+str(index), nn.Dropout(0.3))

    def __add_tcn(self, index: int, num_electrodes: int):
        self.add_module('tcn'+str(index),
                        nn.Sequential(
            nn.Conv1d(num_electrodes, 32,
                      self.tcn_kernel_size, padding='same'),
            nn.BatchNorm1d(32),
            nn.ELU(),
            nn.Dropout(0.3),
            nn.Conv1d(32, 32, self.tcn_kernel_size, padding='same'),
            nn.BatchNorm1d(32),
            nn.ELU(),
            nn.Dropout(0.3))
        )

    def __add_recov(self, index: int):
        self.add_module('re'+str(index),
                        nn.Conv1d(self.win_len, 32, 4, padding='same'))

    def __add_dense(self, index: int):
        self.add_module('dense'+str(index),
                        nn.Linear(self.__dense_dim, self.num_classes))

    def forward(self, x):
        x = self.conv_block(x)
        x = x[:, :, -1, :]
        x = x.permute(0, 2, 1)

        for i in range(self.num_windows):
            st = i
            end = x.shape[1] - self.num_windows+i+1
            x2 = x[:, st:end, :]
            x2_ = self.get_submodule("msa"+str(i))(x2, x2, x2)[0]
            x2_ = self.get_submodule("msa_drop"+str(i))(x2)
            x2 = torch.add(x2, x2_)

            for j in range(self.tcn_depth):
                out = self.get_submodule("tcn"+str((i+1)*j))(x2)
                if x2.shape[1] != out.shape[1]:
                    x2 = self.get_submodule("re"+str(i))(x2)
                x2 = torch.add(x2, out)
                x2 = nn.ELU()(x2)
            x2 = x2[:, -1, :]
            x2 = self.get_submodule("dense"+str(i))(x2)
            if i == 0:
                sw_concat = x2
            else:
                sw_concat = sw_concat.add(x2)

        x = sw_concat/self.num_windows
        return x

## LoRA Implementation

In [None]:
import math

class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank=8, alpha=16):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        self.lora_A = nn.Parameter(torch.zeros(rank, in_dim))
        self.lora_B = nn.Parameter(torch.zeros(out_dim, rank))
        
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
    
    def forward(self, x):
        return (x @ self.lora_A.T @ self.lora_B.T) * self.scaling


class LinearWithLoRA(nn.Module):
    def __init__(self, linear_layer, rank=8, alpha=16):
        super().__init__()
        self.linear = linear_layer
        self.lora = LoRALayer(
            linear_layer.in_features,
            linear_layer.out_features,
            rank=rank,
            alpha=alpha
        )
        
        for param in self.linear.parameters():
            param.requires_grad = False
    
    def forward(self, x):
        return self.linear(x) + self.lora(x)


def add_lora_to_model(model, rank=8, alpha=16, target_modules=['dense']):
    for param in model.parameters():
        param.requires_grad = False
    
    for name, module in model.named_modules():
        should_add_lora = any(target in name for target in target_modules)
        
        if should_add_lora and isinstance(module, nn.Linear):
            parent_name = '.'.join(name.split('.')[:-1])
            child_name = name.split('.')[-1]
            
            if parent_name:
                parent = model.get_submodule(parent_name)
            else:
                parent = model
            
            lora_layer = LinearWithLoRA(module, rank=rank, alpha=alpha)
            setattr(parent, child_name, lora_layer)
            print(f"Added LoRA to: {name}")
    
    return model


def count_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total

## Trainer Class

In [None]:
import logging
from typing import Any, Dict, List, Tuple, Union
import pytorch_lightning as pl
import torchmetrics
from torchmetrics import MetricCollection

_EVALUATE_OUTPUT = List[Dict[str, float]]
log = logging.getLogger('ecog_training')

def classification_metrics(metric_list: List[str], num_classes: int):
    allowed_metrics = ['precision', 'recall', 'f1score', 'accuracy', 'matthews', 'auroc', 'kappa']
    for metric in metric_list:
        if metric not in allowed_metrics:
            raise ValueError(f"{metric} is not allowed.")
    
    metric_dict = {
        'accuracy': torchmetrics.Accuracy(task='multiclass', num_classes=num_classes, top_k=1),
        'precision': torchmetrics.Precision(task='multiclass', average='macro', num_classes=num_classes),
        'recall': torchmetrics.Recall(task='multiclass', average='macro', num_classes=num_classes),
        'f1score': torchmetrics.F1Score(task='multiclass', average='macro', num_classes=num_classes),
        'matthews': torchmetrics.MatthewsCorrCoef(task='multiclass', num_classes=num_classes),
        'auroc': torchmetrics.AUROC(task='multiclass', num_classes=num_classes),
        'kappa': torchmetrics.CohenKappa(task='multiclass', num_classes=num_classes)
    }
    metrics = [metric_dict[name] for name in metric_list]
    return MetricCollection(metrics)


class ClassifierTrainer(pl.LightningModule):
    def __init__(self, model: nn.Module, num_classes: int, lr: float = 1e-3,
                 weight_decay: float = 0.0, devices: int = 1, accelerator: str = "cpu",
                 verbose: bool = True, metrics: List[str] = ["accuracy"]):
        super().__init__()
        self.model = model
        self.num_classes = num_classes
        self.lr = lr
        self.weight_decay = weight_decay
        self.devices = devices
        self.accelerator = accelerator
        self.metrics = metrics
        self.ce_fn = nn.CrossEntropyLoss()
        self.verbose = verbose
        self.init_metrics(metrics, num_classes)

    def init_metrics(self, metrics: List[str], num_classes: int) -> None:
        self.train_loss = torchmetrics.MeanMetric()
        self.val_loss = torchmetrics.MeanMetric()
        self.test_loss = torchmetrics.MeanMetric()
        self.train_metrics = classification_metrics(metrics, num_classes)
        self.val_metrics = classification_metrics(metrics, num_classes)
        self.test_metrics = classification_metrics(metrics, num_classes)

    def fit(self, train_loader: DataLoader, val_loader: DataLoader,
            max_epochs: int = 300, *args, **kwargs) -> Any:
        trainer = pl.Trainer(devices=self.devices, accelerator=self.accelerator,
                             max_epochs=max_epochs, *args, **kwargs)
        return trainer.fit(self, train_loader, val_loader)

    def test(self, test_loader: DataLoader, *args, **kwargs) -> _EVALUATE_OUTPUT:
        trainer = pl.Trainer(devices=self.devices, accelerator=self.accelerator,
                             *args, **kwargs)
        return trainer.test(self, test_loader)

    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        return self.model(x, *args, **kwargs)

    def training_step(self, batch: Tuple[torch.Tensor], batch_idx: int) -> torch.Tensor:
        x, y = batch
        y_hat = self(x)
        loss = self.ce_fn(y_hat, y)
        if self.verbose:
            self.log("train_loss", self.train_loss(loss),
                     prog_bar=True, on_epoch=False, logger=False, on_step=True)
            for i, metric_value in enumerate(self.train_metrics.values()):
                self.log(f"train_{self.metrics[i]}", metric_value(y_hat, y),
                         prog_bar=True, on_epoch=False, logger=False, on_step=True)
        return loss

    def on_train_epoch_end(self) -> None:
        if self.verbose:
            self.log("train_loss", self.train_loss.compute(),
                     prog_bar=False, on_epoch=True, on_step=False, logger=True)
            for i, metric_value in enumerate(self.train_metrics.values()):
                self.log(f"train_{self.metrics[i]}", metric_value.compute(),
                         prog_bar=False, on_epoch=True, on_step=False, logger=True)
        self.train_loss.reset()
        self.train_metrics.reset()

    def validation_step(self, batch: Tuple[torch.Tensor], batch_idx: int) -> torch.Tensor:
        x, y = batch
        y_hat = self(x)
        loss = self.ce_fn(y_hat, y)
        self.val_loss.update(loss)
        self.val_metrics.update(y_hat, y)
        return loss

    def on_validation_epoch_end(self) -> None:
        if self.verbose:
            self.log("val_loss", self.val_loss.compute(),
                     prog_bar=False, on_epoch=True, on_step=False, logger=True)
            for i, metric_value in enumerate(self.val_metrics.values()):
                self.log(f"val_{self.metrics[i]}", metric_value.compute(),
                         prog_bar=False, on_epoch=True, on_step=False, logger=True)
        self.val_loss.reset()
        self.val_metrics.reset()

    def test_step(self, batch: Tuple[torch.Tensor], batch_idx: int) -> torch.Tensor:
        x, y = batch
        y_hat = self(x)
        loss = self.ce_fn(y_hat, y)
        self.test_loss.update(loss)
        self.test_metrics.update(y_hat, y)
        return loss

    def on_test_epoch_end(self) -> None:
        if self.verbose:
            self.log("test_loss", self.test_loss.compute(),
                     prog_bar=False, on_epoch=True, on_step=False, logger=True)
            for i, metric_value in enumerate(self.test_metrics.values()):
                self.log(f"test_{self.metrics[i]}", metric_value.compute(),
                         prog_bar=False, on_epoch=True, on_step=False, logger=True)
        self.test_loss.reset()
        self.test_metrics.reset()

    def configure_optimizers(self):
        parameters = list(self.model.parameters())
        trainable_parameters = list(filter(lambda p: p.requires_grad, parameters))
        optimizer = torch.optim.Adam(trainable_parameters,
                                     lr=self.lr,
                                     weight_decay=self.weight_decay)
        return optimizer

## Phase 1: Base Model Training

Train on base subjects (e.g., Subject 2-14)

In [None]:
# Dataset configuration
DATA_PREFIX = "/kaggle/input/high-gamma-dataset-fif/"

# Base training subjects (all except Subject 1)
base_train_files = [
    f"{DATA_PREFIX}Subject 2/0/0-raw.fif",
    f"{DATA_PREFIX}Subject 2/1/1-raw.fif",
    f"{DATA_PREFIX}Subject 3/0/0-raw.fif",
    f"{DATA_PREFIX}Subject 3/1/1-raw.fif",
    # Add more base subjects...
]

# Validation: Subject 1, session 1
base_val_files = [
    f"{DATA_PREFIX}Subject 1/1/1-raw.fif",
]

print("="*60)
print("PHASE 1: BASE MODEL TRAINING")
print("="*60)

# Create datasets with preprocessing
base_train_dataset = MultiSubjectECoGDataset(
    file_list=base_train_files,
    freq_band=(2, 40),  # Bandpass filter
    time_range=(0.0, 4.0),
    remove_bad=True,
    scale=True  # StandardScaler per channel
)

base_val_dataset = MultiSubjectECoGDataset(
    file_list=base_val_files,
    freq_band=(2, 40),
    time_range=(0.0, 4.0),
    remove_bad=True,
    scale=True
)

# Get dataset info
sample_shape = base_train_dataset[0][0].shape
num_channels = sample_shape[1]
chunk_size = sample_shape[2]
num_classes = len(np.unique(base_train_dataset.labels))

print(f"\nDataset info:")
print(f"  Number of channels: {num_channels}")
print(f"  Chunk size: {chunk_size}")
print(f"  Number of classes: {num_classes}")

# Create dataloaders
batch_size = 32
base_train_loader = DataLoader(base_train_dataset, batch_size=batch_size, shuffle=True)
base_val_loader = DataLoader(base_val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Initialize base model
base_model = ATCNet(
    in_channels=1,
    num_classes=num_classes,
    num_windows=3,
    num_electrodes=num_channels,
    chunk_size=chunk_size,
    F1=16,
    D=2
)

trainable, total = count_parameters(base_model)
print(f"\nBase model parameters:")
print(f"  Trainable: {trainable:,}")
print(f"  Total: {total:,}")
print(f"  Percentage trainable: {100*trainable/total:.2f}%")

In [None]:
# Train base model
device = 'gpu' if torch.cuda.is_available() else 'cpu'

base_trainer = ClassifierTrainer(
    base_model,
    num_classes=num_classes,
    lr=1e-3,
    metrics=["accuracy"],
    accelerator=device
)

print("\nTraining base model...")
base_trainer.fit(base_train_loader, base_val_loader, max_epochs=50)

In [None]:
# Save base model
torch.save(base_model.state_dict(), 'ecog_atcnet_base_model.pt')
print("\nBase model saved to 'ecog_atcnet_base_model.pt'")

## Phase 2: LoRA Fine-tuning

Fine-tune on Subject 1, Session 0

In [None]:
print("\n" + "="*60)
print("PHASE 2: LoRA FINE-TUNING")
print("="*60)

# Load base model
lora_model = ATCNet(
    in_channels=1,
    num_classes=num_classes,
    num_windows=3,
    num_electrodes=num_channels,
    chunk_size=chunk_size,
    F1=16,
    D=2
)

lora_model.load_state_dict(torch.load('ecog_atcnet_base_model.pt'))
print("\nLoaded pre-trained base model")

# Add LoRA adapters
lora_rank = 8
lora_alpha = 16
print(f"\nAdding LoRA adapters (rank={lora_rank}, alpha={lora_alpha})...")
lora_model = add_lora_to_model(lora_model, rank=lora_rank, alpha=lora_alpha, target_modules=['dense'])

trainable, total = count_parameters(lora_model)
print(f"\nLoRA model parameters:")
print(f"  Trainable: {trainable:,}")
print(f"  Total: {total:,}")
print(f"  Percentage trainable: {100*trainable/total:.2f}%")
print(f"  Parameter reduction: {100*(1-trainable/total):.2f}%")

In [None]:
# Fine-tuning data: Subject 1, Session 0
lora_train_files = [
    f"{DATA_PREFIX}Subject 1/0/0-raw.fif",
]

# Validation: Subject 1, Session 1
lora_val_files = [
    f"{DATA_PREFIX}Subject 1/1/1-raw.fif",
]

lora_train_dataset = MultiSubjectECoGDataset(
    file_list=lora_train_files,
    freq_band=(2, 40),
    time_range=(0.0, 4.0),
    remove_bad=True,
    scale=True
)

lora_val_dataset = MultiSubjectECoGDataset(
    file_list=lora_val_files,
    freq_band=(2, 40),
    time_range=(0.0, 4.0),
    remove_bad=True,
    scale=True
)

lora_train_loader = DataLoader(lora_train_dataset, batch_size=32, shuffle=True)
lora_val_loader = DataLoader(lora_val_dataset, batch_size=32, shuffle=False)

In [None]:
# Fine-tune with LoRA
lora_trainer = ClassifierTrainer(
    lora_model,
    num_classes=num_classes,
    lr=5e-4,  # Higher LR for LoRA
    metrics=["accuracy"],
    accelerator=device
)

print("\nFine-tuning with LoRA on Subject 1...")
lora_trainer.fit(lora_train_loader, lora_val_loader, max_epochs=30)

In [None]:
# Save LoRA model
torch.save(lora_model.state_dict(), 'ecog_atcnet_lora_subject1.pt')
print("\nLoRA fine-tuned model saved to 'ecog_atcnet_lora_subject1.pt'")

## Evaluation & Comparison

In [None]:
print("\n" + "="*60)
print("EVALUATION")
print("="*60)

# Test base model (without fine-tuning)
print("\n[Base Model - No Fine-tuning on Subject 1]")
base_test_model = ATCNet(
    in_channels=1,
    num_classes=num_classes,
    num_windows=3,
    num_electrodes=num_channels,
    chunk_size=chunk_size,
    F1=16,
    D=2
)
base_test_model.load_state_dict(torch.load('ecog_atcnet_base_model.pt'))

base_test_trainer = ClassifierTrainer(
    base_test_model,
    num_classes=num_classes,
    metrics=["accuracy"],
    accelerator=device
)
base_results = base_test_trainer.test(lora_val_loader)

# Test LoRA fine-tuned model
print("\n[LoRA Fine-tuned Model]")
lora_results = lora_trainer.test(lora_val_loader)

## Visualize Results

In [None]:
# Parameter comparison visualization
base_trainable, base_total = count_parameters(base_test_model)
lora_trainable, lora_total = count_parameters(lora_model)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Parameter comparison
models = ['Base Model\n(Full Fine-tuning)', 'LoRA Model\n(Adapter Fine-tuning)']
trainable_params = [base_trainable, lora_trainable]
frozen_params = [base_total - base_trainable, lora_total - lora_trainable]

x = range(len(models))
width = 0.5

ax1.bar(x, frozen_params, width, label='Frozen', color='lightgray')
ax1.bar(x, trainable_params, width, bottom=frozen_params, label='Trainable', color='cornflowerblue')
ax1.set_ylabel('Number of Parameters', fontsize=12)
ax1.set_title('Parameter Comparison', fontsize=14, fontweight='bold')
ax1.set_xticks(x)
ax1.set_xticklabels(models)
ax1.legend(fontsize=10)
ax1.grid(axis='y', alpha=0.3)

for i, (train, total) in enumerate([(base_trainable, base_total), (lora_trainable, lora_total)]):
    ax1.text(i, total + max(trainable_params)*0.02, f'{100*train/total:.1f}%\ntrainable', 
             ha='center', va='bottom', fontsize=11, fontweight='bold')

# Training efficiency
reduction = 100 * (1 - lora_trainable / base_trainable)
bars = ax2.bar(['Full Fine-tuning', 'LoRA Fine-tuning'], [100, 100 - reduction], 
               color=['coral', 'seagreen'], width=0.6)
ax2.set_ylabel('Relative Training Cost (%)', fontsize=12)
ax2.set_title('Training Efficiency Gain', fontsize=14, fontweight='bold')
ax2.set_ylim([0, 120])
ax2.grid(axis='y', alpha=0.3)

ax2.text(0, 105, '100%', ha='center', fontsize=12, fontweight='bold')
ax2.text(1, 100 - reduction + 5, f'{100-reduction:.1f}%', ha='center', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig('ecog_lora_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
print(f"LoRA reduces trainable parameters by {reduction:.2f}%")
print(f"Base model: {base_trainable:,} trainable params")
print(f"LoRA model: {lora_trainable:,} trainable params")
print(f"\nMemory saving: ~{reduction:.1f}% during fine-tuning")
print(f"Training speed: ~{reduction/100:.1f}x faster (fewer gradients to compute)")

## Summary

This notebook demonstrated:

1. **ECoG/EEG Data Loading**: Used MNE to load .fif files with proper preprocessing
2. **Base Training**: Trained ATCNet on multiple subjects (Subject 2-N)
3. **LoRA Adaptation**: Added low-rank adapters to classification layers
4. **Subject-Specific Fine-tuning**: Fine-tuned only LoRA parameters on Subject 1
5. **Evaluation**: Compared base vs. LoRA fine-tuned performance

**Key Benefits for ECoG/EEG:**
- Efficient subject-specific adaptation
- Preserves general features learned from multiple subjects
- ~99% parameter reduction during fine-tuning
- Can store multiple subject-specific adapters with minimal storage
- Useful for clinical applications with limited subject data

**Next Steps:**
- Fine-tune LoRA adapters for each subject separately
- Experiment with different frequency bands
- Target different layers (TCN, attention) with LoRA
- Compare with full fine-tuning on more subjects