# Finetuning Notebook for NeuroBench

## Todo

- size statistics on train/val/test jsons

### Imports

In [1]:
import json
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from sklearn.preprocessing import StandardScaler
import os
import pickle
from tqdm import tqdm
import lightning as L
import torch.nn as nn
from torch.utils.data import random_split
from lightning.pytorch.callbacks import ModelCheckpoint
import sys
import random

sys.path.append("/home/maxihuber/eeg-foundation/")
L.seed_everything(42)

[rank: 0] Seed set to 42


42

### Shorten Jsons for Debugging

In [None]:
def shorten_json(data_index_path, store_path, num_samples):
    with open(data_index_path, 'r') as f:
        index_dict = json.load(f)
    short_index_dict = {}
    task_name = list(index_dict.keys())[0]
    short_index_dict[task_name] = [sample for sample in index_dict[task_name][:num_samples]]
    short_index_dict["task_type"] = index_dict["task_type"]
    with open(store_path, 'w') as file:
        json.dump(short_index_dict, file)
        print(f"Stored to {store_path}")

shorten_json("/itet-stor/maxihuber/deepeye_storage/index_files/EEGEyeNet_Direction_train.json",
             "/itet-stor/maxihuber/deepeye_storage/index_files/EEGEyeNet_Direction_train_small.json",
             50)

shorten_json("/itet-stor/maxihuber/deepeye_storage/index_files/EEGEyeNet_Direction_val.json",
             "/itet-stor/maxihuber/deepeye_storage/index_files/EEGEyeNet_Direction_val_small.json",
             10)

shorten_json("/itet-stor/maxihuber/deepeye_storage/index_files/EEGEyeNet_Direction_test.json",
             "/itet-stor/maxihuber/deepeye_storage/index_files/EEGEyeNet_Direction_test_small.json",
             20)

### Adjust pathnames from kard -> maxihuber

In [None]:
def adjust_json_paths(data_index_path, store_path):
    with open(data_index_path, 'r') as f:
        index_dict = json.load(f)
    
    index_dict_new = {}
    task_name = list(index_dict.keys())[0]
    index_dict_new[task_name] = []  # Initialize the list for the task
    
    for sample in index_dict[task_name]:
        new_input = []
        for file_path in sample["input"]:
            # Remove the /itet-stor/kard path prefix and replace with /itet-stor/maxihuber
            modified_path = file_path.replace("/itet-stor/kard", "/itet-stor/maxihuber")
            new_input.append(modified_path)
        sample["input"] = new_input
        index_dict_new[task_name].append(sample)
    
    index_dict_new["task_type"] = index_dict["task_type"]
    
    with open(store_path, 'w') as file:
        json.dump(index_dict_new, file)
        print(f"Stored to {store_path}")

adjust_json_paths("/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_train.json",
                  "/itet-stor/maxihuber/deepeye_storage/index_files/EEGEyeNet_Direction_train.json")

adjust_json_paths("/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_val.json",
                  "/itet-stor/maxihuber/deepeye_storage/index_files/EEGEyeNet_Direction_val.json")

adjust_json_paths("/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_test.json",
                  "/itet-stor/maxihuber/deepeye_storage/index_files/EEGEyeNet_Direction_test.json")

### Load Train/Val/Test Information

In [2]:
prefix_filepath = "/itet-stor/maxihuber/deepeye_storage/foundation_clinical_prepared/"

class_name = "Clinical"
time_col = "Time in Seconds"
task_channels = set(['AF3', 'AF4', 'AF7', 'AF8', 'AFz', 
                     'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 
                     'CP1', 'CP2', 'CP3', 'CP4', 'CP5', 'CP6', 'CPz', 'Cz', 
                     'F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 
                     'FC1', 'FC2', 'FC3', 'FC4', 'FC5', 'FC6', 'FCz', 
                     'FT7', 'FT8', 'Fp1', 'Fp2', 'Fz', 'Mastoids', 
                     'O1', 'O2', 'Oz', 'P1', 'P2', 'P3', 'P4', 'P5', 'P6', 'P7', 'P8', 
                     'PO3', 'PO4', 'PO7', 'PO8', 'POz', 'Pz', 'T7', 'T8', 'TP7', 'TP8', 
                     'Veog', 'X', 'Y', 'Z'])

age = {
    "task_name": "Age",
    "task_type": "Regression",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_clinical_tasks/age.json"
}

depression = {
    "task_name": "Depression",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_clinical_tasks/cli_depression.json"
}

parkinsons = {
    "task_name": "Parkinsons",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_clinical_tasks/cli_parkinsons.json"
}

schizophrenia = {
    "task_name": "Schizophrenia",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_clinical_tasks/cli_schizophrenia.json"
}

sex = {
    "task_name": "Sex",
    "task_type": "Classification",
    "json_path": "/itet-stor/maxihuber/deepeye_storage/foundation_clinical_tasks/sex.json"
}

