In [1]:
import mne
import numpy as np
import pandas as pd
import seaborn as sea
import matplotlib.pyplot as plt
import os
from pathlib import Path
import random
import joblib 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import Dataset, DataLoader
from scipy.interpolate import griddata

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import GroupKFold
from mne.viz.topomap import _prepare_topomap

import warnings
warnings.filterwarnings('ignore')

In [2]:
SEED = 3126
EPOCHS = 20
GRID_SIZE = 32
TIME_WINDOW = (0.1, 0.5)
N_TIME_SLICES = 16
LEARNING_RATE = 7.5e-4

In [3]:
def set_seed(seed):
    random.seed(seed)               
    np.random.seed(seed)            
    torch.manual_seed(seed)         
    torch.cuda.manual_seed(seed)    
    torch.cuda.manual_seed_all(seed) 

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def worker_init_fn(worker_id):
    np.random.seed(SEED + worker_id)
    random.seed(SEED + worker_id)

set_seed(SEED)

In [4]:
class MEGDataset(Dataset):
    """
    Converts MEG epochs into interpolated 2D topographic images.
    Each trial => tensor of shape (n_time_slices, grid_size, grid_size)

    Inputs:
        X: ndarray (M_trials, C_channels, T_timepoints)
        y: labels (M_trials,)
        epochs: MNE epochs object (contains time info and sensor positions)
        time_window: start and end of the period you care about (t0 and t1, in seconds)
        grid_size: The 2D topomap's size (H, W)
        n_time_slices: Number of snapshots to take between t0 and t1
    """
    def __init__(self, X, y, epochs, time_window=(0.1, 0.4), grid_size=32, n_time_slices=5):
        self.X = X
        self.y = y
        self.times = epochs.times
        self.info  = epochs.info
        self.grid_size     = grid_size
        self.n_time_slices = n_time_slices
        
        # Get indices for starting and ending time of interested interval
        t0, t1     = time_window
        self.start = np.argmin(np.abs(self.times - t0))
        self.end   = np.argmin(np.abs(self.times - t1))
        
        # Get orientations of sensors in 3D space and use first two components only (x,y)
        positions_3d = np.array([ch['loc'][:3] for ch in self.info['chs']])
        self.pos2d = positions_3d[:, :2]  
        
        # Prepare grid for interpolated topomap
        self.grid_x, self.grid_y = np.meshgrid(
            np.linspace(self.pos2d[:,0].min(), self.pos2d[:,0].max(), self.grid_size),
            np.linspace(self.pos2d[:,1].min(), self.pos2d[:,1].max(), self.grid_size)
        )
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        trial = self.X[idx]  # (C_channels, T_timepoints)
        label = self.y[idx]
        
        # Select evenly-spaced time slices
        slice_times = np.linspace(self.start, self.end-1, self.n_time_slices).astype(int)
        topo_imgs = []
        
        for slice_t in slice_times:
            values = trial[:, slice_t]  # sensor readings at time slice_t for C_channels (C_channels,)
            topo = griddata(
                self.pos2d, values,
                (self.grid_x, self.grid_y),
                method='cubic',
                fill_value=0.0
            )
            topo_imgs.append(topo)
        
        topo_imgs = np.stack(topo_imgs, axis=0)  # (n_time_slices, H, W)
        return torch.tensor(topo_imgs, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

In [5]:
class Model(nn.Module):
    def __init__(self, in_channels=N_TIME_SLICES, hidden_channels=32, dropout_p=.2, n_classes=10):
        super(Model, self).__init__()
        
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU())
            
        def linear_block(in_dim, out_dim):
            return nn.Sequential(
                nn.Linear(in_dim, out_dim, bias=False),
                nn.BatchNorm1d(out_dim),
                nn.ReLU(),
                nn.Dropout(dropout_p))
            
        self.conv = nn.Sequential(
            conv_block(in_channels, hidden_channels),
            conv_block(hidden_channels, 2*hidden_channels),
            nn.MaxPool2d(2),
            conv_block(2*hidden_channels, 3*hidden_channels),
            nn.MaxPool2d(2)
        )

        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, GRID_SIZE, GRID_SIZE)
            out = self.conv(dummy)
            self.flat_size = out.view(1, -1).size(1)
            
        self.classifier = nn.Sequential(
            linear_block(self.flat_size, 256),
            linear_block(256, 64),
            nn.Linear(64, n_classes)
        )
        
    def forward(self, x):
        o = self.conv(x)
        o = o.view(o.size(0), -1)
        return self.classifier(o)

