# Prepare the folder

In [77]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [78]:
!git clone https://github.com/atomiiw/EEG-Model-Fine-tune.git

fatal: destination path 'EEG-Model-Fine-tune' already exists and is not an empty directory.


In [40]:
%cd EEG-Model-Fine-tune/MIRepNet

/content/EEG-Model-Fine-tune/MIRepNet


In [None]:
!pip install -r MIRepNet/requirements.txt

# Baseline Performance: Before Fine-tuning
Current output: among {0, 1, 2, 3}   
Expected output: among {0, 1, ..., 7, 8}   
Current accuracy: 8%-15%   
Accuracy if just randomly guessing: 11%     

Why does accuracy differ every time?  
'Loaded 108/110 parameters from pretrained model'   
The 2 final layer weights are randomly initialized


In [41]:
print("Working dir:", os.getcwd())


Working dir: /content/EEG-Model-Fine-tune/MIRepNet


In [42]:
import torch
import numpy as np
from model.mlm import mlm_mask
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

# ==== CONFIG ====
DATASET_NAME = "basic"
WEIGHT_PATH = "weight/MIRepNet.pth"   # pretrained weights (4-class)
BATCH_SIZE = 32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ==== LOAD DATA ====
X = np.load(f'data/{DATASET_NAME}/X_test.npy')   # (N, 128, 200)
y = np.load(f'data/{DATASET_NAME}/labels_test.npy')  # (N,)
print("Loaded data:", X.shape, y.shape)

# convert to tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)

dataset = TensorDataset(X_tensor, y_tensor)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

# ==== LOAD MODEL ====
model = mlm_mask(
    emb_size=256,
    depth=6,
    n_classes=4,     # pretrained model expects 4 outputs
    pretrainmode=False,
    pretrain=WEIGHT_PATH
).to(DEVICE)

model.eval()

# ==== EVALUATE ====
correct = 0
total = 0

with torch.no_grad():
    for data, labels in loader:
        data, labels = data.to(DEVICE), labels.to(DEVICE)
        _, outputs = model(data)
        # expected to return amongst {0, 1, 2, 3}
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total * 100
print(f"\n✅ Raw pretrained MIRepNet accuracy on your dataset: {accuracy:.2f}%")
print(f"Correct: {correct} / {total}")

Loaded data: (144, 128, 200) (144,)
Loaded 108/110 parameters from pretrained model

✅ Raw pretrained MIRepNet accuracy on your dataset: 11.81%
Correct: 17 / 144


# Train across all patients

In [43]:
%cd MIRepNet

[Errno 2] No such file or directory: 'MIRepNet'
/content/EEG-Model-Fine-tune/MIRepNet


In [31]:
!python finetune.py --dataset_name basic --model_name MIRepNet --num_classes 9 --val_split 0.8 --epochs 10


Starting EEG Classification with Configurable Hyperparameters

original data shape: (1096, 128, 200) labels shape: (1096,)
preprocessed data shape: (1096, 128, 200) preprocessed labels shape: (1096,)
Loaded 108/110 parameters from pretrained model
Seed: 666, Subject: 0

