In [1]:
# =======================================================
# 1) Setup and imports
# =======================================================
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score
from scipy.signal import butter, lfilter

torch.manual_seed(42)
np.random.seed(42)

# =======================================================
# 2) Config
# =======================================================
BASE_PATH = '/kaggle/input/mtcaic3'
CHANNELS = ['FZ', 'C3', 'CZ', 'C4']
FS = 250
MU_BAND = (8, 13)
BETA_BAND = (13, 30)
MAX_LEN = 2250  # pad/truncate to this length

# =======================================================
# 3) Bandpass filter
# =======================================================
def butter_bandpass(lowcut, highcut, fs, order=4):
    nyq = 0.5 * fs
    low, high = lowcut / nyq, highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

def bandpass_filter(data, lowcut, highcut, fs):
    b, a = butter_bandpass(lowcut, highcut, fs)
    return lfilter(b, a, data)

# =======================================================
# 4) EEG Dataset with light augmentation
# =======================================================
class EEGDataset(Dataset):
    def __init__(self, df, base_path, le, augment=False):
        self.df = df.reset_index(drop=True)
        self.base_path = base_path
        self.le = le
        self.augment = augment
        if 'label' in df:
            self.labels = self.le.transform(df['label'])
        else:
            self.labels = None

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        dataset = 'train' if row['id'] <= 4800 else 'validation' if row['id'] <= 4900 else 'test'
        eeg_path = f"{self.base_path}/{row['task']}/{dataset}/{row['subject_id']}/{row['trial_session']}/EEGdata.csv"
        eeg_data = pd.read_csv(eeg_path)

        samples_per_trial = 2250 if row['task'] == 'MI' else 1750
        start = (row['trial'] - 1) * samples_per_trial
        end = start + samples_per_trial

        trial = eeg_data.iloc[start:end][CHANNELS].values.T  # (4, T)

        # Pad or truncate
        if trial.shape[1] < MAX_LEN:
            pad_width = MAX_LEN - trial.shape[1]
            trial = np.pad(trial, ((0, 0), (0, pad_width)), mode='constant')
        elif trial.shape[1] > MAX_LEN:
            trial = trial[:, :MAX_LEN]

        # Augment (if training)
        if self.augment:
            # Add Gaussian noise
            noise = np.random.normal(0, 0.01, trial.shape)
            trial = trial + noise
            # Random scaling
            scale = np.random.uniform(0.9, 1.1)
            trial = trial * scale

        # Filter bands
        mu = bandpass_filter(trial, MU_BAND[0], MU_BAND[1], FS)
        beta = bandpass_filter(trial, BETA_BAND[0], BETA_BAND[1], FS)

        # Tabular features
        means = np.mean(trial, axis=1)
        diff_c3_c4 = means[1] - means[3]
        mu_power = np.mean(mu ** 2, axis=1)
        beta_power = np.mean(beta ** 2, axis=1)
        tabular = np.concatenate([means, [diff_c3_c4], mu_power, beta_power]).astype(np.float32)

        waveform = torch.tensor(trial, dtype=torch.float32)
        tabular = torch.tensor(tabular, dtype=torch.float32)

        if self.labels is not None:
            label = torch.tensor(self.labels[idx], dtype=torch.long)
            return waveform, tabular, label
        else:
            return waveform, tabular

# =======================================================
# 5) Improved CNN + MLP Model
# =======================================================
class CNN_MLP(nn.Module):
    def __init__(self, in_channels, tabular_dim, num_classes):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(in_channels, 32, 7, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Conv1d(32, 64, 7, padding=3),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 7, padding=3),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten()
        )
        self.mlp = nn.Sequential(
            nn.Linear(tabular_dim, 32),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        self.final = nn.Linear(128 + 16, num_classes)

    def forward(self, x_wave, x_tab):
        x1 = self.cnn(x_wave)
        x2 = self.mlp(x_tab)
        x = torch.cat([x1, x2], dim=1)
        return self.final(x)

# =======================================================
# 6) Prepare Data
# =======================================================
train_df = pd.read_csv(f"{BASE_PATH}/train.csv")
val_df = pd.read_csv(f"{BASE_PATH}/validation.csv")

# Filter MI task only
train_df = train_df[train_df['task'] == 'MI'].copy()
val_df = val_df[val_df['task'] == 'MI'].copy()

print("Train MI shape:", train_df.shape)
print("Val MI shape:", val_df.shape)

# Shared LabelEncoder
all_labels = pd.concat([train_df['label'], val_df['label']])
le = LabelEncoder()
le.fit(all_labels)

print("Classes:", le.classes_)

train_set = EEGDataset(train_df, BASE_PATH, le, augment=True)
val_set = EEGDataset(val_df, BASE_PATH, le, augment=False)

train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False, num_workers=2)