In [3]:
json_path = age["json_path"]
task_type = age["task_type"]
task_name = age["task_name"]

def load_index(data_index_path):
    with open(data_index_path, 'r') as f:
        train_test_dict = json.load(f)
    train_samples = train_test_dict['train']
    test_samples = train_test_dict['test']
    return train_samples, test_samples

def load_file_data(data_index, task_channels):
    num_samples = 0
    data = {}
    outputs = {}
    srs = {}
    durs = {}
    channels = {}
    for sample in tqdm(data_index, desc="Loading data"):
        # Load and concatenate dataframe
        input_files = sample["input"]
        df = pd.DataFrame()
        for file in input_files:
            file = prefix_filepath + file
            with open(file, 'rb') as f:
                df_new = pickle.load(f)
                df = pd.concat([df, df_new], axis=0)
        start, length = int(sample["start"]), int(sample["length"])
        df = df.iloc[start:length, :]
        # Add metadata
        sr = int(1 / float(float(df[time_col].iloc[1]) - float(df[time_col].iloc[0])))
        outputs[num_samples] = sample["output"]
        srs[num_samples] = sr
        durs[num_samples] = len(df) / sr
        channels[num_samples] = list(set(df.columns) & task_channels)
        df = df[channels[num_samples]].astype(float)
        signals = torch.tensor(df.to_numpy(), dtype=torch.float32).T
        data[num_samples] = signals
        num_samples += 1
    return data, outputs, srs, durs, channels

train_index, test_index = load_index(json_path)

train_index = train_index[:10]
test_index = test_index[:5]

train_data, train_outputs, train_sr, train_dur, train_channels = load_file_data(train_index, task_channels)
test_data, test_outputs, test_sr, test_dur, test_channels = load_file_data(test_index, task_channels)

Loading data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [06:15<00:00, 37.54s/it]
Loading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:59<00:00, 47.98s/it]


### Preload Data into Run Memory

In [None]:
def get_file_to_channels(pretrain_index):
    with open(pretrain_index, 'r') as file:
        pretrain_index = json.load(file)
    file_to_channels = {}
    for ie in pretrain_index:
        file_to_channels[ie["path"]] = ie["good_channels"] if "good_channels" in ie else ie["channels"]
    return file_to_channels

def get_file_to_sr(pretrain_index):
    with open(pretrain_index, 'r') as file:
        pretrain_index = json.load(file)
    file_to_sr = {}
    for ie in pretrain_index:
        file_to_sr[ie["path"]] = ie["sr"]
    return file_to_sr



def load_file_data(data_index, file_to_channels, file_to_sr):
    num_samples = 0
    data = {}
    outputs = {}
    srs = {}
    durs = {}
    channels = {}
    for sample in tqdm(data_index, desc="Loading data"):
        input_files = sample["input"]
        df = pd.DataFrame()
        df_channels = []
        for file in input_files:
            with open(file, 'rb') as f:
                df_new = pickle.load(f)
                df = pd.concat([df, df_new], axis=0)
            #sr = file_to_sr[file]
            sr = 500
            #df_channels.append(file_to_channels[file])
        df = df.loc[sample["start"]:sample["start"]+sample["length"]-1, :]
        assert df.shape[0] == sample["length"], f"df.shape[0]={df.shape[0]}, sample['length']={sample['length']}"
        outputs[num_samples] = sample["output"]
        srs[num_samples] = sr
        durs[num_samples] = sample["length"] / sr
        #unique_channels = set(df_channels)
        unique_channels = set([f"E{i}" for i in range(1,129)] + ["AVG_REF"])
        channels[num_samples] = unique_channels
        df = df.loc[:, df.columns.intersection(unique_channels)]
        data[num_samples] = df
        num_samples += 1
    return data, outputs, srs, durs, channels

file_to_channels = get_file_to_channels(pretrain_index)
file_to_sr = get_file_to_sr(pretrain_index)

train_data, train_outputs, train_sr, train_dur, train_channels = load_file_data(train_index, file_to_channels, file_to_sr)
val_data, val_outputs, val_sr, val_dur, val_channels = load_file_data(val_index, file_to_channels, file_to_sr)
test_data, test_outputs, test_sr, test_dur, test_channels = load_file_data(test_index, file_to_channels, file_to_sr)

## EEG-Transformer

In [4]:
#########################################################################################################
# Label Encoder

from sklearn.preprocessing import LabelEncoder

# Collect all unique class labels from train and test outputs
all_outputs = list(train_outputs.values()) + list(test_outputs.values())
if isinstance(all_outputs[0], str):  # Only if outputs are strings (classification)
    label_encoder = LabelEncoder()
    label_encoder.fit(all_outputs)
else:
    label_encoder = None  # No encoding needed for regression

