In [11]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.functional as F
from torch.autograd import Function
import optuna
from modules import Trainer
from modules.competition_dataset import EEGDataset
from modules.utils import evaluate_model
import random
import numpy as np
from torch.utils.data import ConcatDataset, random_split, DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


device(type='cuda')

In [12]:
# from google.colab import drive
# drive.mount('/content/drive')
# data_path = '/content/drive/MyDrive/ai_data/eeg_detection/data/mtcaic3'
# model_path = '/content/drive/MyDrive/ai_data/eeg_detection/checkpoints/ssvep/models/ssvep.pth'
# optuna_db_path = '/content/drive/MyDrive/ai_data/eeg_detection/checkpoints/ssvep/optuna/optuna_studies.db'
data_path = './data/mtcaic3'
model_path = './checkpoints/mi/models/the_honored_one.pth'
optuna_db_path = './checkpoints/mi/optuna/the_honored_one.db'
eeg_channels = [
    "C3", # 2296.15
    "PZ", # 1744.43
    "C4", # 1556.46
    # "OZ" # 444.98
    # "PO7" # 381.63
    # "PO8" # 275.78
    # "CZ" # 200.43
    # "FZ" # 111.51
]

In [13]:
# Add this at the beginning of your notebook, after imports
def set_random_seeds(seed=42):
    """Set random seeds for reproducibility"""

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# Call this function before creating datasets and models
set_random_seeds(42)

In [20]:
window_length = 128 * 2 # ensure divisble by 64 the kernel size
stride = window_length // 3
batch_size = 64

dataset_train = EEGDataset(
    data_path,
    window_length=window_length,
    stride=stride,
    domain="time",
    data_fraction=0.3,
    hardcoded_mean=True,
    task="mi",
    eeg_channels=eeg_channels,
)

dataset_val = EEGDataset(
    data_path=data_path,
    window_length=window_length,
    stride=stride,
    split='validation',
    read_labels=True,
    hardcoded_mean=True,
    data_fraction=1,
    task="mi",
    eeg_channels=eeg_channels,
)

train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(dataset_val,   batch_size=batch_size)

task: mi, split: train, domain: time, data_fraction: 0.3
Using 30.0% of data: 720/720 samples
skipped: 0/720
data shape: (12662, 3, 256), mean shape: (1, 3, 1)
task: mi, split: validation, domain: time, data_fraction: 1
skipped: 0/50
data shape: (857, 3, 256), mean shape: (1, 3, 1)


In [21]:
# ---------------- Gradient Reversal Layer ---------------- #
class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.alpha, None

class GradientReversal(nn.Module):
    def __init__(self, alpha=1.0):
        super().__init__()
        self.alpha = alpha

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.alpha)

# ---------------- LSTM Head ---------------- #
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, h0=None, c0=None):
        # x: B x seq_len x feat_dim
        if h0 is None or c0 is None:
            h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim, device=x.device)
            c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim, device=x.device)

        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

