In [34]:
from pathlib import Path
import argparse
import time
import copy

import torch
import numpy as np
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state
from tqdm.notebook import tqdm

from lr_scheduler import CosineLRScheduler

try:
    from eegdash.dataset import EEGChallengeDataset
    from eegdash.hbn.windows import (
        annotate_trials_with_target,
        add_aux_anchors,
        add_extras_columns,
        keep_only_recordings_with,
    )
except Exception as e:
    EEGChallengeDataset = None

from braindecode.preprocessing import Preprocessor, preprocess, create_windows_from_events
from braindecode.datasets import BaseConcatDataset
import torch.nn as nn
from model.eegmamba_jamba import EegMambaJEPA
import joblib

In [None]:
DATA_PATH = "LOL_DATASET/LOL_DATASET/HBN_DATA_FULL/"
RELEASES = ["R1", "R2", "R3"]

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [29]:
all_dataset = []
for rel in RELEASES:
    name_folder = f"{rel}_mini_L100_bdf" 
    cache_dir = Path(DATA_PATH) / name_folder 

    dataset = EEGChallengeDataset(
        cache_dir = cache_dir,
        task = "contrastChangeDetection",
        mini = True,
        download = False,
        release = rel
    )

    all_dataset.append(dataset)

In [30]:
all_dataset

[<eegdash.dataset.dataset.EEGChallengeDataset at 0x7819ee614f90>,
 <eegdash.dataset.dataset.EEGChallengeDataset at 0x7819ece6cdd0>,
 <eegdash.dataset.dataset.EEGChallengeDataset at 0x7819ece7af10>]

Preprocessing the dataset

In [31]:
EPOCH_LENS_S = 2

In [32]:
def build_offline_preprocessors():
    return [
        Preprocessor(annotate_trials_with_target, 
                     target_field = "rt_from_stimulus", 
                     epoch_length = EPOCH_LENS_S, 
                     require_stimulus = True, 
                     require_response = True, 
                     apply_on_array = False),
        Preprocessor(add_aux_anchors, apply_on_array=False),
    ]


This case we don't need to save the dataset

In [43]:
preproc_dir = Path("preprocessed_dataset")
preproc_dir.mkdir(parents=True, exist_ok=True)

list_windows = []

ANCHOR = "stimulus_anchor"
SHIFT_AFTER_STIM = 0.5
WINDOW_LEN = 2.0
SFREQ = 100

preproc = build_offline_preprocessors()