# =======================================================
# 7) Training Setup
# =======================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CNN_MLP(
    in_channels=len(CHANNELS),
    tabular_dim=4 + 1 + 4 + 4,  # means + diff + mu_power + beta_power
    num_classes=len(le.classes_)
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
best_val_f1 = 0
patience = 15
patience_counter = 0

history = {'train_loss': [], 'val_loss': [], 'val_f1': []}

# =======================================================
# 8) Train with Early Stopping
# =======================================================
for epoch in range(100):
    model.train()
    train_loss = 0.0
    for x_wave, x_tab, y in train_loader:
        x_wave, x_tab, y = x_wave.to(device), x_tab.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x_wave, x_tab)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * y.size(0)

    train_loss /= len(train_loader.dataset)

    model.eval()
    val_loss = 0.0
    y_true, y_pred = [], []
    with torch.no_grad():
        for x_wave, x_tab, y in val_loader:
            x_wave, x_tab, y = x_wave.to(device), x_tab.to(device), y.to(device)
            out = model(x_wave, x_tab)
            loss = criterion(out, y)
            val_loss += loss.item() * y.size(0)
            y_true.extend(y.cpu().numpy())
            y_pred.extend(out.argmax(dim=1).cpu().numpy())

    val_loss /= len(val_loader.dataset)
    val_f1 = f1_score(y_true, y_pred, average='macro')

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_f1'].append(val_f1)

    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f} | Val Loss={val_loss:.4f} | Val F1={val_f1:.4f}")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("✅ Early stopping triggered!")
            break

# =======================================================
# 9) Save logs
# =======================================================
np.savez('training_log.npz', **history)
print("✅ Training log saved: training_log.npz")


Train MI shape: (2400, 6)
Val MI shape: (50, 6)
Classes: ['Left' 'Right']
Epoch 1: Train Loss=357825.9558 | Val Loss=71288.6909 | Val F1=0.5895
Epoch 2: Train Loss=262504.0104 | Val Loss=62985.7613 | Val F1=0.5659
Epoch 3: Train Loss=198695.3972 | Val Loss=55712.9608 | Val F1=0.5895
Epoch 4: Train Loss=158962.3735 | Val Loss=58957.6112 | Val F1=0.4907
Epoch 5: Train Loss=129660.7802 | Val Loss=58120.9601 | Val F1=0.4652
Epoch 6: Train Loss=98752.0924 | Val Loss=45285.6936 | Val F1=0.3333
Epoch 7: Train Loss=70992.4940 | Val Loss=33537.6543 | Val F1=0.2958
Epoch 8: Train Loss=57662.9590 | Val Loss=26534.2320 | Val F1=0.4325
Epoch 9: Train Loss=48592.3281 | Val Loss=22395.1078 | Val F1=0.4167
Epoch 10: Train Loss=37619.1287 | Val Loss=16054.8467 | Val F1=0.4492
Epoch 11: Train Loss=28833.6121 | Val Loss=10973.7661 | Val F1=0.4945
Epoch 12: Train Loss=21679.8944 | Val Loss=7806.3937 | Val F1=0.4945
Epoch 13: Train Loss=15229.2676 | Val Loss=5468.4406 | Val F1=0.5098
Epoch 14: Train Loss=1

In [1]:
# =======================================================
# 1) Setup and imports
# =======================================================
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score, confusion_matrix, ConfusionMatrixDisplay
from scipy.signal import butter, lfilter
import matplotlib.pyplot as plt
from mne.decoding import CSP

torch.manual_seed(42)
np.random.seed(42)

# =======================================================
# 2) Config
# =======================================================
BASE_PATH = '/kaggle/input/mtcaic3'
CHANNELS = ['FZ', 'C3', 'CZ', 'C4', 'PZ', 'PO7', 'OZ', 'PO8']
FS = 250
MAX_LEN = 2250  # pad/truncate to this length

