In [1]:
import mne
import numpy as np
import pandas as pd
import seaborn as sea
import matplotlib.pyplot as plt
import os
from pathlib import Path
import random
import joblib 
import optuna

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import Dataset, DataLoader

from scipy.spatial import Delaunay
from scipy.interpolate import LinearNDInterpolator

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import GroupKFold

import warnings
warnings.filterwarnings('ignore')

In [2]:
SEED = 3126
EPOCHS = 40
EARLY_STOPPING_PATIENCE = 20

TIME_WINDOW = (0.1,0.5)
GRID_SIZE = 32
N_TIME_SLICES = 16
LEARNING_RATE = 7.5e-4
WEIGHT_DECAY = .001

In [3]:
def set_seed(seed):
    random.seed(seed)               
    np.random.seed(seed)            
    torch.manual_seed(seed)         
    torch.cuda.manual_seed(seed)    
    torch.cuda.manual_seed_all(seed) 

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def worker_init_fn(worker_id):
    np.random.seed(SEED + worker_id)
    random.seed(SEED + worker_id)

set_seed(SEED)

In [4]:
class MEGDataset(Dataset):
    """
    Efficient MEG → 2D Topographic Image Dataset
    Uses precomputed Delaunay triangulation + LinearNDInterpolator for fast interpolation.
    """

    def __init__(self, X, y, epochs, scaler=None,
                 time_window=TIME_WINDOW, grid_size=GRID_SIZE, n_time_slices=N_TIME_SLICES):

        self.X = X
        self.y = y
        self.times = epochs.times
        self.info  = epochs.info
        self.grid_size = grid_size
        self.n_time_slices = n_time_slices

        ## Scale channels in data
        if scaler is None:
            ch_median = np.median(X, axis=(0,2))
            ch_iqr    = np.subtract(*np.percentile(X, [75, 25], axis=(0,2))) + 1e-6
            self.scaler   = (ch_median, ch_iqr)
        else: self.scaler = scaler
        ch_median, ch_iqr = self.scaler
        self.X = (self.X - ch_median[None,:,None]) / ch_iqr[None,:,None]

        # Find time window
        t0, t1 = time_window
        start = np.argmin(np.abs(self.times - t0))
        end   = np.argmin(np.abs(self.times - t1))
        self.slice_times = np.linspace(start, end-1, n_time_slices).astype(int)
        # Get sensor positions in 2D
        positions_3d = np.array([ch['loc'][:3] for ch in self.info['chs']])
        self.pos2d   = positions_3d[:, :2]

        # Get grid coordinates
        gx = np.linspace(self.pos2d[:,0].min(), self.pos2d[:,0].max(), grid_size)
        gy = np.linspace(self.pos2d[:,1].min(), self.pos2d[:,1].max(), grid_size)
        grid_x, grid_y = np.meshgrid(gx, gy)
        self.grid_points = np.c_[grid_x.ravel(), grid_y.ravel()].copy(order='C')
        # Gets triangulation
        self.triangulation = Delaunay(self.pos2d)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        trial = self.X[idx]     # shape: (channels, time)
        label = self.y[idx]
        topo_imgs = []

        # interpolation of each time slice
        for t in self.slice_times:
            channels = trial[:, t]  # (C,)
            fill_val = np.median(channels)
            interpolator = LinearNDInterpolator(self.triangulation, channels, fill_value=fill_val)
            topo_img = interpolator(self.grid_points).reshape(self.grid_size, self.grid_size)
            # topo_img = (topo_img - topo_img.mean(axis=0))/(topo_img.std(axis=0) + 1e-6)
            topo_imgs.append(np.nan_to_num(topo_img))

        topo_imgs = np.stack(topo_imgs, axis=0)  # (n_time_slices, grid_size, grid_size)
        return torch.tensor(topo_imgs, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

In [5]:
# X, y, groups, first_epochs, _ = load_all_subjects_data("/kaggle/input/the-imagine-decoding-challenge/train/train", False, "localizer")
# data = MEGDataset(X,y,first_epochs)
# sample,_ = data[0]

# for k in range(N_TIME_SLICES):
#     sea.heatmap(sample[k])
#     # sea.heatmap((sample[k]-sample[k].mean(axis=0))/(sample[k].std(axis=0) + 1e-6))
#     plt.show()

In [6]:
class Model(nn.Module):
    def __init__(self, in_channels=N_TIME_SLICES, hidden_channels=32, dropout_p=.2, n_classes=10):
        super(Model, self).__init__()
        

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU())
            
        def linear_block(in_dim, out_dim):
            return nn.Sequential(
                nn.Linear(in_dim, out_dim, bias=False),
                nn.BatchNorm1d(out_dim),
                nn.ReLU(),
                nn.Dropout(dropout_p))
            
        self.conv = nn.Sequential(
            conv_block(in_channels, hidden_channels),
            conv_block(hidden_channels, 2*hidden_channels),
            nn.MaxPool2d(2),
            conv_block(2*hidden_channels, 3*hidden_channels),
            nn.MaxPool2d(2)
        )

        self.global_pool = nn.AdaptiveAvgPool2d(1)
            
        self.classifier = nn.Sequential(
            linear_block(3*hidden_channels, 256),
            linear_block(256, 64),
            nn.Linear(64, n_classes)
        )
        
    def forward(self, x):
        o = self.conv(x)
        o = self.global_pool(o)  # shape → (batch, channels, 1)
        o = o.view(o.size(0), -1)   # shape → (batch, channels)
        return self.classifier(o)

