In [1]:
%load_ext autoreload
%autoreload 2

import random
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
import optuna
from modules.utils import evaluate_model
import matplotlib.pyplot as plt

# dataset related
from modules import CompetitionDataset, load_combined_moabb_data
from torch.utils.data import DataLoader, TensorDataset
from moabb.datasets import BNCI2014_001, PhysionetMI, Cho2017, Weibo2014 # 250 hz

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cpu')

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')
# data_path = '/content/drive/MyDrive/ai_data/eeg_detection/data/mtcaic3'
# model_path = '/content/drive/MyDrive/ai_data/eeg_detection/checkpoints/ssvep/models/ssvep.pth'
# optuna_db_path = '/content/drive/MyDrive/ai_data/eeg_detection/checkpoints/ssvep/optuna/optuna_studies.db'
data_path = './data/mtcaic3'
model_path = './checkpoints/mi/models/the_honored_one.pth'
optuna_db_path = './checkpoints/mi/optuna/the_honored_one.db'

In [3]:
batch_size = 64
# Add this at the beginning of your notebook, after imports
def set_random_seeds(seed=42):
    """Set random seeds for reproducibility"""

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# Call this function before creating datasets and models
set_random_seeds(42)

In [4]:
moabb_train_datasets = [
    PhysionetMI(imagined=True),  # 109 subjects
    Weibo2014(),  # 10 subjects, 64 channels
    CompetitionDataset(),
]
train_val = [CompetitionDataset(split="validation")]

train_datasets = [
    PhysionetMI(imagined=True),  # 109 subjects
    Weibo2014(),  # 10 subjects, 64 channels
    BNCI2014_001(),  # 9 subjects
    CompetitionDataset(),
]
train_val = [CompetitionDataset(split="validation")]

eeg_channels = ["Fz", "C3", "Cz", "C4", "Pz"]
X_train, class_labels_train, domain_labels_train, info_train = load_combined_moabb_data(
    datasets=moabb_train_datasets,
    paradigm_config={
        "channels": eeg_channels,
        "tmin": 1.0,
        "tmax": 4.0,
        "resample": 250,
    },
    subjects_per_dataset={
        "PhysionetMI": list(range(1, 21)),
        "Weibo2014": list(range(1, 11)),
        "CompetitionDataset": list(range(1, 21)),
    },
)


# Load combined data
X_val, class_labels_val, domain_labels_val, info_val = load_combined_moabb_data(
    datasets=train_val,
    paradigm_config={
        "channels": eeg_channels,
        "tmin": 1.0,
        "tmax": 4.0,
        "resample": 250,
    },
)

y_train = np.column_stack([class_labels_train, domain_labels_train])
y_val = np.column_stack([class_labels_val, domain_labels_val])

print(X_train.shape, X_val.shape, y_train.shape, y_val.shape)

X_train_t = torch.from_numpy(X_train).float()  # FloatTensor of shape (N, C, T)
y_train_t = torch.from_numpy(y_train).long()  # LongTensor of shape (N, 2)
y_train_t[:, 1] -= 1

train_dataset = TensorDataset(X_train_t, y_train_t)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

X_val_t = torch.from_numpy(X_val).float()
y_val_t = torch.from_numpy(y_val).long()
y_val_t[:, 1] -= 1

val_dataset = TensorDataset(X_val_t, y_val_t)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


Processing dataset: PhysionetMI
Original subject range: 1 to 20
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']


 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']


 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']


 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']


 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']


 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']


 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")


Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']
Used Annotations descriptions: ['feet', 'hands', 'rest']


 'left_hand': 7
 'right_hand': 8>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")


Adjusted subject range: (1, 20)
Number of trials: 900
Number of subjects: 20

Processing dataset: Weibo2014
Original subject range: 1 to 10


Trial data de-meaned and concatenated with a buffer to create cont data
 'left_hand': 80
 'right_hand': 80>
  warn(f"warnEpochs {epochs}")
Trial data de-meaned and concatenated with a buffer to create cont data
 'left_hand': 80
 'right_hand': 80>
  warn(f"warnEpochs {epochs}")
Trial data de-meaned and concatenated with a buffer to create cont data
 'left_hand': 80
 'right_hand': 80>
  warn(f"warnEpochs {epochs}")
Trial data de-meaned and concatenated with a buffer to create cont data
 'left_hand': 80
 'right_hand': 80>
  warn(f"warnEpochs {epochs}")
Trial data de-meaned and concatenated with a buffer to create cont data
 'left_hand': 80
 'right_hand': 80>
  warn(f"warnEpochs {epochs}")