Predicted: [5 5 5 7 5 5 5 5 7 5 5 7 5 5 5 5 5 5 5 5 5 5 5 7 7 5 5 5 5 5 5 7]
Actual:    [5 0 6 1 3 5 3 0 1 1 8 1 8 2 6 7 2 8 2 4 6 7 5 4 7 5 2 3 7 7 8 3]
Got 102 out of 877 correct. Accuracy: 11.630558722919043.
Predicted: [7 6 3 7 3 3 3 3 7 3 1 7 5 7 8 3 3 5 3 3 3 3 3 2 8 8 3 5 3 3 3 3]
Actual:    [5 0 6 1 3 5 3 0 1 1 8 1 8 2 6 7 2 8 2 4 6 7 5 4 7 5 2 3 7 7 8 3]
Got 126 out of 877 correct. Accuracy: 14.367160775370582.
Predicted: [3 3 3 3 3 3 3 3 3 3 5 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3]
Actual:    [5 0 6 1 3 5 3 0 1 1 8 1 8 2 6 7 2 8 2 4 6 7 5 4 7 5 2 3 7 7 8 3]
Got 127 out of 877 correct. Accuracy: 14.481185860889395.
Predicted: [7 3 7 7 7 7 3 5 7 7 7 3 7 7 3 7 7 5 3 7 7 7 7 3 7 3 3 7 5 3 3 3]
Actual:    [5 0 

# Train on each patient individually




In [139]:
!python finetune.py --dataset_name S14_testing --model_name MIRepNet --num_classes 9 --val_split 0.2 --epochs 5


Starting EEG Classification with Configurable Hyperparameters

original data shape: (1290, 128, 200) labels shape: (1290,)
preprocessed data shape: (1290, 128, 200) preprocessed labels shape: (1290,)
🔧 Using local process_and_replace_loader from finetune.py
🔧 Using local process_and_replace_loader from finetune.py
✅ Saved EA matrix to ./weight/S14_testing_EA_matrix.npy
Loaded 108/110 parameters from pretrained model
Seed: 666, Subject: 0

Predicted: [3 0 0 7 8 0 7 0 7 6 0 6 0 6 8 6 2 4 6 7 3 4 8 5 8 0 4 5 0 8 0 5]
Actual:    [3 8 0 2 2 7 7 0 2 6 8 6 0 6 2 6 8 4 6 2 3 4 8 5 2 8 4 6 6 8 0 5]
Got 164 out of 258 correct. Accuracy: 63.565891472868216.
Predicted: [3 8 0 2 2 8 7 0 2 6 8 6 0 6 2 6 8 4 6 2 3 4 8 5 2 8 4 5 6 8 0 5]
Actual:    [3 8 0 2 2 7 7 0 2 6 8 6 0 6 2 6 8 4 6 2 3 4 8 5 2 8 4 6 6 8 0 5]
Got 227 out of 258 correct. Accuracy: 87.98449612403101.
Predicted: [3 8 0 2 2 7 7 0 2 6 8 6 0 6 2 6 8 4 6 2 3 4 8 5 2 8 4 6 6 8 0 5]
Actual:    [3 8 0 2 2 7 7 0 2 6 8 6 0 6 2 6 8 4 6 2 3 4 8

# Test on new data

In [140]:
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from scipy.linalg import fractional_matrix_power
from model.mlm import mlm_mask
from utils.channel_list import use_channels_names, channel_positions
from scipy.spatial.distance import cdist

# ==== CONFIG ====
DATASET_NAME = "S14_testing"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ==== Load data ====
X = np.load(f"data/{DATASET_NAME}/X.npy")  # (N,128,200)
y = np.load(f"data/{DATASET_NAME}/labels.npy")

# ==== 1️⃣ Apply Euclidean Alignment (exactly same direction) ====
refEA = np.load(f"weight/{DATASET_NAME}_EA_matrix.npy")
X_ea = np.zeros_like(X)
for i in range(X.shape[0]):
    X_ea[i] = np.dot(refEA, X[i])    # same as training (not transposed)

# ==== 2️⃣ Channel interpolation to 45 channels ====
def pad_missing_channels_diff(x, target_channels, actual_channels):
    B, C, T = x.shape
    existing_pos = np.array([channel_positions[ch] for ch in actual_channels])
    target_pos = np.array([channel_positions[ch] for ch in target_channels])

    W = np.zeros((len(target_channels), C))
    for i, (target_ch, pos) in enumerate(zip(target_channels, target_pos)):
        if target_ch in actual_channels:
            src_idx = actual_channels.index(target_ch)
            W[i, src_idx] = 1.0
        else:
            dist = cdist([pos], existing_pos)[0]
            weights = 1 / (dist + 1e-6)
            weights /= weights.sum()
            W[i] = weights

    padded = np.zeros((B, len(target_channels), T))
    for b in range(B):
        padded[b] = W @ x[b]
    return padded

X_final = pad_missing_channels_diff(X_ea, use_channels_names, use_channels_names)
print("Shape after projection:", X_final.shape)  # (N,45,200)

# ==== 3️⃣ Convert to tensors ====
X_tensor = torch.tensor(X_final, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)

loader = DataLoader(TensorDataset(X_tensor, y_tensor), batch_size=32)

# ==== 4️⃣ Load model ====
model = mlm_mask(emb_size=256, depth=6, n_classes=9, pretrainmode=False).to(DEVICE)
model.load_state_dict(torch.load(f"weight/{DATASET_NAME}_MIRepNet_finetuned.pth", map_location=DEVICE))
model.eval()

# ==== 5️⃣ Evaluate ====
correct = total = 0
with torch.no_grad():
    for data, labels in loader:
        data, labels = data.to(DEVICE), labels.to(DEVICE)
        _, outputs = model(data)
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

acc = correct / total * 100
print(f"✅ Test accuracy on {DATASET_NAME}: {acc:.2f}% ({correct}/{total})")

Shape after projection: (1290, 45, 200)
✅ Test accuracy on S14_testing: 99.22% (1280/1290)