In [7]:
def init_model_weights(model):
    for module in model.modules():
        ## Linear and Convolution 
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            nn.init.kaiming_uniform_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        ## Batchnorm 
        if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d)):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

In [8]:
def train_evaluate(model, train_loader, val_loader, save_path, lr, weight_decay, 
                   epochs=EPOCHS, early_stopping_patience=EARLY_STOPPING_PATIENCE, device="cuda"):
    model.to(device)
    optimizer    = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn      = nn.CrossEntropyLoss()
    lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=len(train_loader))
    history = {"train_losses":[], "val_losses":[], "train_accs":[], "val_accs":[]}

    min_change          = .002
    best_val_acc        = 0.0
    epochs_not_improved = 0
    
    for epoch in range(epochs):
        # -------------------- Model training -------------------- 
        model.train()
        train_loss = 0.0
        train_acc  = 0.0
        total      = 0.0
        for X, y in train_loader:
            X, y   = X.to(device), y.to(device)
            logits = model(X)
            loss   = loss_fn(logits, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            
            ## Accumulates loss and accuracy
            train_loss += loss.item()
            preds       = logits.argmax(dim=1)
            train_acc  += (preds==y).sum().item()
            total      += y.size(0)
        train_loss /= len(train_loader)
        train_acc  /= total
        
        # -------------------- Model Validation -------------------- 
        model.eval()
        val_loss = 0.0
        val_acc  = 0.0
        total    = 0.0
        with torch.inference_mode():
            for X, y in val_loader:
                X, y  = X.to(device), y.to(device)

                ## Compute loss
                logits    = model(X)
                loss      = loss_fn(logits, y)
                val_loss += loss.item()
                ## Tracks validation accuracy
                preds    = logits.argmax(dim=1)
                val_acc += (preds==y).sum().item()
                total   += y.size(0)
        val_loss /= len(val_loader)
        val_acc  /= total

        ## -------------------- Logging -------------------- 
        history["train_losses"].append(train_loss)
        history["val_losses"].append(val_loss)
        history["train_accs"].append(train_acc)
        history["val_accs"].append(val_acc)
        
        print(f"Epoch {epoch+1}/{epochs} | ",
              f"Train Loss: {train_loss:.4f} | ",
              f"Val Loss: {val_loss:.4f} | ",
              f"Train Acc: {train_acc:.3f} | ",
              f"Val Acc: {val_acc:.3f}")

        # ----------------- Early Stopping -----------------
        if val_acc - best_val_acc >= min_change:
            best_val_acc        = val_acc
            epochs_not_improved = 0
            torch.save(model.state_dict(), save_path)
        else:
            epochs_not_improved += 1
            print(f"Early Stopping Counter: {epochs_not_improved}/{early_stopping_patience}")
        if epochs_not_improved >= early_stopping_patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break
    return history


def predict(model, loader, device="cuda"):
    model.eval()
    all_preds = []
    with torch.inference_mode():
        for X,_ in loader:
            X = X.to(device)
            all_preds.append(model(X).argmax(dim=1))
    return torch.cat(all_preds).cpu().numpy()

In [9]:
def load_subject_data(data_path, subject_id, need_label_map=True, data_type='localizer'):
    """
    Load data for a single subject.
    Returns:
        X: ndarray (M_trials, C_channels, T_timepoints)
        y: labels  (M_trials,)
        epochs: MNE epochs object
        label_map: dict mapping event names to codes
    """
    file_path = Path(data_path) / subject_id / f"{subject_id}_{data_type}-epo.fif"
    epochs    = mne.read_epochs(file_path, preload=True, verbose=False)
    X         = epochs.get_data()
    y         = epochs.events[:,2]-1  # ranges from [1, 10], subtracts 1 to become [0,9]
    label_map = None
    if need_label_map: 
        label_map = {key:value-1 for key,value in epochs.event_id.items()} # shift values down to be in range [0,9]
    return X, y, epochs, label_map



def load_all_subjects_data(data_path, need_label_map=True, data_type='localizer'):
    """
    Load data for all subjects.
    Returns:
        X: ndarray (M_trials * num_subjects, C_channels, T_timepoints)
        y: labels  (M_trials * num_subjects,)
        groups: ndarray
        label_map: dict
    """
    subject_ids = os.listdir(data_path)
    all_X, all_y, all_groups, first_epochs = [], [], [], None
    
    for idx, subject_id in enumerate(subject_ids):
        X, y, epochs, label_map = load_subject_data(data_path, subject_id, need_label_map, data_type)
        if first_epochs is None: first_epochs=epochs
        all_X.append(X)
        all_y.append(y)
        all_groups.append(np.full(len(y), idx))
    
    X = np.concatenate(all_X, axis=0) 
    y = np.concatenate(all_y, axis=0) 
    groups = np.concatenate(all_groups, axis=0)
    return X, y, groups, first_epochs, label_map

In [10]:
def cross_validate(data_path, n_splits=5, 
                   dataset_params=None, model_params=None, training_params=None, device='cuda'):
    # LOAD ALL SUBJECTS
    X, y, groups, first_epochs, _ = load_all_subjects_data(data_path, False, "localizer")
    print(f"Total trials: {len(y)} | Subjects: {len(os.listdir(data_path))} | X dimension: {X.shape}")
    
    # --- GroupKFold CV ---
    gkf = GroupKFold(n_splits=n_splits)
    fold_scores = []
    
    for fold_idx, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups)):
        print(f"\nFold {fold_idx+1}/{n_splits}")

        ## Create train and validation datasets and loaders
        X_train, y_train = X[train_idx], y[train_idx]
        X_val,   y_val   = X[val_idx],   y[val_idx]
        train_data   = MEGDataset(X_train, y_train, first_epochs, **dataset_params)
        val_data     = MEGDataset(X_val, y_val, first_epochs, train_data.scaler, **dataset_params)
        train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn)
        val_loader   = DataLoader(val_data,   batch_size=32, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

        ## Trains model
        save_path = f"/kaggle/working/best_model_fold_{fold_idx+1}.pth"
        model     = Model(**model_params).to(device)
        init_model_weights(model)
        history   = train_evaluate(model, train_loader, val_loader, save_path, **training_params, device=device)
        
        # Evaluates model
        model = Model(**model_params).to(device)
        model.load_state_dict(torch.load(save_path))
        val_preds = predict(model, val_loader, device=device)
        val_acc   = (val_preds==y_val).mean()
        fold_scores.append(val_acc)
        print(f"Fold {fold_idx+1} Validation Accuracy: {val_acc:.4f}\n\n")
    
    mean_acc, std_acc = np.mean(fold_scores), np.std(fold_scores)
    print(f"\nMean Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")
    return fold_scores, mean_acc, std_acc

In [11]:
dataset_params  = {'time_window':TIME_WINDOW, 'grid_size':GRID_SIZE, 'n_time_slices':N_TIME_SLICES}
model_params    = {'in_channels':N_TIME_SLICES, 'n_classes':10}
training_params = {'epochs':EPOCHS, 'lr':LEARNING_RATE, "weight_decay":WEIGHT_DECAY}

# Cross-validation
cv_scores, mean_acc, std_acc = cross_validate(
    '/kaggle/input/the-imagine-decoding-challenge/train/train', 
    n_splits=15, 
    dataset_params=dataset_params,
    model_params=model_params,
    training_params=training_params,
    device='cuda'
)

# # Train on all & predict test
# # submission = train_and_predict(
# #     '/kaggle/input/the-imagine-decoding-challenge/train/train', 
# #     '/kaggle/input/the-imagine-decoding-challenge/test/test', 
# #     dataset_params=dataset_params,
# #     model_params=model_params,
# #     training_params=training_params,
# #     output_file='submission.csv',
# #     device='cpu'
# # )

Total trials: 7200 | Subjects: 15 | X dimension: (7200, 309, 121)

Fold 1/15
Epoch 1/40 |  Train Loss: 2.4960 |  Val Loss: 2.3122 |  Train Acc: 0.103 |  Val Acc: 0.115
Epoch 2/40 |  Train Loss: 2.4097 |  Val Loss: 2.3235 |  Train Acc: 0.106 |  Val Acc: 0.104
Early Stopping Counter: 1/20
Epoch 3/40 |  Train Loss: 2.3795 |  Val Loss: 2.3127 |  Train Acc: 0.102 |  Val Acc: 0.098
Early Stopping Counter: 2/20
Epoch 4/40 |  Train Loss: 2.3678 |  Val Loss: 2.3097 |  Train Acc: 0.110 |  Val Acc: 0.090
Early Stopping Counter: 3/20
Epoch 5/40 |  Train Loss: 2.3522 |  Val Loss: 2.3109 |  Train Acc: 0.108 |  Val Acc: 0.102
Early Stopping Counter: 4/20
Epoch 6/40 |  Train Loss: 2.3406 |  Val Loss: 2.3041 |  Train Acc: 0.112 |  Val Acc: 0.087
Early Stopping Counter: 5/20
Epoch 7/40 |  Train Loss: 2.3347 |  Val Loss: 2.3061 |  Train Acc: 0.113 |  Val Acc: 0.106
Early Stopping Counter: 6/20
Epoch 8/40 |  Train Loss: 2.3430 |  Val Loss: 2.3121 |  Train Acc: 0.111 |  Val Acc: 0.121
Epoch 9/40 |  Train L

In [12]:
# def objective(trial):
#     t0 = trial.suggest_float("time_window_t0", 0.0, 0.25)
#     t1 = trial.suggest_float("time_window_t1", 0.5, 0.85)
    
#     dataset_params  = {
#         'time_window'  : (t0, t1), 
#         'grid_size'    : trial.suggest_categorical("grid_size", [32, 48, 56, 64, 72]), 
#         'n_time_slices': trial.suggest_categorical("n_time_slices", [8, 12, 16, 32, 48])
#     }

#     model_params = {
#         'in_channels'       :dataset_params["n_time_slices"], 
#         'hidden_channels'   :trial.suggest_int("hidden_channels", dataset_params["n_time_slices"], 4*dataset_params["n_time_slices"]),
#         'conv_dropout'      :trial.suggest_float("conv_dropout", 0.1, 0.5),
#         'classifier_dropout':trial.suggest_float("classifier_dropout", 0.1, 0.5),
#         'conv_activation'   :trial.suggest_categorical("conv_activation", ["relu", "silu", "gelu"]),
#     }
#     model_params["reduction"] = trial.suggest_int("reduction", 8, model_params["hidden_channels"])
    
#     training_params = {
#         'lr':trial.suggest_float("lr", .0001, .01, log=True), 
#         "weight_decay":trial.suggest_float("weight_decay", .0001, .01, log=True)
#     }
    
#     cv_scores, mean_acc, std_acc = cross_validate(
#         '/kaggle/input/the-imagine-decoding-challenge/train/train', 
#         n_splits=3, 
#         dataset_params=dataset_params,
#         model_params=model_params,
#         training_params=training_params,
#         device='cuda'
#     )
#     return mean_acc,std_acc

# study = optuna.create_study(directions=["maximize", "minimize"])  
# study.optimize(objective, n_trials=25)