Trial data de-meaned and concatenated with a buffer to create cont data
 'left_hand': 70
 'right_hand': 70>
  warn(f"warnEpochs {epochs}")
Trial data de-meaned and concatenated with a buffer to create cont data
 'left_hand': 80
 'right_hand': 80>
  warn(f"warnEpochs {epochs}")
Trial data de-meaned and co

Adjusted subject range: (21, 30)
Number of trials: 1580
Number of subjects: 10

Processing dataset: CompetitionDataset
Original subject range: 1 to 20


No stim channel nor annotations found, skipping setting annotations.
No stim channel nor annotations found, skipping setting annotations.
No stim channel nor annotations found, skipping setting annotations.
No stim channel nor annotations found, skipping setting annotations.
No stim channel nor annotations found, skipping setting annotations.
No stim channel nor annotations found, skipping setting annotations.
No stim channel nor annotations found, skipping setting annotations.
No stim channel nor annotations found, skipping setting annotations.
 'left_hand': 6
 'right_hand': 4>
  warn(f"warnEpochs {epochs}")
 'left_hand': 6
 'right_hand': 4>
  warn(f"warnEpochs {epochs}")
 'left_hand': 5
 'right_hand': 5>
  warn(f"warnEpochs {epochs}")
 'left_hand': 6
 'right_hand': 4>
  warn(f"warnEpochs {epochs}")
 'left_hand': 5
 'right_hand': 5>
  warn(f"warnEpochs {epochs}")
 'left_hand': 3
 'right_hand': 7>
  warn(f"warnEpochs {epochs}")
 'left_hand': 8
 'right_hand': 2>
  warn(f"warnEpochs {epo

Adjusted subject range: (31, 50)
Number of trials: 1600
Number of subjects: 20

=== COMBINED DATASET SUMMARY ===
Total trials: 4080
Feature shape: (4080, 5, 750)
Class distribution: [2039 2041]
Subject range: 1 to 50
Total unique subjects: 50

Processing dataset: CompetitionDataset
Original subject range: 1 to 30


 'left_hand': 5
 'right_hand': 5>
  warn(f"warnEpochs {epochs}")
No stim channel nor annotations found, skipping setting annotations.
 'left_hand': 6
 'right_hand': 4>
  warn(f"warnEpochs {epochs}")
No stim channel nor annotations found, skipping setting annotations.
 'left_hand': 4
 'right_hand': 6>
  warn(f"warnEpochs {epochs}")
No stim channel nor annotations found, skipping setting annotations.
 'left_hand': 6
 'right_hand': 4>
  warn(f"warnEpochs {epochs}")
























































































































































Adjusted subject range: (1, 5)
Number of trials: 50
Number of subjects: 5

=== COMBINED DATASET SUMMARY ===
Total trials: 50
Feature shape: (50, 5, 750)
Class distribution: [28 22]
Subject range: 1 to 5
Total unique subjects: 5
(4080, 5, 750) (50, 5, 750) (4080, 2) (50, 2)


In [5]:
# ---------------- Gradient Reversal Layer ---------------- #
class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.alpha, None


class GradientReversal(nn.Module):
    def __init__(self, alpha=1.0):
        super().__init__()
        self.alpha = alpha

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.alpha)


# ---------------- LSTM Head ---------------- #
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, h0=None, c0=None):
        # x: B x seq_len x feat_dim
        if h0 is None or c0 is None:
            h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim, device=x.device)
            c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim, device=x.device)
        out, _ = self.lstm(x, (h0, c0))
        last_time_step = out[:, -1, :]
        return self.fc(last_time_step), last_time_step