for i, dataset in enumerate(all_dataset):
    preprocess(dataset, preproc, n_jobs = -1)

    dataset = keep_only_recordings_with(ANCHOR, dataset)
    windows = create_windows_from_events(
        dataset, 
        mapping = {ANCHOR: 0},
        trial_start_offset_samples = int(SHIFT_AFTER_STIM * SFREQ),                 # +0.5 s
        trial_stop_offset_samples = int((SHIFT_AFTER_STIM + WINDOW_LEN) * SFREQ),   # +2.5 s
        window_size_samples = int(WINDOW_LEN * SFREQ),
        window_stride_samples = SFREQ,
        preload=True,
    )

    windows = add_extras_columns(
        windows,
        dataset,
        desc=ANCHOR,
        keys=("target", "rt_from_stimulus", "rt_from_trialstart",
              "stimulus_onset", "response_onset", "correct", "response_type")
    )

    list_windows.append(windows)

    save_path = preproc_dir / f"{RELEASES[i]}_windows.pkl"
    joblib.dump(windows, save_path)


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(

In [44]:
load_windows = []

for rel in RELEASES:
    load_path = preproc_dir / f"{rel}_windows.pkl"
    windows = joblib.load(load_path)
    load_windows.append(windows)

all_windows = BaseConcatDataset(load_windows)

In [45]:
all_windows.get_metadata().head()

Unnamed: 0,i_window_in_trial,i_start_in_trial,i_stop_in_trial,target,rt_from_stimulus,rt_from_trialstart,stimulus_onset,response_onset,correct,response_type,...,thepresent,diaryofawimpykid,contrastchangedetection_1,contrastchangedetection_2,contrastchangedetection_3,surroundsupp_1,surroundsupp_2,seqlearning6target,seqlearning8target,symbolsearch
0,0,4278,4478,2.13,2.13,4.93,42.284,44.414,1,right_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available
1,0,4798,4998,1.96,1.96,4.76,47.484,49.444,1,right_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available
2,0,5478,5678,2.02,2.02,6.42,54.284,56.304,1,right_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available
3,0,6318,6518,1.72,1.72,7.72,62.684,64.404,1,right_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available
4,0,6838,7038,1.8,1.8,4.6,67.884,69.684,1,left_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available


#### Spliting the train test valid

In [47]:
meta = all_windows.get_metadata()
subjects = list(meta['subject'].unique())

valid_frac = 0.1
test_frac = 0.1
seed = 2025

train_subj, valid_test_subject = train_test_split(subjects, test_size=(valid_frac + test_frac), random_state=check_random_state(seed), shuffle=True)
valid_subj, test_subj = train_test_split(valid_test_subject, test_size=test_frac/(valid_frac+test_frac), random_state=check_random_state(seed+1), shuffle=True)

subject_split = windows.split("subject")
train_sets = [ds for subj, ds in subject_split.items() if subj in train_subj]
valid_sets = [ds for subj, ds in subject_split.items() if subj in valid_subj]
test_sets = [ds for subj, ds in subject_split.items() if subj in test_subj]

train_ds = BaseConcatDataset(train_sets)
valid_ds = BaseConcatDataset(valid_sets)
test_ds = BaseConcatDataset(test_sets)

In [51]:
print(len(train_ds))
print(len(valid_ds))
print(len(test_ds))

1110
134
224


In [62]:
class ContrastChangeDataset(torch.utils.data.Dataset):
    def __init__(self, braindecode_dataset):
        self.dataset = braindecode_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        X, y, _ = self.dataset[idx]
        return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

train_dataset = ContrastChangeDataset(train_ds)
valid_dataset = ContrastChangeDataset(valid_ds)
test_dataset = ContrastChangeDataset(test_ds)

In [63]:
first = train_dataset[0]
first

(tensor([[ 1.2222e-05,  5.1986e-06,  5.3049e-06,  ..., -8.8720e-07,
           7.3418e-07, -9.9474e-07],
         [ 1.1537e-05,  5.4671e-06,  5.5437e-06,  ..., -6.2149e-07,
           1.2688e-06,  3.0214e-07],
         [ 1.3410e-05,  6.8182e-06,  5.9050e-06,  ...,  8.5317e-07,
           1.7593e-06,  8.2808e-07],
         ...,
         [-8.4166e-06, -1.0888e-05, -9.3056e-06,  ...,  6.8735e-06,
           3.1253e-06,  1.5513e-06],
         [-1.2049e-05, -1.4262e-05, -1.3558e-05,  ...,  9.3226e-06,
           5.4775e-06,  2.9852e-06],
         [ 5.0000e-13,  5.0000e-13,  5.0000e-13,  ...,  5.0000e-13,
           5.0000e-13,  5.0000e-13]]),
 tensor([1.9580]))

In [64]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True )
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, pin_memory=True)

In [65]:
first_batch = next(iter(train_loader))
first_batch[0].shape, first_batch[1].shape

(torch.Size([64, 129, 200]), torch.Size([64, 1]))

Load the model 

In [5]:
import torch.nn as nn
import torch 
from model.eegmamba_jamba import EegMambaJEPA