class FinetuneDataset(Dataset):
    def __init__(self, data, outputs, srs, durs, channels, task_type, label_encoder=None):
        self.data = data
        self.outputs = outputs
        self.srs = srs
        self.durs = durs
        self.channels = channels
        self.task_type = task_type
        self.label_encoder = label_encoder
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        signals = self.data[idx]
        output = self.outputs[idx]
        sr = self.srs[idx]
        dur = self.durs[idx]
        channels = self.channels[idx]

        if self.task_type == "Classification" and self.label_encoder is not None:
            output = self.label_encoder.transform([output])[0]  # Encode the output label
            output_tensor = torch.tensor(output, dtype=torch.long)
        else:
            output_tensor = torch.tensor([output], dtype=torch.float32)
        
        return {
            "signals": signals,
            "output": output_tensor,
            "sr": sr,
            "dur": dur,
            "channels": channels,
        }
        
full_train_dataset = FinetuneDataset(train_data, train_outputs, train_sr, train_dur, train_channels, task_type=task_type, label_encoder=label_encoder)
test_dataset = FinetuneDataset(test_data, test_outputs, test_sr, test_dur, test_channels, task_type=task_type, label_encoder=label_encoder)

# Define the split ratio
train_ratio = 0.8
val_ratio = 0.2

# Calculate lengths for train and validation sets
total_size = len(full_train_dataset)
train_size = int(train_ratio * total_size)
val_size = total_size - train_size

# Split the dataset
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

#########################################################################################################
# DataLoaders
import torchaudio
from src.data.transforms import (
    crop_spg,
    custom_fft,
    normalize_spg,
)

self_win_shifts = [.25, .5, 1, 2, 4, 8]
self_patch_size = 16
self_win_shift_factor = .25
self_max_win_shift = self_win_shifts[-1]
self_max_y_datapoints = 4_000

def get_nr_y_patches(win_size, sr):
    return int((sr / 2 * win_size + 1) / self_patch_size)

def get_nr_x_patches(win_size, dur):
    win_shift = win_size * self_win_shift_factor
    x_datapoints_per_second = 1 / win_shift
    x_datapoints = dur * x_datapoints_per_second + 1
    return int(x_datapoints // self_patch_size)

channel_name_map_path = '/home/maxihuber/eeg-foundation/src/data/components/channels_to_id.json'
with open(channel_name_map_path, "r") as file:
    self_channel_name_map = json.load(file)

def self_get_generic_channel_name(channel_name):
    channel_name = channel_name.lower()
    # Remove "eeg " prefix if present
    if channel_name.startswith("eeg "):
        channel_name = channel_name[4:]
    # Simplify names with a dash and check if it ends with "-"
    if "-" in channel_name:
        if channel_name.endswith("-"):
            return "None"
        return channel_name.split("-")[0]
    return channel_name

def self_encode_mean(mean, win_size):
    y_datapoints = mean.shape[0]
    encoded_mean = torch.zeros(self_max_y_datapoints)
    step_size = int(self_max_win_shift // win_size)
    end_idx = step_size * y_datapoints
    indices = torch.arange(0, end_idx, step_size)
    encoded_mean[indices] = mean.squeeze_().float()
    encoded_mean.unsqueeze_(1)
    return encoded_mean

#########################################################################################################
# collate_fn
def sample_collate_fn(batch):

    signals, output, sr, dur, channels = batch[0]["signals"], batch[0]["output"], batch[0]["sr"], batch[0]["dur"], batch[0]["channels"]

    if dur > 3_600:
        dur = 3_600
        signals = signals[:, :3_600*sr]
    
    #print(f"[collate_fn] sr={sr}")
    
    # TODO: compute spectrograms for each win_size
    # gives a new dimension (S) in batch
    # need another extra transformer after the encoder
    # (B, 1, H, W) -> (S*B, 1, H, W)
    valid_win_shifts = [
        win_shift
        for win_shift in self_win_shifts
        if get_nr_y_patches(win_shift, sr) >= 1
        and get_nr_x_patches(win_shift, dur) >= 1
    ]

    # list holding assembled tensors for varying window shifts
    full_batch = {}   

    for win_size in valid_win_shifts:
        
        fft = torchaudio.transforms.Spectrogram(
            n_fft=int(sr * win_size),
            win_length=int(sr * win_size),
            hop_length=int(sr * win_size * self_win_shift_factor),
            normalized=True,
        )
    
        spg_list = []
        chn_list = []
        mean_list = []
        std_list = []
    
        for signal, channel in zip(signals, channels):
            
            # Channel information
            channel_name = self_get_generic_channel_name(channel)
            channel = self_channel_name_map[channel_name] if channel_name in self_channel_name_map else self_channel_name_map["None"]
    
            # Spectrogram Computation & Cropping
            spg = fft(signal)
            spg = spg**2
            spg = crop_spg(spg, self_patch_size)
            
            H_new, W_new = spg.shape[0], spg.shape[1]
            h_new, w_new = H_new // self_patch_size, W_new // self_patch_size
    
            # Prepare channel information (per-patch)
            channel = torch.full((h_new, w_new), channel, dtype=torch.float16)
            
            spg, mean, std = normalize_spg(spg)
            mean = self_encode_mean(mean, win_size)
            std = self_encode_mean(std, win_size)
            
            spg_list.append(spg)
            chn_list.append(channel)
            mean_list.append(mean)
            std_list.append(std)
        
        win_batch = torch.stack(spg_list)
        win_channels = torch.stack(chn_list)
        win_means = torch.stack(mean_list)
        win_stds = torch.stack(std_list)
        
        win_batch.unsqueeze_(1)
        win_channels = win_channels.flatten(1)
        win_means = win_means.transpose(1, 2)
        win_stds = win_stds.transpose(1, 2)
        
        full_batch[win_size] = {
            "batch": win_batch,
            "channels": win_channels,
            "means": win_means,
            "stds": win_stds
        }
        #print(f"[collate_fn] win_size={win_size}: {win_batch.shape}")
        
    # == Finished iterating over all possible window shifts
   
    return full_batch, output

train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=sample_collate_fn, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=sample_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=sample_collate_fn)

#########################################################################################################
# Model
# == Metrics ==
def rmse(y_true, y_pred):
    return torch.sqrt(torch.mean((y_true - y_pred) ** 2))

from sklearn.metrics import balanced_accuracy_score

def balanced_accuracy(y_true, y_pred):
    return balanced_accuracy_score(y_true, y_pred)

from functools import partial

from src.models.mae_rope_encoder import EncoderViTRoPE
from src.models.components.vit_rope import (
    Flexible_RoPE_Layer_scale_init_Block,
    FlexibleRoPEAttention,
    compute_axial_cis,
    select_freqs_cis,
)
from timm.models.vision_transformer import Mlp as Mlp

from torch.nn import TransformerEncoderLayer
class SingleTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead):
        super(SingleTransformerEncoderLayer, self).__init__()
        self.encoder_layer = TransformerEncoderLayer(d_model, nhead)

    def forward(self, src):
        return self.encoder_layer(src)