# ---------------- EEG Feature Extractor ---------------- #
class EEGFeatureExtractor(nn.Module):
    def __init__(self, n_electrodes, kernLength, F1, D, F2, dropout):
        super().__init__()
        # For input B x C x T, apply 2D convs on [C x T]
        self.block = nn.Sequential(
            # Temporal conv across time
            nn.Conv2d(1, F1, (1, kernLength), padding=(0, kernLength//2), bias=False),
            nn.BatchNorm2d(F1),
            # Depthwise spatial conv across electrodes
            nn.Conv2d(F1, F1 * D, (n_electrodes, 1), groups=F1, bias=False),
            nn.BatchNorm2d(F1 * D),
            nn.ELU(),
            nn.MaxPool2d((1, 2)),
            nn.Dropout(dropout),
            # Separable conv
            nn.Conv2d(F1 * D, F1 * D, (1, 16), padding=(0, 8), groups=F1 * D, bias=False),
            nn.Conv2d(F1 * D, F2, 1, bias=False),
            nn.BatchNorm2d(F2),
            nn.ELU(),
            nn.MaxPool2d((1, 2)),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # x: B x C x T
        x = x.unsqueeze(1)  # B x 1 x C x T
        x = self.block(x)   # B x F2 x 1 x T_sub
        x = x.squeeze(2)    # B x F2 x T_sub
        x = x.permute(0, 2, 1)  # B x T_sub x F2
        return x

# ---------------- DANN SSVEP Classifier ---------------- #
class DANN_SSVEPClassifier(nn.Module):
    def __init__(self,
                 n_electrodes=16,
                 out_dim=4,
                 dropout=0.25,
                 kernLength=256,
                 F1=96,
                 D=1,
                 F2=96,
                 hidden_dim=100,
                 layer_dim=1,
                 grl_alpha=1.0,
                 domain_classes=2):
        super().__init__()
        self.feature_extractor = EEGFeatureExtractor(
            n_electrodes, kernLength, F1, D, F2, dropout
        )
        feat_dim = F2
        # label predictor (LSTM over time)
        self.label_lstm = LSTMModel(feat_dim, hidden_dim, layer_dim, out_dim)
        # domain predictor
        self.grl = GradientReversal(alpha=grl_alpha)
        self.domain_fc = nn.Sequential(
            nn.Linear(feat_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, domain_classes)
        )

    def forward(self, x):
        # x: B x C x T
        feat_seq = self.feature_extractor(x)        # B x T_sub x feat_dim
        class_out = self.label_lstm(feat_seq)
        # domain prediction from time-averaged features
        feat_avg = feat_seq.mean(dim=1)             # B x feat_dim
        rev_feat = self.grl(feat_avg)
        domain_out = self.domain_fc(rev_feat)
        return class_out, domain_out


dummy_x = torch.randn(5, 3, 256).to(device)
model = DANN_SSVEPClassifier(
    dropout=0.26211635308091535,
    n_electrodes=3,
    out_dim=2,
    domain_classes=30,
    kernLength=256,
    F1 = 32,
    D = 3,
    F2 = 32,
    hidden_dim=256,
    layer_dim=3,
).to(device)

model(dummy_x)

(tensor([[-0.0656, -0.0231],
         [-0.0680, -0.0245],
         [-0.0684, -0.0269],
         [-0.0686, -0.0246],
         [-0.0668, -0.0260]], device='cuda:0', grad_fn=<AddmmBackward0>),
 tensor([[-1.8820e-01,  1.5262e-01,  3.2191e-01, -5.9028e-02, -2.9754e-01,
          -2.4035e-02,  4.4117e-02,  2.7417e-01, -1.2048e-01,  2.5109e-02,
          -6.9529e-02, -3.7081e-02,  1.7412e-01,  1.2067e-01, -5.8845e-02,
           5.6947e-01,  6.5318e-02, -2.8768e-02,  4.3703e-04,  1.2804e-01,
          -4.7969e-01,  1.1280e-01, -3.5237e-02, -2.3722e-01, -2.2489e-01,
          -1.3670e-01, -2.1485e-01, -2.5705e-01,  8.1293e-02,  1.4516e-01],
         [-9.8887e-02,  8.7126e-02,  1.8121e-01, -2.8861e-01, -9.2050e-02,
          -8.3485e-02, -2.1213e-01,  1.6412e-01, -1.4923e-01, -2.7404e-01,
          -1.1194e-01, -2.2734e-01, -5.7766e-04,  2.1102e-01,  2.2843e-01,
           6.4041e-01, -2.7889e-02,  5.1864e-02,  4.2334e-01,  1.8291e-01,
           3.2912e-02, -6.5786e-02, -2.6322e-02, -1.1465e-0

In [None]:
avg_losses_label = []
avg_losses_domain = []
val_label_accuracies = []
val_domain_accuracies = []
train_label_accuracies = []
train_domain_accuracies = []

In [23]:
try:
    model.load_state_dict(torch.load(model_path, weights_only=True))
except Exception:
    print("skipping model loading...")


opt = torch.optim.Adam(model.parameters(), lr=0.0003746351873334935)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt,
    mode="min",  # “min” if you want to reduce lr when the quantity monitored has stopped decreasing
    factor=0.5,  # new_lr = lr * factor
    patience=20,  # number of epochs with no improvement after which lr will be reduced
    threshold=1e-4,  # threshold for measuring the new optimum, to only focus on significant changes
    threshold_mode="rel",  # `'rel'` means compare change relative to best value. Could use `'abs'`.
    cooldown=0,  # epochs to wait before resuming normal operation after lr has been reduced
    min_lr=1e-6,  # lower bound on the lr
)

epochs = 4000

for epoch in range(epochs):
    avg_loss_label = 0
    avg_loss_domain = 0
    correct_label = 0
    correct_domain = 0
    total = 0
    model.train()

    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device).to(torch.int64) # shape: [Bx2], 0: label, 1: domain
        y_labels = y[:, 0]
        y_subj = y[:, 1]

        y_pred_labels, y_pred_domain = model(x)

        loss_label = criterion(y_pred_labels, y_labels)
        loss_domain = criterion(y_pred_domain, y_subj)
        loss = loss_label + loss_domain
        
        opt.zero_grad()
        loss.backward()
        opt.step()

        avg_loss_label += loss_label.item()
        avg_loss_domain += loss_domain.item()

        # Accuracy calculation
        _, pred_labels = torch.max(y_pred_labels, 1)
        _, pred_domains = torch.max(y_pred_domain, 1)
        correct_label += (pred_labels == y_labels).sum().item()
        correct_domain += (pred_domains == y_subj).sum().item()
        total += y_labels.size(0)

    avg_loss_label /= len(train_loader)
    avg_loss_domain /= len(train_loader)
    avg_losses_label.append(avg_loss_label)
    avg_losses_domain.append(avg_loss_domain)
    train_label_acc = 100.0 * correct_label / total
    train_domain_acc = 100.0 * correct_domain / total
    train_label_accuracies.append(train_label_acc)
    train_domain_accuracies.append(train_domain_acc)
    # scheduler.step(avg_loss_label)

    if epoch % 5 == 0:
        label_evaluation, domain_evaluation = evaluate_model(model, val_loader, device)
        val_label_accuracies.append(label_evaluation)
        val_domain_accuracies.append(domain_evaluation)
        model.cpu()
        torch.save(model.state_dict(), model_path)
        model.to(device)
        print(f"{epoch}: avg_loss_label: {avg_loss_label:.3F}, avg_loss_domain: {avg_loss_domain:.3F}, train_label_acc: {train_label_acc:.2f}%, train_domain_acc: {train_domain_acc:.2f}%, label_evaluation: {(label_evaluation*100):.2f}, domain_evaluation: {(domain_evaluation*100):.2f}, lr: {opt.param_groups[0]['lr']}")

0: avg_loss_label: 0.687, avg_loss_domain: 3.393, train_label_acc: 53.59%, train_domain_acc: 4.26%, label_evaluation: 44.22, domain_evaluation: 0.00, lr: 0.0003746351873334935
5: avg_loss_label: 0.679, avg_loss_domain: 3.383, train_label_acc: 56.49%, train_domain_acc: 4.45%, label_evaluation: 46.44, domain_evaluation: 0.00, lr: 0.0003746351873334935
10: avg_loss_label: 0.675, avg_loss_domain: 3.383, train_label_acc: 56.87%, train_domain_acc: 4.48%, label_evaluation: 44.34, domain_evaluation: 0.00, lr: 0.0003746351873334935
15: avg_loss_label: 0.673, avg_loss_domain: 3.382, train_label_acc: 56.79%, train_domain_acc: 4.48%, label_evaluation: 44.92, domain_evaluation: 0.00, lr: 0.0003746351873334935
20: avg_loss_label: 0.672, avg_loss_domain: 3.379, train_label_acc: 57.12%, train_domain_acc: 4.98%, label_evaluation: 40.61, domain_evaluation: 0.00, lr: 0.0003746351873334935
25: avg_loss_label: 0.668, avg_loss_domain: 3.379, train_label_acc: 58.90%, train_domain_acc: 4.72%, label_evaluation

KeyboardInterrupt: 

In [None]:
class CustomTrainer(Trainer):
    # This method is called by _objective during an Optuna trial
    def prepare_trial_run(self):
        assert isinstance(self.trial, optuna.Trial), "Trial not set!"

        # 1. Define Hyperparameters for this trial
        #    a. Data/Loader parameters
        window_length = self.trial.suggest_categorical("window_length", [128, 256, 640]) # e.g. 64*2, 64*4, 64*10
        batch_size = self.trial.suggest_categorical("batch_size", [32, 64])

        #    b. Model architecture parameters
        kernLength = self.trial.suggest_categorical("kernLength", [64, 128, 256])
        F1 = self.trial.suggest_categorical("F1", [8, 16, 32])
        D = self.trial.suggest_categorical("D", [1, 2, 3])
        F2 = self.trial.suggest_categorical("F2", [16, 32, 64]) # F2 must be F1 * D
        hidden_dim = self.trial.suggest_categorical("hidden_dim", [64, 128, 256])
        layer_dim = self.trial.suggest_categorical("layer_dim", [1, 2, 3])
        dropout = self.trial.suggest_float("dropout", 0.1, 0.6)
        
        #    c. Optimizer parameters
        lr = self.trial.suggest_float("lr", 1e-4, 1e-2, log=True)

        # 2. Prepare the data using these parameters
        super()._prepare_data(is_trial=True, batch_size=batch_size, window_length=window_length)
        
        assert self.dataset is not None, "Dataset was not created correctly"
        n_electrodes = self.dataset.datasets[0].data[0].shape[0] # Get shape from underlying dataset

        # 3. Build the model and optimizer
        self.model = SSVEPClassifier(
            n_electrodes=n_electrodes, # Use value from data
            dropout=dropout,
            kernLength=kernLength,
            F1=F1,
            D=D,
            F2=F1 * D, # F2 is dependent on F1 and D
            hidden_dim=hidden_dim,
            layer_dim=layer_dim,
        ).to(self.device)
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    # This method is called by train() for the final run
    def prepare_final_run(self):
        # 1. Get the best hyperparameters from the completed study
        study = self._get_study()
        best_params = study.best_params
        
        # 2. Prepare data using the best params
        super()._prepare_data(is_trial=False) # is_trial=False handles getting params from study
        
        assert self.dataset is not None, "Dataset was not created correctly"
        n_electrodes = self.dataset.datasets[0].data[0].shape[0]

        # 3. Build the final model and optimizer
        self.model = SSVEPClassifier(
            n_electrodes=n_electrodes,
            out_dim=2,
            dropout=best_params["dropout"],
            kernLength=best_params["kernLength"],
            F1=best_params["F1"],
            D=best_params["D"],
            F2=best_params["F1"] * best_params["D"],
            hidden_dim=best_params["hidden_dim"],
            layer_dim=best_params["layer_dim"],
        ).to(self.device)
        
        # Optional: Load pre-existing weights if you are resuming
        try:
            self.model.load_state_dict(torch.load(self.model_path))
            print(f"Loaded existing model weights from {self.model_path}")
        except Exception:
            print(f"No existing model weights found at {self.model_path}. Training from scratch.")
        
        lr = 0.00018182233882257615 # best_params["lr"]
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='max',        # “min” if you want to reduce lr when the quantity monitored has stopped decreasing
            factor=0.5,        # new_lr = lr * factor
            patience=20,        # number of epochs with no improvement after which lr will be reduced
            threshold=1e-4,    # threshold for measuring the new optimum, to only focus on significant changes
            threshold_mode='rel', # `'rel'` means compare change relative to best value. Could use `'abs'`.
            cooldown=0,        # epochs to wait before resuming normal operation after lr has been reduced
            min_lr=1e-6,       # lower bound on the lr
        )

trainer = CustomTrainer(
        data_path=data_path,
        optuna_db_path=optuna_db_path,
        model_path=model_path,
        train_epochs=500, # Final training epochs
        tune_epochs=50,   # Epochs per trial
        optuna_n_trials=50,
        task="mi",
        eeg_channels=eeg_channels,
        data_fraction=0.4
    )

In [None]:
delete_existing = False
trainer.optimize(delete_existing)

In [None]:
trainer.train()

In [None]:
trainer._prepare_training(False)
trainer.model.eval()
f"test accuracy: {evaluate_model(trainer.model, trainer.eval_loader, device)}"