# =======================================================
# 3) Bandpass filter
# =======================================================
def butter_bandpass(lowcut, highcut, fs, order=4):
    nyq = 0.5 * fs
    low, high = lowcut / nyq, highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

def bandpass_filter(data, lowcut, highcut, fs):
    b, a = butter_bandpass(lowcut, highcut, fs)
    return lfilter(b, a, data)

# =======================================================
# 4) EEG Dataset with CSP and light augmentation
# =======================================================
class EEGDataset(Dataset):
    def __init__(self, df, base_path, le, csp=None, augment=False):
        self.df = df.reset_index(drop=True)
        self.base_path = base_path
        self.le = le
        self.csp = csp
        self.augment = augment
        if 'label' in df:
            self.labels = self.le.transform(df['label'])
        else:
            self.labels = None

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        dataset = 'train' if row['id'] <= 4800 else 'validation' if row['id'] <= 4900 else 'test'
        eeg_path = f"{self.base_path}/{row['task']}/{dataset}/{row['subject_id']}/{row['trial_session']}/EEGdata.csv"
        eeg_data = pd.read_csv(eeg_path)

        samples_per_trial = 2250 if row['task'] == 'MI' else 1750
        start = (row['trial'] - 1) * samples_per_trial
        end = start + samples_per_trial

        trial = eeg_data.iloc[start:end][CHANNELS].values.T  # (C, T)

        if trial.shape[1] < MAX_LEN:
            trial = np.pad(trial, ((0, 0), (0, MAX_LEN - trial.shape[1])), mode='constant')
        elif trial.shape[1] > MAX_LEN:
            trial = trial[:, :MAX_LEN]

        if self.augment:
            noise = np.random.normal(0, 0.01, trial.shape)
            trial += noise
            scale = np.random.uniform(0.9, 1.1)
            trial *= scale

        if self.csp is not None:
            trial = self.csp.transform(trial.T[np.newaxis, :, :])[0].T  # (CSP_components, T)

        # Simple tabular: means + diff
        means = np.mean(trial, axis=1)
        diff_c3_c4 = means[1] - means[3] if len(means) >= 4 else 0.0
        tabular = np.concatenate([means, [diff_c3_c4]]).astype(np.float32)

        waveform = torch.tensor(trial, dtype=torch.float32)
        tabular = torch.tensor(tabular, dtype=torch.float32)

        if self.labels is not None:
            label = torch.tensor(self.labels[idx], dtype=torch.long)
            return waveform, tabular, label
        else:
            return waveform, tabular

# =======================================================
# 5) Improved CNN + MLP Model
# =======================================================
class CNN_MLP(nn.Module):
    def __init__(self, in_channels, tabular_dim, num_classes):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(in_channels, 32, 7, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Conv1d(32, 64, 7, padding=3),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 7, padding=3),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten()
        )
        self.mlp = nn.Sequential(
            nn.Linear(tabular_dim, 32),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        self.final = nn.Linear(128 + 16, num_classes)

    def forward(self, x_wave, x_tab):
        x1 = self.cnn(x_wave)
        x2 = self.mlp(x_tab)
        x = torch.cat([x1, x2], dim=1)
        return self.final(x)

# =======================================================
# 6) Prepare Data
# =======================================================
train_df = pd.read_csv(f"{BASE_PATH}/train.csv")
val_df = pd.read_csv(f"{BASE_PATH}/validation.csv")

train_df = train_df[train_df['task'] == 'MI'].copy()
val_df = val_df[val_df['task'] == 'MI'].copy()

print("Train MI shape:", train_df.shape)
print("Val MI shape:", val_df.shape)

# Shared LabelEncoder
all_labels = pd.concat([train_df['label'], val_df['label']])
le = LabelEncoder()
le.fit(all_labels)

print("Classes:", le.classes_)

# =======================================================
# 7) Fit CSP on training data
# =======================================================
print("Fitting CSP...")

X_waveforms = []
y_labels = []
for i in range(len(train_df)):
    row = train_df.iloc[i]
    eeg_path = f"{BASE_PATH}/{row['task']}/train/{row['subject_id']}/{row['trial_session']}/EEGdata.csv"
    eeg_data = pd.read_csv(eeg_path)
    samples_per_trial = 2250
    start = (row['trial'] - 1) * samples_per_trial
    end = start + samples_per_trial
    trial = eeg_data.iloc[start:end][CHANNELS].values.T
    if trial.shape[1] < MAX_LEN:
        trial = np.pad(trial, ((0, 0), (0, MAX_LEN - trial.shape[1])))
    elif trial.shape[1] > MAX_LEN:
        trial = trial[:, :MAX_LEN]
    X_waveforms.append(trial)
    y_labels.append(row['label'])

X_waveforms = np.stack(X_waveforms, axis=0)  # (N, C, T)
X_waveforms = np.transpose(X_waveforms, (0, 2, 1))  # (N, T, C)
y_labels = le.transform(y_labels)

csp = CSP(n_components=4, reg=None, log=True, norm_trace=False)
csp.fit(X_waveforms, y_labels)

# =======================================================
# 8) Create Datasets and Loaders
# =======================================================
train_set = EEGDataset(train_df, BASE_PATH, le, csp=csp, augment=True)
val_set = EEGDataset(val_df, BASE_PATH, le, csp=csp, augment=False)

train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False, num_workers=2)