def mean_aggregation(tokens):
    return torch.mean(torch.stack(tokens), dim=0)

from sklearn.metrics import balanced_accuracy_score
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torchmetrics

class FineTuningModel(L.LightningModule):
    def __init__(self, encoder, frozen_encoder, task_type, learning_rate, mask_ratio):
        super(FineTuningModel, self).__init__()
        
        self.task_type = task_type
        self.learning_rate = learning_rate
        self.mask_ratio = mask_ratio

        # Pretrained network
        self.encoder = encoder       
        if frozen_encoder:
            self.freeze_encoder()

        # Finetuning network
        self.finetune_time_transformer = Flexible_RoPE_Layer_scale_init_Block(
            dim=384,
            num_heads=6,
            mlp_ratio=4,
            qkv_bias=True,
            drop=0.0,
            attn_drop=0.0,
            drop_path=0.0,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            act_layer=nn.GELU,
            Attention_block=FlexibleRoPEAttention,
            Mlp_block=Mlp,
            init_values=1e-4,
        )
        
        # Single 1D transformer encoder layer
        self.finetune_channel_transformer = SingleTransformerEncoderLayer(
            d_model=384,  # Match the dimension used in finetune_time_transformer
            nhead=6       # Number of heads in the multiheadattention models
        )

        # Modular aggregation method on channel tokens
        self.win_shift_aggregation = mean_aggregation
        
        if task_type == "Regression":
            out_dim = 1
            print(f"[FT.__init__] Regression: out_dim={out_dim}")
            self.head = nn.Linear(encoder.encoder_embed_dim, out_dim)
            self.criterion = nn.MSELoss()
        else:
            out_dim = len(set(v for k, v in train_outputs.items()))
            print(f"[FT.__init__] Classification: out_dim={out_dim}")
            self.head = nn.Linear(encoder.encoder_embed_dim, out_dim)
            self.criterion = nn.CrossEntropyLoss()

        # Initialize buffers for validation and test predictions and targets
        self.validation_step_outputs = []
        self.test_step_outputs = []

    def forward(self, full_x):
        
        x_embeds = {}
        H_W = {}
        
        for win_size, x_win in full_x.items():
            spgs = x_win["batch"]
            channels = x_win["channels"]
            means = x_win["means"]
            stds = x_win["stds"]
            B, C, H, W = spgs.shape
            # TODO: split into less rows if necessary because of CUDA error
            #nr_tokens = B * C * H * W
            #if nr_tokens > max_nr_tokens:
            #    pass
            x_emb, _, _, nr_meta_patches = self.encoder(
                x=spgs,
                means=means,
                stds=stds,
                channels=channels,
                win_size=win_size,
                mask_ratio=self.mask_ratio,
            )
            # TODO: 
            x_embeds[win_size] = x_emb
            H_W[win_size] = (H, W)
            #print(f"[FT.forward, after self.encoder] x_emb.shape: {x_emb.shape}")

        # Pass through time-transformer
        for win_size, x_emb in x_embeds.items():
            freqs_cis = select_freqs_cis(
                self.encoder, self.encoder.encoder_freqs_cis, H_W[win_size][0], H_W[win_size][1], win_size, x_emb.device
            )
            x_emb = self.finetune_time_transformer(x_emb, freqs_cis=freqs_cis, nr_meta_tokens=nr_meta_patches)
            #print(f"[FT.forward, after self.time_transformer] x_emb.shape: {x_emb.shape}")
            x_emb = x_emb[:, 0]
            #print(f"[FT.forward, after time-token] x_emb.shape: {x_emb.shape}")
            x_embeds[win_size] = x_emb

        # Pass through channel-transformer
        tokens = []
        for win_size, x_emb in x_embeds.items():
            x_emb = x_emb.unsqueeze(0)
            #print(f"[FT.forward, before channel-token] x_emb.shape: {x_emb.shape}")
            x_emb = self.finetune_channel_transformer(x_emb)  # Adding a batch dimension
            x_emb = x_emb[0, 0]
            #print(f"[FT.forward, after channel-token] x_emb.shape: {x_emb.shape}")
            tokens.append(x_emb)

        #print(f"[FT.forward] len(tokens): {len(tokens)}")
        # Average over all window shifts
        smart_token = self.win_shift_aggregation(tokens)
        #print(f"[FT.forward] smart_token.shape: {smart_token.shape}")

        # Pass through head
        y_hat = self.head(smart_token)
        
        return y_hat

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(input=y_hat, target=y)
        self.log('train_loss', loss)
        print(f"[training_step] y_hat={y_hat}, y={y} -> loss={loss}")
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(input=y_hat, target=y)
        self.log('val_loss', loss, prog_bar=True)
        print(f"[validation_step] y_hat={y_hat}, y={y} -> loss={loss}")

        if self.task_type == "Classification":
            y_pred = torch.argmax(y_hat, dim=1)
            self.validation_step_outputs.append((y.cpu(), y_pred.cpu()))
        elif self.task_type == "Regression":
            self.validation_step_outputs.append((y.cpu(), y_hat.cpu()))

        return loss

    def on_validation_epoch_end(self):
        if self.task_type == "Classification":
            y_true = torch.cat([x[0] for x in self.validation_step_outputs], dim=0)
            y_pred = torch.cat([x[1] for x in self.validation_step_outputs], dim=0)
            balanced_acc = balanced_accuracy_score(y_true.numpy(), y_pred.numpy())
            self.log('val_balanced_accuracy', balanced_acc, prog_bar=True)
        elif self.task_type == "Regression":
            y_true = torch.cat([x[0] for x in self.validation_step_outputs], dim=0)
            y_pred = torch.cat([x[1] for x in self.validation_step_outputs], dim=0)
            rmse_value = rmse(y_true, y_pred)
            self.log('val_rmse', rmse_value, prog_bar=True)

        # Clear the buffer
        self.validation_step_outputs.clear()

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(input=y_hat, target=y)
        self.log('test_loss', loss, prog_bar=True)

        if self.task_type == "Classification":
            y_pred = torch.argmax(y_hat, dim=1)
            self.test_step_outputs.append((y.cpu(), y_pred.cpu()))
        elif self.task_type == "Regression":
            self.test_step_outputs.append((y.cpu(), y_hat.cpu()))

        return loss

    def on_test_epoch_end(self):
        if self.task_type == "Classification":
            y_true = torch.cat([x[0] for x in self.test_step_outputs], dim=0)
            y_pred = torch.cat([x[1] for x in self.test_step_outputs], dim=0)
            balanced_acc = balanced_accuracy_score(y_true.numpy(), y_pred.numpy())
            self.log('test_balanced_accuracy', balanced_acc, prog_bar=True)
        elif self.task_type == "Regression":
            y_true = torch.cat([x[0] for x in self.test_step_outputs], dim=0)
            y_pred = torch.cat([x[1] for x in self.test_step_outputs], dim=0)
            rmse_value = rmse(y_true, y_pred)
            self.log('test_rmse', rmse_value, prog_bar=True)

        # Clear the buffer
        self.test_step_outputs.clear()

    def configure_optimizers(self):
        return torch.optim.Adam(self.head.parameters(), lr=self.learning_rate)

    def freeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = False

    def unfreeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = True