In [6]:
def init_model_weights(model):
    for module in model.modules():
        ## Linear and Convolution 
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        ## Batchnorm 
        if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d)):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

In [7]:
def train_evaluate(model, train_loader, val_loader, epochs=20, lr=1e-3, 
                   weight_decay=.001, early_stopping_patience=10, device="cuda"):
    model.to(device)
    optimizer    = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn      = nn.CrossEntropyLoss()
    lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=len(train_loader))
    history = {"train_losses":[], "val_losses":[], "train_accs":[], "val_accs":[]}

    min_change          = .002
    best_val_acc        = 0.0
    epochs_not_improved = 0
    
    for epoch in range(epochs):
        # -------------------- Model training -------------------- 
        model.train()
        train_loss = 0.0
        train_acc  = 0.0
        total      = 0.0
        for X, y in train_loader:
            X, y   = X.to(device), y.to(device)
            logits = model(X)
            loss   = loss_fn(logits, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            
            ## Accumulates loss and accuracy
            train_loss += loss.item()
            preds       = logits.argmax(dim=1)
            train_acc  += (preds==y).sum().item()
            total      += y.size(0)
        train_loss /= len(train_loader)
        train_acc  /= total
        
        # -------------------- Model Validation -------------------- 
        model.eval()
        val_loss = 0.0
        val_acc  = 0.0
        total    = 0.0
        with torch.inference_mode():
            for X, y in val_loader:
                X, y  = X.to(device), y.to(device)

                ## Compute loss
                logits    = model(X)
                loss      = loss_fn(logits, y)
                val_loss += loss.item()
                ## Tracks validation accuracy
                preds    = logits.argmax(dim=1)
                val_acc += (preds==y).sum().item()
                total   += y.size(0)
        val_loss /= len(val_loader)
        val_acc  /= total

        ## -------------------- Logging -------------------- 
        history["train_losses"].append(train_loss)
        history["val_losses"].append(val_loss)
        history["train_accs"].append(train_acc)
        history["val_accs"].append(val_acc)
        
        print(f"Epoch {epoch+1}/{epochs} | ",
              f"Train Loss: {train_loss:.4f} | ",
              f"Val Loss: {val_loss:.4f} | ",
              f"Train Acc: {train_acc:.3f} | ",
              f"Val Acc: {val_acc:.3f}")

        # ----------------- Early Stopping -----------------
        if val_acc - best_val_acc >= min_change:
            best_val_acc        = val_acc
            epochs_not_improved = 0
            torch.save(model.state_dict(), "best_model.pth")
        else:
            epochs_not_improved += 1
            print(f"Early Stopping Counter: {epochs_not_improved}/{early_stopping_patience}")
        if epochs_not_improved >= early_stopping_patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break
    return history


def predict(model, loader, device="cuda"):
    model.eval()
    all_preds = []
    with torch.inference_mode():
        for X,_ in loader:
            X = X.to(device)
            all_preds.append(model(X).argmax(dim=1))
    return torch.cat(all_preds).cpu().numpy()

In [8]:
def load_subject_data(data_path, subject_id, need_label_map=True, data_type='localizer'):
    """
    Load data for a single subject.
    Returns:
        X: ndarray (M_trials, C_channels, T_timepoints)
        y: labels  (M_trials,)
        epochs: MNE epochs object
        label_map: dict mapping event names to codes
    """
    file_path = Path(data_path) / subject_id / f"{subject_id}_{data_type}-epo.fif"
    epochs    = mne.read_epochs(file_path, preload=True, verbose=False)
    X         = epochs.get_data()
    y         = epochs.events[:,2]-1  # ranges from [1, 10], subtracts 1 to become [0,9]
    label_map = None
    if need_label_map: 
        label_map = {key:value-1 for key,value in epochs.event_id.items()} # shift values down to be in range [0,9]
    return X, y, epochs, label_map



def load_all_subjects_data(data_path, need_label_map=True, data_type='localizer'):
    """
    Load data for all subjects.
    Returns:
        X: ndarray (M_trials * num_subjects, C_channels, T_timepoints)
        y: labels  (M_trials * num_subjects,)
        groups: ndarray
        label_map: dict
    """
    subject_ids = os.listdir(data_path)
    all_X, all_y, all_groups, first_epochs = [], [], [], None
    
    for idx, subject_id in enumerate(subject_ids):
        X, y, epochs, label_map = load_subject_data(data_path, subject_id, need_label_map, data_type)
        if first_epochs is None: first_epochs=epochs
        all_X.append(X)
        all_y.append(y)
        all_groups.append(np.full(len(y), idx))
    
    X = np.concatenate(all_X, axis=0) 
    y = np.concatenate(all_y, axis=0) 
    groups = np.concatenate(all_groups, axis=0)
    return X, y, groups, first_epochs, label_map

In [9]:
def cross_validate(data_path, n_splits=5, 
                   dataset_params=None, model_params=None, training_params=None, device='cuda'):
    
    if dataset_params is None:
        dataset_params  = {'time_window': (0.1, 0.4), 'grid_size': 32, 'n_time_slices': 5}
    if model_params is None:
        model_params    = {'in_channels': dataset_params['n_time_slices'], 'n_classes': 10}
    if training_params is None:
        training_params = {'epochs': 20, 'lr': 1e-3}
    
    # LOAD ALL SUBJECTS
    X, y, groups, first_epochs, _ = load_all_subjects_data(data_path, False, "localizer")
    print(f"Total trials: {len(y)} | Subjects: {len(os.listdir(data_path))} | X dimension: {X.shape}")
    
    # --- GroupKFold CV ---
    gkf = GroupKFold(n_splits=n_splits)
    fold_scores = []
    
    for fold_idx, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups)):
        print(f"\nFold {fold_idx+1}/{n_splits}")

        ## Create train and validation datasets
        X_train, y_train = X[train_idx], y[train_idx]
        X_val,   y_val   = X[val_idx],   y[val_idx]
        train_data = MEGDataset(X_train, y_train, first_epochs, **dataset_params)
        val_data   = MEGDataset(X_val, y_val, first_epochs, **dataset_params)
        ## Create dataloaders
        train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn)
        val_loader   = DataLoader(val_data,   batch_size=32, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

        ## Trains model
        model   = Model(**model_params).to(device)
        init_model_weights(model)
        history = train_evaluate(model, train_loader, val_loader, **training_params, device=device)
        
        # Evaluates model
        model = Model(**model_params).to(device)
        model.load_state_dict(torch.load("/kaggle/working/best_model.pth"))
        val_preds = predict(model, val_loader, device=device)
        val_acc   = (val_preds==y_val).mean()
        fold_scores.append(val_acc)
        print(f"Fold {fold_idx+1} Validation Accuracy: {val_acc:.4f}\n\n")
    
    mean_acc, std_acc = np.mean(fold_scores), np.std(fold_scores)
    print(f"\nMean Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")
    return fold_scores, mean_acc, std_acc

In [10]:
def train_and_predict(train_path, test_path, dataset_params=None, model_params=None,
                         training_params=None, output_file='submission.csv', device='cuda'):
    
    # --- Load all training data and trains model---
    X_train, y_train, _, first_epochs, label_map = load_all_subjects_data(train_path, True, "localizer")
    train_data   = MEGDataset(X_train, y_train, first_epochs, **dataset_params)
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn)
    
    model = Model(**model_params).to(device)
    history = train_evaluate(model, train_loader, train_loader, **training_params, device=device)
    
    # --- Predict test subjects ---
    test_subjects = os.listdir(test_path)
    results = []
    
    inverse_map = {v:k for k,v in label_map.items()}
    
    for subject in test_subjects:
        X_test, _, imagine_epochs, _ = load_subject_data(test_path, subject, False, 'imagine')
        test_data   = MEGTopographicDataset(X_test, np.zeros(len(X_test)), imagine_epochs, **dataset_params)
        test_loader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)
        
        preds = predict(model, test_loader, device=device)
        pred_labels = [inverse_map[pred] for pred in preds]
        
        for i, label in enumerate(pred_labels, start=1):
            results.append({'ID': f"{sub}_{i}", 'label': label})
    
    submission_df = pd.DataFrame(results)
    submission_df.to_csv(output_file, index=False)
    print(f"Saved submission with {len(submission_df)} rows.")
    return submission_df