# =======================================================
# 9) Training Setup
# =======================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CNN_MLP(
    in_channels=4,  # CSP components
    tabular_dim=4 + 1,  # means + diff
    num_classes=len(le.classes_)
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
best_val_f1 = 0
patience = 15
patience_counter = 0

history = {'train_loss': [], 'val_loss': [], 'val_f1': []}

# =======================================================
# 10) Train with Early Stopping
# =======================================================
for epoch in range(100):
    model.train()
    train_loss = 0.0
    for x_wave, x_tab, y in train_loader:
        x_wave, x_tab, y = x_wave.to(device), x_tab.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x_wave, x_tab)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * y.size(0)

    train_loss /= len(train_loader.dataset)

    model.eval()
    val_loss = 0.0
    y_true, y_pred = [], []
    with torch.no_grad():
        for x_wave, x_tab, y in val_loader:
            x_wave, x_tab, y = x_wave.to(device), x_tab.to(device), y.to(device)
            out = model(x_wave, x_tab)
            loss = criterion(out, y)
            val_loss += loss.item() * y.size(0)
            y_true.extend(y.cpu().numpy())
            y_pred.extend(out.argmax(dim=1).cpu().numpy())

    val_loss /= len(val_loader.dataset)
    val_f1 = f1_score(y_true, y_pred, average='macro')

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_f1'].append(val_f1)

    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f} | Val Loss={val_loss:.4f} | Val F1={val_f1:.4f}")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("✅ Early stopping triggered!")
            break

# =======================================================
# 11) Save logs
# =======================================================
np.savez('training_log.npz', **history)
print("✅ Training log saved: training_log.npz")

# =======================================================
# 12) Confusion Matrix
# =======================================================
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

all_true, all_pred = [], []
with torch.no_grad():
    for x_wave, x_tab, y in val_loader:
        x_wave, x_tab = x_wave.to(device), x_tab.to(device)
        out = model(x_wave, x_tab)
        all_true.extend(y.numpy())
        all_pred.extend(out.argmax(dim=1).cpu().numpy())

cm = confusion_matrix(all_true, all_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=le.classes_)
disp.plot(cmap=plt.cm.Blues)
plt.title("Validation Confusion Matrix")
plt.show()


Train MI shape: (2400, 6)
Val MI shape: (50, 6)
Classes: ['Left' 'Right']
Fitting CSP...
Computing rank from data with rank=None
    Using tolerance 1e+09 (2.2e-16 eps * 2250 dim * 2.1e+21  max singular value)
    Estimated rank (data): 2250
    data: rank 2250 computed from 2250 data channels with 0 projectors
Reducing data rank from 2250 -> 2250
Estimating class=0 covariance using EMPIRICAL
Done.
Estimating class=1 covariance using EMPIRICAL
Done.


AxisError: Caught AxisError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/tmp/ipykernel_35/4202493367.py", line 86, in __getitem__
    means = np.mean(trial, axis=1)
            ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py", line 3504, in mean
    return _methods._mean(a, axis=axis, dtype=dtype,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py", line 106, in _mean
    rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py", line 77, in _count_reduce_items
    items *= arr.shape[mu.normalize_axis_index(ax, arr.ndim)]
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
numpy.exceptions.AxisError: axis 1 is out of bounds for array of dimension 1