#########################################################################################################
# Instantiate
# Load the checkpoint
chkpt_path = "/itet-stor/maxihuber/net_scratch/checkpoints/977598/epoch=0-step=32807-val_loss=133.55.ckpt"
checkpoint = torch.load(chkpt_path, map_location=torch.device('cpu'))
state_dict = checkpoint['state_dict']
state_dict = {k.replace("net.encoder.", ""): v for k, v in state_dict.items() if "net.encoder." in k}

# Initialize the encoder and load the state dict
encoder = EncoderViTRoPE(channel_name_map_path)
encoder.load_state_dict(state_dict)

# Instantiate the fine-tuning model
fine_tuning_model = FineTuningModel(encoder=encoder, frozen_encoder=True, task_type=task_type, learning_rate=0.05, mask_ratio=0)

#########################################################################################################
# Load the checkpoint
chkpt_path = "/itet-stor/maxihuber/net_scratch/checkpoints/977598/epoch=0-step=32807-val_loss=133.55.ckpt"
checkpoint = torch.load(chkpt_path, map_location=torch.device('cpu'))
state_dict = checkpoint['state_dict']
state_dict = {k.replace("net.encoder.", ""): v for k, v in state_dict.items() if "net.encoder." in k}

# Initialize the encoder and load the state dict
encoder = EncoderViTRoPE(channel_name_map_path)
encoder.load_state_dict(state_dict)