In [11]:
dataset_params  = {'time_window':TIME_WINDOW, 'grid_size':GRID_SIZE, 'n_time_slices':N_TIME_SLICES}
model_params    = {'in_channels':dataset_params['n_time_slices'], 'n_classes':10}
training_params = {'epochs':EPOCHS, 'lr':LEARNING_RATE}

# Cross-validation
cv_scores, mean_acc, std_acc = cross_validate(
    '/kaggle/input/the-imagine-decoding-challenge/train/train', 
    n_splits=5, 
    dataset_params=dataset_params,
    model_params=model_params,
    training_params=training_params,
    device='cuda'
)

# Train on all & predict test
# submission = train_and_predict(
#     '/kaggle/input/the-imagine-decoding-challenge/train/train', 
#     '/kaggle/input/the-imagine-decoding-challenge/test/test', 
#     dataset_params=dataset_params,
#     model_params=model_params,
#     training_params=training_params,
#     output_file='submission.csv',
#     device='cpu'
# )

Total trials: 7200 | Subjects: 15 | X dimension: (7200, 309, 121)

Fold 1/5
Epoch 1/20 |  Train Loss: 2.5115 |  Val Loss: 2.3445 |  Train Acc: 0.099 |  Val Acc: 0.107
Epoch 2/20 |  Train Loss: 2.4082 |  Val Loss: 2.3289 |  Train Acc: 0.105 |  Val Acc: 0.088
Early Stopping Counter: 1/10
Epoch 3/20 |  Train Loss: 2.3839 |  Val Loss: 2.3216 |  Train Acc: 0.101 |  Val Acc: 0.107
Early Stopping Counter: 2/10
Epoch 4/20 |  Train Loss: 2.3630 |  Val Loss: 2.3255 |  Train Acc: 0.105 |  Val Acc: 0.098
Early Stopping Counter: 3/10
Epoch 5/20 |  Train Loss: 2.3506 |  Val Loss: 2.3105 |  Train Acc: 0.110 |  Val Acc: 0.114
Epoch 6/20 |  Train Loss: 2.3433 |  Val Loss: 2.3130 |  Train Acc: 0.117 |  Val Acc: 0.113
Early Stopping Counter: 1/10
Epoch 7/20 |  Train Loss: 2.3404 |  Val Loss: 2.3086 |  Train Acc: 0.118 |  Val Acc: 0.107
Early Stopping Counter: 2/10
Epoch 8/20 |  Train Loss: 2.3324 |  Val Loss: 2.3076 |  Train Acc: 0.119 |  Val Acc: 0.111
Early Stopping Counter: 3/10
Epoch 9/20 |  Train Lo