In [1]:
import mne
import numpy as np
import pandas as pd
import seaborn as sea
import matplotlib.pyplot as plt
import os
from pathlib import Path
import random
import joblib 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import Dataset, DataLoader
from scipy.interpolate import griddata

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import GroupKFold
from mne.viz.topomap import _prepare_topomap

import warnings
warnings.filterwarnings('ignore')

In [2]:
SEED = 3126
EPOCHS = 50
PCA_COMPONENTS = 25
LEARNING_RATE = 7.5e-4
WEIGHT_DECAY = 7.5e-3
EARLY_STOPPING_PATIENCE = 10

In [3]:
def set_seed(seed):
    random.seed(seed)               
    np.random.seed(seed)            
    torch.manual_seed(seed)         
    torch.cuda.manual_seed(seed)    
    torch.cuda.manual_seed_all(seed) 

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def worker_init_fn(worker_id):
    np.random.seed(SEED + worker_id)
    random.seed(SEED + worker_id)

set_seed(SEED)

In [4]:
class PCADataset(Dataset):
    def __init__(self, X, y, scaler=None, pca=None, augment=False, pca_components=PCA_COMPONENTS):
        """
        X: (trials, channels, timepoints)
        y: labels
        scaler: a fitted scaler object (optional). If None (for training), fit PCA on X.
        pca: a fitted PCA object (optional). If None (for training), fit PCA on X.
        """
        self.X = X
        self.y = y
        self.augment = augment
        self.pca_components = pca_components
        self.pool = nn.AdaptiveMaxPool1d(121)
        # Use provided PCA if given, otherwise fit one
        self.scaler = scaler if(scaler is not None) else StandardScaler()
        self.pca    = pca if(pca is not None) else PCA(n_components=pca_components, random_state=3126)
        self.X_new  = self.make_X_new(fit=(pca is None))
    
    def make_X_new(self, fit=True):
        m_trials, c_channels, t_time = self.X.shape
        X_reshaped = self.X.transpose(0,2,1).reshape(-1, c_channels)  # (trials*timepoints, channels)

        if fit: 
            X_reshaped = self.scaler.fit_transform(X_reshaped)
            X_new = self.pca.fit_transform(X_reshaped)
        else:   
            X_reshaped = self.scaler.transform(X_reshaped)
            X_new = self.pca.transform(X_reshaped)
        X_new = X_new.reshape(m_trials, t_time, self.pca_components).transpose(0,2,1)  # (trials, components, timepoints)
        ## make sure time length is 121 (length of localizer trials)
        if X_new.shape[-1] > 121:
            X_new = torch.tensor(X_new).float()
            X_new = self.pool(X_new).numpy()
        return X_new

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        trial = self.X_new[idx]
        label = self.y[idx]     

        if self.augment:
            noise_std = np.random.uniform(0.95, 1.35)
            trial    += np.random.normal(0, noise_std, trial.shape)
        return torch.tensor(trial, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

In [5]:
class Model(nn.Module):
    def __init__(self, in_channels=PCA_COMPONENTS, hidden_channels=64, dropout_p=.25, n_classes=10):
        super(Model, self).__init__()
        
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv1d(in_channels, out_channels, 3, padding=1, bias=False),
                nn.BatchNorm1d(out_channels),
                nn.ReLU())
            
        def linear_block(in_dim, out_dim):
            return nn.Sequential(
                nn.Linear(in_dim, out_dim, bias=False),
                nn.BatchNorm1d(out_dim),
                nn.ReLU(),
                nn.Dropout(dropout_p))
            
        self.conv = nn.Sequential(
            conv_block(in_channels, hidden_channels),
            nn.Dropout(dropout_p),
            conv_block(hidden_channels, 2*hidden_channels),
            nn.MaxPool1d(2),
            nn.Dropout(dropout_p),
            conv_block(2*hidden_channels, 4*hidden_channels),
            nn.MaxPool1d(2),
            nn.Dropout(dropout_p)
        )

        # with torch.no_grad():
        #     dummy = torch.zeros(1, in_channels, GRID_SIZE, GRID_SIZE)
        #     out = self.conv(dummy)
        #     self.flat_size = out.view(1, -1).size(1)
            
        self.classifier = nn.Sequential(
            linear_block(256*30, 256),
            linear_block(256, 64),
            linear_block(64, 16),
            nn.Linear(16, n_classes)
        )
        
    def forward(self, x):
        o = self.conv(x)
        o = o.view(o.size(0), -1)
        return self.classifier(o)