# Instantiate the fine-tuning model
fine_tuning_model = FineTuningModel(encoder=encoder, frozen_encoder=True, task_type=task_type, learning_rate=0.001, mask_ratio=0)

#########################################################################################################
# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath=f"/itet-stor/maxihuber/deepeye_storage/finetune_ckpts/{task_name}",
    filename="{epoch:02d}-{val_loss:.2f}",
    save_top_k=3,
    monitor="val_loss",
    mode="min",
)

# Train the model
trainer = L.Trainer(
    max_epochs=2,
    callbacks=[checkpoint_callback],
    log_every_n_steps=1,
)

trainer.fit(fine_tuning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

trainer.test(model=fine_tuning_model, dataloaders=test_loader)

[FT.__init__] Regression: out_dim=1
[FT.__init__] Regression: out_dim=1


/itet-stor/maxihuber/net_scratch/conda_envs/fastenv/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.9 /itet-stor/maxihuber/net_scratch/conda_envs/faste ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/itet-stor/maxihuber/net_scratch/conda_envs/fastenv/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:652: Checkpoint directory 

Sanity Checking: |                                                                                            …

/itet-stor/maxihuber/net_scratch/conda_envs/fastenv/lib/python3.9/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 62. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[validation_step] y_hat=tensor([-0.8159], device='cuda:0'), y=tensor([70.], device='cuda:0') -> loss=5014.888671875
[validation_step] y_hat=tensor([-0.8416], device='cuda:0'), y=tensor([79.], device='cuda:0') -> loss=6374.6845703125


Training: |                                                                                                   …

[training_step] y_hat=tensor([-0.7036], device='cuda:0', grad_fn=<ViewBackward0>), y=tensor([79.], device='cuda:0') -> loss=6352.6572265625
[training_step] y_hat=tensor([-0.5966], device='cuda:0', grad_fn=<ViewBackward0>), y=tensor([64.], device='cuda:0') -> loss=4172.72412109375
[training_step] y_hat=tensor([-0.2418], device='cuda:0', grad_fn=<ViewBackward0>), y=tensor([64.], device='cuda:0') -> loss=4127.005859375
[training_step] y_hat=tensor([-0.1002], device='cuda:0', grad_fn=<ViewBackward0>), y=tensor([73.], device='cuda:0') -> loss=5343.64208984375
[training_step] y_hat=tensor([0.2895], device='cuda:0', grad_fn=<ViewBackward0>), y=tensor([70.], device='cuda:0') -> loss=4859.55224609375
[training_step] y_hat=tensor([0.4622], device='cuda:0', grad_fn=<ViewBackward0>), y=tensor([77.], device='cuda:0') -> loss=5858.0341796875
[training_step] y_hat=tensor([0.5085], device='cuda:0', grad_fn=<ViewBackward0>), y=tensor([69.], device='cuda:0') -> loss=4691.091796875
[training_step] y_hat=

Validation: |                                                                                                 …

[validation_step] y_hat=tensor([0.7757], device='cuda:0'), y=tensor([70.], device='cuda:0') -> loss=4792.00537109375
[validation_step] y_hat=tensor([1.0515], device='cuda:0'), y=tensor([79.], device='cuda:0') -> loss=6075.96875
[training_step] y_hat=tensor([1.1959], device='cuda:0', grad_fn=<ViewBackward0>), y=tensor([70.], device='cuda:0') -> loss=4734.001953125
[training_step] y_hat=tensor([1.4763], device='cuda:0', grad_fn=<ViewBackward0>), y=tensor([64.], device='cuda:0') -> loss=3909.214599609375
[training_step] y_hat=tensor([1.4598], device='cuda:0', grad_fn=<ViewBackward0>), y=tensor([64.], device='cuda:0') -> loss=3911.272705078125
[training_step] y_hat=tensor([1.6852], device='cuda:0', grad_fn=<ViewBackward0>), y=tensor([73.], device='cuda:0') -> loss=5085.8046875
[training_step] y_hat=tensor([2.0882], device='cuda:0', grad_fn=<ViewBackward0>), y=tensor([64.], device='cuda:0') -> loss=3833.0654296875
[training_step] y_hat=tensor([2.0856], device='cuda:0', grad_fn=<ViewBackward

Validation: |                                                                                                 …

[validation_step] y_hat=tensor([2.3570], device='cuda:0'), y=tensor([70.], device='cuda:0') -> loss=4575.56884765625
[validation_step] y_hat=tensor([2.9262], device='cuda:0'), y=tensor([79.], device='cuda:0') -> loss=5787.22509765625


`Trainer.fit` stopped: `max_epochs=2` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |                                                                                                    …

[{'test_loss': 4157.4736328125, 'test_rmse': 64.47847747802734}]

# Baseline Models

## Data Handling & Model Template

In [11]:
sys.path.append('/home/maxihuber/eeg-foundation/src/models/components/Baselines')

# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import balanced_accuracy_score
import lightning.pytorch as L
from lightning.pytorch.callbacks import ModelCheckpoint

class SimpleDataset(Dataset):
    def __init__(self, data, outputs, task_type, label_encoder=None):
        self.data = data
        self.outputs = outputs
        self.task_type = task_type
        self.label_encoder = label_encoder
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        signals = self.data[idx]
        output = self.outputs[idx]
        
        if self.task_type == "Classification" and self.label_encoder is not None:
            output = self.label_encoder.transform([output])[0]  # Encode the output label
            output_tensor = torch.tensor(output, dtype=torch.long)
        else:
            output_tensor = torch.tensor([output], dtype=torch.float32)
        
        return {
            "signals": signals,
            "output": output_tensor,
        }

durs = [df.shape[1] for idx, df in train_data.items()] + [df.shape[1] for idx, df in test_data.items()]
n_chns = [df.shape[0] for idx, df in train_data.items()] + [df.shape[0] for idx, df in test_data.items()]
dur_90 = int(np.percentile(durs, 90))
chn_90 = int(np.percentile(n_chns, 90))

def pad_tensor(tensor, target_height, target_width):
    current_height, current_width = tensor.shape

    # Pad height if necessary
    if current_height < target_height:
        padding_height = target_height - current_height
        padding = torch.zeros((padding_height, current_width), dtype=tensor.dtype)
        tensor = torch.cat((tensor, padding), dim=0)
    else:
        tensor = tensor[:target_height, :]

    # Pad width if necessary
    if current_width < target_width:
        padding_width = target_width - current_width
        padding = torch.zeros((tensor.shape[0], padding_width), dtype=tensor.dtype)
        tensor = torch.cat((tensor, padding), dim=1)
    else:
        tensor = tensor[:, :target_width]

    return tensor

train_data_pad = {k: pad_tensor(signals, chn_90, dur_90) for k, signals in train_data.items()}
test_data_pad = {k: pad_tensor(signals, chn_90, dur_90) for k, signals in test_data.items()}

full_train_dataset = SimpleDataset(train_data_pad, train_outputs, task_type=task_type, label_encoder=label_encoder)
test_dataset = SimpleDataset(test_data_pad, test_outputs, task_type=task_type, label_encoder=label_encoder)

# Define the split ratio
train_ratio, val_ratio = 0.8, 0.2

# Calculate lengths for train and validation sets
total_size = len(full_train_dataset)
train_size = int(train_ratio * total_size)
val_size = total_size - train_size

# Split the dataset
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

def sample_collate_fn(batch):
    signals, output = batch[0]["signals"], batch[0]["output"]
    signals = signals.unsqueeze(0)
    print(f"[sample_collate_fn] signals.shape: {signals.shape}")
    return signals, output

# Define the baseline model class
class BaselineModel(L.LightningModule):
    def __init__(self, task_type, learning_rate):
        super(BaselineModel, self).__init__()
        
        self.task_type = task_type
        self.learning_rate = learning_rate

        if task_type == "Regression":
            self.out_dim = 1
            self.criterion = nn.MSELoss()
        else:
            self.out_dim = len(set(v for k, v in train_outputs.items()))
            self.criterion = nn.CrossEntropyLoss()

        self.validation_step_outputs = []
        self.test_step_outputs = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)

        if self.task_type == "Classification":
            y_pred = torch.argmax(y_hat, dim=1)
            self.validation_step_outputs.append((y.cpu(), y_pred.cpu()))
        elif self.task_type == "Regression":
            self.validation_step_outputs.append((y.cpu(), y_hat.cpu()))

        return loss

    def on_validation_epoch_end(self):
        if self.task_type == "Classification":
            y_true = torch.cat([x[0] for x in self.validation_step_outputs], dim=0)
            y_pred = torch.cat([x[1] for x in self.validation_step_outputs], dim=0)
            balanced_acc = balanced_accuracy_score(y_true.numpy(), y_pred.numpy())
            self.log('val_balanced_accuracy', balanced_acc, prog_bar=True)
        elif self.task_type == "Regression":
            y_true = torch.cat([x[0] for x in self.validation_step_outputs], dim=0)
            y_pred = torch.cat([x[1] for x in self.validation_step_outputs], dim=0)
            rmse_value = rmse(y_true, y_pred)
            self.log('val_rmse', rmse_value, prog_bar=True)

        self.validation_step_outputs.clear()

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('test_loss', loss, prog_bar=True)

        if self.task_type == "Classification":
            y_pred = torch.argmax(y_hat, dim=1)
            self.test_step_outputs.append((y.cpu(), y_pred.cpu()))
        elif self.task_type == "Regression":
            self.test_step_outputs.append((y.cpu(), y_hat.cpu()))

        return loss

    def on_test_epoch_end(self):
        if self.task_type == "Classification":
            y_true = torch.cat([x[0] for x in self.test_step_outputs], dim=0)
            y_pred = torch.cat([x[1] for x in self.test_step_outputs], dim=0)
            balanced_acc = balanced_accuracy_score(y_true.numpy(), y_pred.numpy())
            self.log('test_balanced_accuracy', balanced_acc, prog_bar=True)
        elif self.task_type == "Regression":
            y_true = torch.cat([x[0] for x in self.test_step_outputs], dim=0)
            y_pred = torch.cat([x[1] for x in self.test_step_outputs], dim=0)
            rmse_value = rmse(y_true, y_pred)
            self.log('test_rmse', rmse_value, prog_bar=True)

        self.test_step_outputs.clear()

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

# CNN ====================================================

In [None]:
from src.models.components.Baselines.DL_Models.torch_models.CNN.CNN import CNN

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=sample_collate_fn, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=sample_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=sample_collate_fn)