# ---------------- EEG Feature Extractor ---------------- #
class EEGFeatureExtractor(nn.Module):
    def __init__(self, n_electrodes, kernLength, F1, D, F2, dropout):
        super().__init__()
        # For input B x C x T, apply 2D convs on [C x T]
        self.block = nn.Sequential(
            # Temporal conv across time
            nn.Conv2d(1, F1, (1, kernLength), padding=(0, kernLength // 2), bias=False),
            nn.BatchNorm2d(F1),
            # Depthwise spatial conv across electrodes
            nn.Conv2d(F1, F1 * D, (n_electrodes, 1), groups=F1, bias=False),
            nn.BatchNorm2d(F1 * D),
            nn.ELU(),
            nn.MaxPool2d((1, 2)),
            nn.Dropout(dropout),
            # Separable conv
            nn.Conv2d(F1 * D, F1 * D, (1, 16), padding=(0, 8), groups=F1 * D, bias=False),
            nn.Conv2d(F1 * D, F2, 1, bias=False),
            nn.BatchNorm2d(F2),
            nn.ELU(),
            nn.MaxPool2d((1, 2)),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x: B x C x T
        x = x.unsqueeze(1)  # B x 1 x C x T
        x = self.block(x)  # B x F2 x 1 x T_sub
        x = x.squeeze(2)  # B x F2 x T_sub
        x = x.permute(0, 2, 1)  # B x T_sub x F2
        return x

# ---------------- DANN SSVEP Classifier ---------------- #
class DANN_SSVEPClassifier(nn.Module):
    def __init__(
        self,
        n_electrodes=16,
        out_dim=4,
        dropout=0.25,
        kernLength=256,
        F1=96,
        D=1,
        hidden_dim=256,
        layer_dim=1,
        grl_alpha=0.0,
        domain_lstm_div=3,
        domain_classes=30,
    ):
        super().__init__()
        self.grl_alpha = grl_alpha
        F2 = F1 * D
        self.feature_extractor = EEGFeatureExtractor(n_electrodes, kernLength, F1, D, F2, dropout)

        self.label_lstm = LSTMModel(F2, hidden_dim, layer_dim, out_dim)

        self.grl_layer = GradientReversal(alpha=grl_alpha)
        self.domain_classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim  // 2, domain_classes)
        )

    def forward(self, x):
        # x: B×C×F×T  (from your CWT)
        seq = self.feature_extractor(x)  # B×T_sub×feat_dim
        class_out, last_time_step = self.label_lstm(seq)  # B×out_dim, Bxhidden_dim

        domain_out = self.domain_classifier(self.grl_layer(last_time_step))
        return class_out, domain_out
        
    def set_grl_alpha(self, new_grl):
        self.grl_layer = GradientReversal(alpha=new_grl)
        

n_electrodes = 5
dummy_x = torch.randn(5, n_electrodes, 10001).to(device)
model = DANN_SSVEPClassifier(
    dropout=0.26211635308091535,
    n_electrodes=n_electrodes,
    out_dim=2,
    domain_classes=50,
    kernLength=8,
    F1=8,
    D=2,
    hidden_dim=256,
    layer_dim=2,
    grl_alpha=0,
    domain_lstm_div=2,
).to(device)

print("worked dude")
model(dummy_x)

worked dude


(tensor([[-0.0085,  0.0561],
         [-0.0099,  0.0590],
         [-0.0142,  0.0615],
         [-0.0067,  0.0608],
         [-0.0022,  0.0637]], grad_fn=<AddmmBackward0>),
 tensor([[-2.2459e-02,  2.3087e-02, -2.4518e-03, -7.6806e-02,  8.6641e-02,
           4.4249e-02, -7.9637e-02, -2.9990e-02, -5.1053e-02,  3.6215e-02,
          -8.1206e-02, -9.2535e-02, -2.2008e-02, -6.0937e-02, -2.5211e-02,
           7.1984e-04, -3.7208e-02,  2.5291e-03, -6.2704e-02,  1.1368e-01,
          -6.9377e-02, -2.5615e-02,  8.4006e-02,  1.7576e-02,  2.6153e-02,
          -2.5731e-02,  2.6348e-02, -6.0475e-02, -2.4734e-02,  3.4250e-02,
          -7.9001e-02,  3.7512e-02, -1.1996e-02,  2.7506e-02, -2.4240e-02,
           8.1834e-02, -2.8964e-02,  3.0461e-02,  3.2779e-02, -3.3053e-02,
          -6.6285e-03, -6.4261e-02, -3.8096e-02, -7.0476e-03,  5.3363e-02,
           8.2495e-02, -3.7276e-02,  5.3448e-02, -5.5734e-02,  7.9097e-02],
         [-2.1824e-02,  2.3650e-02, -5.3392e-04, -7.8088e-02,  8.6703e-02,
 

In [6]:
avg_losses_label = []
avg_losses_domain = []
val_label_accuracies = []
val_domain_accuracies = []
train_label_accuracies = []
train_domain_accuracies = []

In [None]:
try:
    model.load_state_dict(torch.load(model_path, weights_only=True))
except Exception:
    print("skipping model loading...")


opt = torch.optim.Adam(model.parameters(), lr=0.0003746351873334935)
criterion = nn.CrossEntropyLoss()
epochs = 300
domain_loss_weight = 0.5
i = 0

for epoch in range(epochs):
    avg_loss_label = 0
    avg_loss_domain = 0
    correct_label = 0
    correct_domain = 0
    total = 0
    
    new_grl_alpha = 2.0 / (1.0 + np.exp(-10 * (epoch/epochs))) - 1.0
    model.set_grl_alpha(new_grl_alpha * 0)
    model.train()

    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device).to(torch.int64) # shape: [Bx2], 0: label, 1: domain
        y_labels = y[:, 0]
        y_subj = y[:, 1]

        y_pred_labels, y_pred_domain = model(x)

        loss_label = criterion(y_pred_labels, y_labels)
        loss_domain = criterion(y_pred_domain, y_subj)
        loss = loss_label + domain_loss_weight * loss_domain
        
        opt.zero_grad()
        loss.backward()
        opt.step()

        avg_loss_label += loss.item()
        avg_loss_domain += loss_domain.item()

        # Accuracy calculation
        _, pred_labels = torch.max(y_pred_labels, 1)
        _, pred_domains = torch.max(y_pred_domain, 1)
        correct_label += (pred_labels == y_labels).sum().item()
        correct_domain += (pred_domains == y_subj).sum().item()
        total += y_labels.size(0)

        avg_loss_label /= len(train_loader)
        avg_loss_domain /= len(train_loader)
        avg_losses_label.append(avg_loss_label)
        avg_losses_domain.append(avg_loss_domain)
        train_label_acc = 100.0 * correct_label / total
        train_domain_acc = 100.0 * correct_domain / total
        train_label_accuracies.append(train_label_acc)
        train_domain_accuracies.append(train_domain_acc)

        i += 1
        if i % 20 == 0:
            i = 0
            label_evaluation, domain_evaluation = evaluate_model(model, val_loader, device)
            val_label_accuracies.append(label_evaluation)
            val_domain_accuracies.append(domain_evaluation)
            model.cpu()
            torch.save(model.state_dict(), model_path)
            model.to(device)
            print(
                f"Epoch {epoch+1:2d}/{epochs} | "
                f"Label Loss: {avg_loss_label:.4f} | "
                f"Domain Loss: {avg_loss_domain:.4f} | "
                f"Train Label Acc: {train_label_acc:.2f}% | "
                f"Train Domain Acc: {train_domain_acc:.2f}% | "
                f"Val Label Acc: {label_evaluation*100:.2f}% | "
                f"Val Domain Acc: {domain_evaluation*100:.2f}% | "
                f"LR: {opt.param_groups[0]['lr']:.6f} | "
                f"GRL: {new_grl_alpha:.6f}"
            )
            avg_loss_label = 0
            avg_loss_domain = 0
            correct_label = 0
            correct_domain = 0
            total = 0


skipping model loading...
Epoch  1/300 | Label Loss: 0.0426 | Domain Loss: 0.0628 | Train Label Acc: 51.09% | Train Domain Acc: 5.08% | Val Label Acc: 44.00% | Val Domain Acc: 0.00% | LR: 0.000375 | GRL: 0.000000
Epoch  1/300 | Label Loss: 0.0422 | Domain Loss: 0.0624 | Train Label Acc: 52.89% | Train Domain Acc: 3.67% | Val Label Acc: 60.00% | Val Domain Acc: 0.00% | LR: 0.000375 | GRL: 0.000000
Epoch  1/300 | Label Loss: 0.0425 | Domain Loss: 0.0628 | Train Label Acc: 48.36% | Train Domain Acc: 3.52% | Val Label Acc: 44.00% | Val Domain Acc: 0.00% | LR: 0.000375 | GRL: 0.000000
Epoch  2/300 | Label Loss: 0.0422 | Domain Loss: 0.0619 | Train Label Acc: 47.43% | Train Domain Acc: 4.78% | Val Label Acc: 56.00% | Val Domain Acc: 0.00% | LR: 0.000375 | GRL: 0.016665


In [1]:
epochs = range(1, len(avg_losses_label) + 1)
val_epochs = range(0, len(val_label_accuracies) * 5, 5)  # if you log every 5 epochs

plt.figure(figsize=(16, 10))

# 1. Label Loss
plt.subplot(2, 2, 1)
plt.plot(epochs, avg_losses_label, label='Train Label Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Label Loss')
plt.legend()

# 2. Domain Loss
plt.subplot(2, 2, 2)
plt.plot(epochs, avg_losses_domain, label='Train Domain Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Domain Loss')
plt.legend()

# 3. Label Accuracy
plt.subplot(2, 2, 3)
plt.plot(epochs, train_label_accuracies, label='Train Label Acc')
plt.plot(val_epochs, [v*100 for v in val_label_accuracies], label='Val Label Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Label Accuracy')
plt.legend()

# 4. Domain Accuracy
plt.subplot(2, 2, 4)
plt.plot(epochs, train_domain_accuracies, label='Train Domain Acc')
plt.plot(val_epochs, [v*100 for v in val_domain_accuracies], label='Val Domain Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Domain Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

NameError: name 'avg_losses_label' is not defined

In [None]:
class CustomTrainer(Trainer):
    # Called by _objective during an Optuna trial
    def prepare_trial_run(self):
        assert isinstance(self.trial, optuna.Trial), "Trial not set!"

        # 1) Hyperparameter search space
        # Data params
        window_length = self.trial.suggest_categorical("window_length", [128, 256, 640])
        batch_size    = self.trial.suggest_categorical("batch_size", [32, 64])

        # Model extractor params (based on EEG3D+MDD)
        kernLength = self.trial.suggest_categorical("kernLength", [8, 16, 32, 64, 128])
        F1         = self.trial.suggest_categorical("F1", [8, 16, 32, 64])
        D          = self.trial.suggest_categorical("D", [1, 2, 4])
        F2         = self.trial.suggest_categorical("F2", [16, 32, 64, 128])
        dropout    = self.trial.suggest_float("dropout", 0.1, 0.5)
        # MDD head params
        hidden_dim = self.trial.suggest_categorical("hidden_dim", [64, 128, 256])
        layer_dim  = self.trial.suggest_int("layer_dim", 1, 3)
        # MDD alignment weight
        lambda_mdd = self.trial.suggest_float("lambda_mdd", 0.1, 1.0)

        # Optimizer
        lr = self.trial.suggest_float("lr", 1e-5, 1e-3, log=True)

        # 2) Prepare data
        super()._prepare_data(is_trial=True,
                              batch_size=batch_size,
                              window_length=window_length)

        # 3) Build MDD model
        extractor_kwargs = dict(
            n_electrodes=self.data.num_channels,
            kernLength=kernLength,
            F1=F1,
            D=D,
            F2=F2,
            dropout=dropout,
        )
        lstm_kwargs = dict(
            input_dim=F2 * (self.data.freq_bins // D),  # adjust if freq_bins variable
            hidden_dim=hidden_dim,
            layer_dim=layer_dim,
            output_dim=self.data.num_classes,
        )
        self.model = MDD_SSVEPClassifier(
            extractor_kwargs=extractor_kwargs,
            lstm_kwargs=lstm_kwargs
        ).to(self.device)
        self.lambda_mdd = lambda_mdd

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    # Called by train() for final run
    def prepare_final_run(self):
        study      = self._get_study()
        best       = study.best_params
        # Data
        super()._prepare_data(is_trial=False)
        # Build final model
        extractor_kwargs = dict(
            n_electrodes=self.data.num_channels,
            kernLength=best["kernLength"],
            F1=best["F1"],
            D=best["D"],
            F2=best["F2"],
            dropout=best["dropout"],
        )
        lstm_kwargs = dict(
            input_dim=best["F2"] * (self.data.freq_bins // best["D"]),
            hidden_dim=best["hidden_dim"],
            layer_dim=best["layer_dim"],
            output_dim=self.data.num_classes,
        )
        self.model = MDD_SSVEPClassifier(
            extractor_kwargs=extractor_kwargs,
            lstm_kwargs=lstm_kwargs
        ).to(self.device)
        # load weights?
        try:
            self.model.load_state_dict(torch.load(self.model_path))
            print(f"Loaded weights from {self.model_path}")
        except:
            print("No checkpoint found, training from scratch.")

        lr = best["lr"]
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='max', factor=0.5, patience=20,
            threshold=1e-4, threshold_mode='rel',
            cooldown=0, min_lr=1e-6
        )

trainer = CustomTrainer(
        data_path=data_path,
        optuna_db_path=optuna_db_path,
        model_path=model_path,
        train_epochs=500, # Final training epochs
        tune_epochs=50,   # Epochs per trial
        optuna_n_trials=50,
        task="mi",
        eeg_channels=eeg_channels,
        data_fraction=0.4
    )

In [None]:
delete_existing = False
trainer.optimize(delete_existing)

In [None]:
trainer.train()

In [None]:
trainer._prepare_training(False)
trainer.model.eval()
f"test accuracy: {evaluate_model(trainer.model, trainer.eval_loader, device)}"