class FinetuneJEPA(nn.Module):
    """Simple wrapper: EegMambaJEPA backbone -> linear regression head."""
    def __init__(self, 
                 n_chans: int = 129, 
                 d_model: int = 256, 
                 n_layer: int = 8, 
                 patch_size: int = 10
                 ):
        super().__init__()
        self.backbone = EegMambaJEPA(
            d_model=d_model, 
            n_layer=n_layer, 
            n_channels=n_chans, 
            patch_size=patch_size
            )
        self.head = nn.Linear(d_model, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, C, T)
        z = self.backbone(x)  # (B, d_model)
        out = self.head(z)    # (B, 1)
        return out


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = FinetuneJEPA(n_chans=129, d_model=256, n_layer=8, patch_size=10)
model = model.to(DEVICE)

In [12]:
weight_path = "finetune_weight/pretrain_epoch020.pt"
state_dict = torch.load(weight_path, map_location=DEVICE)
model_state = state_dict["model_state"]

model.backbone.load_state_dict(model_state, strict = False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['target_model.cls_token', 'target_model.patch_embed.proj.weight', 'target_model.patch_embed.proj.bias', 'target_model.mamba_blocks.0.in_proj.weight', 'target_model.mamba_blocks.0.conv1d.weight', 'target_model.mamba_blocks.0.conv1d.bias', 'target_model.mamba_blocks.0.mamba_fwd.A_log', 'target_model.mamba_blocks.0.mamba_fwd.D', 'target_model.mamba_blocks.0.mamba_fwd.in_proj.weight', 'target_model.mamba_blocks.0.mamba_fwd.conv1d.weight', 'target_model.mamba_blocks.0.mamba_fwd.conv1d.bias', 'target_model.mamba_blocks.0.mamba_fwd.x_proj.weight', 'target_model.mamba_blocks.0.mamba_fwd.dt_proj.weight', 'target_model.mamba_blocks.0.mamba_fwd.dt_proj.bias', 'target_model.mamba_blocks.0.mamba_fwd.out_proj.weight', 'target_model.mamba_blocks.0.mamba_bwd.A_log', 'target_model.mamba_blocks.0.mamba_bwd.D', 'target_model.mamba_blocks.0.mamba_bwd.in_proj.weight', 'target_model.mamba_blocks.0.mamba_bwd.conv1d.weight', 'target_model.mamba_blocks.0.mamb

In [10]:
model_state = state_dict["model_state"]
model_state.keys()

odict_keys(['cls_token', 'patch_embed.proj.weight', 'patch_embed.proj.bias', 'mamba_blocks.0.in_proj.weight', 'mamba_blocks.0.conv1d.weight', 'mamba_blocks.0.conv1d.bias', 'mamba_blocks.0.mamba_fwd.A_log', 'mamba_blocks.0.mamba_fwd.D', 'mamba_blocks.0.mamba_fwd.in_proj.weight', 'mamba_blocks.0.mamba_fwd.conv1d.weight', 'mamba_blocks.0.mamba_fwd.conv1d.bias', 'mamba_blocks.0.mamba_fwd.x_proj.weight', 'mamba_blocks.0.mamba_fwd.dt_proj.weight', 'mamba_blocks.0.mamba_fwd.dt_proj.bias', 'mamba_blocks.0.mamba_fwd.out_proj.weight', 'mamba_blocks.0.mamba_bwd.A_log', 'mamba_blocks.0.mamba_bwd.D', 'mamba_blocks.0.mamba_bwd.in_proj.weight', 'mamba_blocks.0.mamba_bwd.conv1d.weight', 'mamba_blocks.0.mamba_bwd.conv1d.bias', 'mamba_blocks.0.mamba_bwd.x_proj.weight', 'mamba_blocks.0.mamba_bwd.dt_proj.weight', 'mamba_blocks.0.mamba_bwd.dt_proj.bias', 'mamba_blocks.0.mamba_bwd.out_proj.weight', 'mamba_blocks.0.out_proj.weight', 'mamba_blocks.0.norm.weight', 'mamba_blocks.0.norm.bias', 'mamba_blocks.1.in