class CNNBaselineModel(BaselineModel):
    def __init__(self, task_type, learning_rate):
        super(CNNBaselineModel, self).__init__(task_type, learning_rate)
        
        self.model = CNN(
            model_name=None,
            path=None,
            loss="dice-loss",
            model_number=None,
            batch_size=1,
            input_shape=[chn_90, dur_90],
            output_shape=self.out_dim
        )

baseline_model = CNNBaselineModel(task_type=task_type, learning_rate=0.001)

# EEGNet ====================================================

In [12]:
from src.models.components.Baselines.DL_Models.torch_models.EEGNet.eegNet import EEGNet

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=sample_collate_fn, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=sample_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=sample_collate_fn)

class EEGNetBaselineModel(BaselineModel):
    def __init__(self, task_type, learning_rate):
        super(EEGNetBaselineModel, self).__init__(task_type, learning_rate)
        
        self.model = EEGNet(
            model_name=None,
            path=None,
            loss="dice-loss",
            model_number=None,
            batch_size=1,
            input_shape=[chn_90, dur_90],
            output_shape=self.out_dim
        )

baseline_model = EEGNetBaselineModel(task_type=task_type, learning_rate=0.001)

2024-06-03 16:27:29,093 - root - INFO - Dice weights: [0.86686687 0.10610611 0.02702703]
2024-06-03 16:27:29,095 - root - INFO - Using loss fct: DiceLoss()
2024-06-03 16:27:29,130 - root - INFO - Number of model parameters: 7196013
2024-06-03 16:27:29,132 - root - INFO - Number of trainable parameters: 7196013