In [6]:
def init_model_weights(model):
    for module in model.modules():
        ## Linear and Convolution 
        if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
            nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        ## Batchnorm 
        if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

In [7]:
def train_evaluate(model, train_loader, val_loader, save_path, epochs=EPOCHS, lr=LEARNING_RATE, 
                   weight_decay=WEIGHT_DECAY, early_stopping_patience=EARLY_STOPPING_PATIENCE, device="cuda"):
    model.to(device)
    optimizer    = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn      = nn.CrossEntropyLoss()
    lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=len(train_loader))
    history = {"train_losses":[], "val_losses":[], "train_accs":[], "val_accs":[]}

    min_change          = .002
    best_val_acc        = 0.0
    epochs_not_improved = 0
    
    for epoch in range(epochs):
        # -------------------- Model training -------------------- 
        model.train()
        train_loss = 0.0
        train_acc  = 0.0
        total      = 0.0
        for X, y in train_loader:
            X, y   = X.to(device), y.to(device)
            logits = model(X)
            loss   = loss_fn(logits, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            
            ## Accumulates loss and accuracy
            train_loss += loss.item()
            preds       = logits.argmax(dim=1)
            train_acc  += (preds==y).sum().item()
            total      += y.size(0)
        train_loss /= len(train_loader)
        train_acc  /= total
        
        # -------------------- Model Validation -------------------- 
        model.eval()
        val_loss = 0.0
        val_acc  = 0.0
        total    = 0.0
        with torch.inference_mode():
            for X, y in val_loader:
                X, y  = X.to(device), y.to(device)

                ## Compute loss
                logits    = model(X)
                loss      = loss_fn(logits, y)
                val_loss += loss.item()
                ## Tracks validation accuracy
                preds    = logits.argmax(dim=1)
                val_acc += (preds==y).sum().item()
                total   += y.size(0)
        val_loss /= len(val_loader)
        val_acc  /= total

        ## -------------------- Logging -------------------- 
        history["train_losses"].append(train_loss)
        history["val_losses"].append(val_loss)
        history["train_accs"].append(train_acc)
        history["val_accs"].append(val_acc)
        
        print(f"Epoch {epoch+1}/{epochs} | ",
              f"Train Loss: {train_loss:.4f} | ",
              f"Val Loss: {val_loss:.4f} | ",
              f"Train Acc: {train_acc:.3f} | ",
              f"Val Acc: {val_acc:.3f}")

        # ----------------- Early Stopping -----------------
        if val_acc - best_val_acc >= min_change:
            best_val_acc        = val_acc
            epochs_not_improved = 0
            torch.save(model.state_dict(), save_path)
        else:
            epochs_not_improved += 1
            print(f"Early Stopping Counter: {epochs_not_improved}/{early_stopping_patience}")
        if epochs_not_improved >= early_stopping_patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break
    return history


def predict(model, loader, device="cuda"):
    model.eval()
    all_preds = []
    with torch.inference_mode():
        for X,_ in loader:
            X = X.to(device)
            all_preds.append(model(X).argmax(dim=1))
    return torch.cat(all_preds).cpu().numpy()

In [8]:
def load_subject_data(data_path, subject_id, need_label_map=True, data_type='localizer'):
    """
    Load data for a single subject.
    Returns:
        X: ndarray (M_trials, C_channels, T_timepoints)
        y: labels  (M_trials,)
        epochs: MNE epochs object
        label_map: dict mapping event names to codes
    """
    file_path = Path(data_path) / subject_id / f"{subject_id}_{data_type}-epo.fif"
    epochs    = mne.read_epochs(file_path, preload=True, verbose=False)
    X         = epochs.get_data()
    y         = epochs.events[:,2]-1  # ranges from [1, 10], subtracts 1 to become [0,9]
    label_map = None
    if need_label_map: 
        label_map = {key:value-1 for key,value in epochs.event_id.items()} # shift values down to be in range [0,9]
    return X, y, epochs, label_map



def load_all_subjects_data(data_path, need_label_map=True, data_type='localizer'):
    """
    Load data for all subjects.
    Returns:
        X: ndarray (M_trials * num_subjects, C_channels, T_timepoints)
        y: labels  (M_trials * num_subjects,)
        groups: ndarray
        label_map: dict
    """
    subject_ids = os.listdir(data_path)
    all_X, all_y, all_groups, first_epochs = [], [], [], None
    
    for idx, subject_id in enumerate(subject_ids):
        X, y, epochs, label_map = load_subject_data(data_path, subject_id, need_label_map, data_type)
        if first_epochs is None: first_epochs=epochs
        all_X.append(X)
        all_y.append(y)
        all_groups.append(np.full(len(y), idx))
    
    X = np.concatenate(all_X, axis=0) 
    y = np.concatenate(all_y, axis=0) 
    groups = np.concatenate(all_groups, axis=0)
    return X, y, groups, first_epochs, label_map

In [9]:
def cross_validate(data_path, n_splits=5, 
                   dataset_params=None, model_params=None, training_params=None, device='cuda'):
    
    # if dataset_params is None:
    #     dataset_params  = {'time_window': (0.1, 0.4), 'grid_size': 32, 'n_time_slices': 5}
    if model_params is None:
        model_params    = {'in_channels': PCA_COMPONENTS, 'n_classes': 10}
    if training_params is None:
        training_params = {'epochs': 20, 'lr': 1e-3}
    
    # LOAD ALL SUBJECTS
    X, y, groups, first_epochs, _ = load_all_subjects_data(data_path, False, "localizer")
    print(f"Total trials: {len(y)} | Subjects: {len(os.listdir(data_path))} | X dimension: {X.shape}")
    
    # --- GroupKFold CV ---
    gkf = GroupKFold(n_splits=n_splits)
    fold_scores = []
    
    for fold_idx, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups)):
        print(f"\nFold {fold_idx+1}/{n_splits}")

        ## Create train and validation datasets and loaders
        X_train, y_train = X[train_idx], y[train_idx]
        X_val,   y_val   = X[val_idx],   y[val_idx]
        train_data   = PCADataset(X_train, y_train, augment=True)
        val_data     = PCADataset(X_val, y_val, train_data.scaler, train_data.pca, augment=False)
        train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn)
        val_loader   = DataLoader(val_data,   batch_size=32, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

        ## Trains model
        save_path = f"/kaggle/working/best_model_fold_{fold_idx+1}.pth"
        model     = Model(**model_params).to(device)
        init_model_weights(model)
        history   = train_evaluate(model, train_loader, val_loader, save_path, **training_params, device=device)
        
        # Evaluates model
        model = Model(**model_params).to(device)
        model.load_state_dict(torch.load(save_path))
        val_preds = predict(model, val_loader, device=device)
        val_acc   = (val_preds==y_val).mean()
        fold_scores.append(val_acc)
        print(f"Fold {fold_idx+1} Validation Accuracy: {val_acc:.4f}\n\n")
    
    mean_acc, std_acc = np.mean(fold_scores), np.std(fold_scores)
    print(f"\nMean Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")
    return fold_scores, mean_acc, std_acc

In [10]:
# dataset_params  = {'time_window':TIME_WINDOW, 'grid_size':GRID_SIZE, 'n_time_slices':N_TIME_SLICES}
model_params    = {'in_channels':PCA_COMPONENTS, 'n_classes':10}
training_params = {'epochs':EPOCHS, 'lr':LEARNING_RATE}

# Cross-validation
cv_scores, mean_acc, std_acc = cross_validate(
    '/kaggle/input/the-imagine-decoding-challenge/train/train', 
    n_splits=5, 
    dataset_params=None,
    model_params=model_params,
    training_params=training_params,
    device='cuda'
)

Total trials: 7200 | Subjects: 15 | X dimension: (7200, 309, 121)

Fold 1/5
Epoch 1/50 |  Train Loss: 2.6514 |  Val Loss: 2.4004 |  Train Acc: 0.102 |  Val Acc: 0.100
Epoch 2/50 |  Train Loss: 2.4425 |  Val Loss: 2.3380 |  Train Acc: 0.105 |  Val Acc: 0.101
Early Stopping Counter: 1/10
Epoch 3/50 |  Train Loss: 2.3766 |  Val Loss: 2.3139 |  Train Acc: 0.109 |  Val Acc: 0.093
Early Stopping Counter: 2/10
Epoch 4/50 |  Train Loss: 2.3525 |  Val Loss: 2.3048 |  Train Acc: 0.119 |  Val Acc: 0.099
Early Stopping Counter: 3/10
Epoch 5/50 |  Train Loss: 2.3211 |  Val Loss: 2.3039 |  Train Acc: 0.126 |  Val Acc: 0.114
Epoch 6/50 |  Train Loss: 2.3131 |  Val Loss: 2.3103 |  Train Acc: 0.125 |  Val Acc: 0.108
Early Stopping Counter: 1/10
Epoch 7/50 |  Train Loss: 2.2903 |  Val Loss: 2.3137 |  Train Acc: 0.143 |  Val Acc: 0.110
Early Stopping Counter: 2/10
Epoch 8/50 |  Train Loss: 2.2813 |  Val Loss: 2.3054 |  Train Acc: 0.143 |  Val Acc: 0.110
Early Stopping Counter: 3/10
Epoch 9/50 |  Train Lo

# Submission

In [11]:
 X, y, _, _, label_map = load_all_subjects_data("/kaggle/input/the-imagine-decoding-challenge/train/train", True, "imagine")
train_data   = PCADataset(X, y)

In [12]:
test_subjects = os.listdir("/kaggle/input/the-imagine-decoding-challenge/test/test")
results = []

inverse_map = {v: k for k, v in label_map.items()}

for subject in test_subjects:
    # Load test data and apply train PCA + scaler
    X_test, _, imagine_epochs, _ = load_subject_data("/kaggle/input/the-imagine-decoding-challenge/test/test", subject, False, 'imagine')
    test_data = PCADataset(X_test, np.zeros(len(X_test)), scaler=train_data.scaler, pca=train_data.pca, augment=False)
    test_loader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=4,
                             worker_init_fn=worker_init_fn)

    # Collect softmax probabilities from all folds
    probs_all = []
    for k in range(1, 6):
        fold_path = f"/kaggle/working/best_model_fold_{k}.pth"
        model = Model(**model_params).to("cuda")
        model.load_state_dict(torch.load(fold_path))
        model.eval()
        fold_probs = []
        with torch.inference_mode():
            for X_batch, _ in test_loader:
                X_batch = X_batch.to("cuda")
                logits = model(X_batch)
                probs = torch.softmax(logits, dim=1)
                fold_probs.append(probs.cpu())
        fold_probs = torch.cat(fold_probs, dim=0).numpy()
        probs_all.append(fold_probs)

    # Average probabilities across folds
    probs_all = np.array(probs_all)            # shape: (n_folds, n_trials, n_classes)
    avg_probs = np.mean(probs_all, axis=0)     # shape: (n_trials, n_classes)

    # Take argmax to get final predicted class
    ensemble_preds = np.argmax(avg_probs, axis=1)
    pred_labels = [inverse_map[pred] for pred in ensemble_preds]

    for i, label in enumerate(pred_labels, start=1):
        results.append({'ID': f"{subject}_{i}", 'label': label})

submission_df = pd.DataFrame(results)
submission_df.to_csv("submission.csv", index=False)
print(f"Saved submission with {len(submission_df)} rows.")


Saved submission with 680 rows.