# UNet ======================================================

In [97]:
from src.models.components.Baselines.DL_Models.torch_models.UNet.UNet import UNet

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=sample_collate_fn, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=sample_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=sample_collate_fn)

class UNetBaselineModel(BaselineModel):
    def __init__(self, task_type, learning_rate):
        super(UNetBaselineModel, self).__init__(task_type, learning_rate)
        
        self.model = UNet(
            model_name=None,
            path=None,
            loss="dice-loss",
            model_number=None,
            batch_size=1,
            input_shape=[chn_90, dur_90],
            output_shape=self.out_dim
        )

baseline_model = UNetBaselineModel(task_type=task_type, learning_rate=0.001)

2024-06-03 14:47:14,547 - root - INFO - Dice weights: [0.86686687 0.10610611 0.02702703]
2024-06-03 14:47:14,553 - root - INFO - Using loss fct: DiceLoss()


# xDAWN + LDA ===============================================

# Training

In [13]:
# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath=f"/itet-stor/maxihuber/deepeye_storage/finetune_ckpts/{task_name}_baseline",
    filename="{epoch:02d}-{val_loss:.2f}",
    save_top_k=3,
    monitor="val_loss",
    mode="min",
)

# Train the model
trainer = L.Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback],
    log_every_n_steps=1,
)

trainer.fit(baseline_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(model=baseline_model, dataloaders=test_loader)

Trainer will use only 1 of 6 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=6)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5]

  | Name      | Type    | Params | Mode 
----------------------------------------------
0 | criterion | MSELoss | 0      | train
1 | model     | EEGNet  | 7.2 M  | train
----------------------------------------------
7.2 M     Trainable params
0         Non-trainable params
7.2 M     Total params
28.784    Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

[sample_collate_fn] signals.shape: torch.Size([1, 62, 112151])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (3x62 and 500x1)

## Metrics

Classification: Balanced Accuracy

Regression: mean euclidean distance for 2d eyenet, else RMSE