In [52]:
from pathlib import Path
import argparse
import time
import copy

import torch
import numpy as np
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state
from tqdm.notebook import tqdm

from lr_scheduler import CosineLRScheduler

try:
    from eegdash.dataset import EEGChallengeDataset
    from eegdash.hbn.windows import (
        annotate_trials_with_target,
        add_aux_anchors,
        add_extras_columns,
        keep_only_recordings_with,
    )
except Exception as e:
    EEGChallengeDataset = None

from braindecode.preprocessing import Preprocessor, preprocess, create_windows_from_events
from braindecode.datasets import BaseConcatDataset
import torch.nn as nn
from model.eegmamba_jamba import EegMambaJEPA
import joblib

In [53]:
DATA_PATH = "LOL_DATASET/LOL_DATASET/HBN_DATA_FULL/"
RELEASES = ["R1", "R2", "R3"]

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [54]:
all_dataset = []
for rel in RELEASES:
    name_folder = f"{rel}_mini_L100_bdf" 
    cache_dir = Path(DATA_PATH) / name_folder 

    dataset = EEGChallengeDataset(
        cache_dir = cache_dir,
        task = "contrastChangeDetection",
        mini = True,
        download = False,
        release = rel
    )

    all_dataset.append(dataset)

In [55]:
all_dataset = BaseConcatDataset(all_dataset)
len(all_dataset)

7522500

Preprocessing the dataset

In [31]:
EPOCH_LENS_S = 2

In [32]:
def build_offline_preprocessors():
    return [
        Preprocessor(annotate_trials_with_target, 
                     target_field = "rt_from_stimulus", 
                     epoch_length = EPOCH_LENS_S, 
                     require_stimulus = True, 
                     require_response = True, 
                     apply_on_array = False),
        Preprocessor(add_aux_anchors, apply_on_array=False),
    ]


This case we don't need to save the dataset

In [43]:
preproc_dir = Path("preprocessed_dataset")
preproc_dir.mkdir(parents=True, exist_ok=True)

list_windows = []

ANCHOR = "stimulus_anchor"
SHIFT_AFTER_STIM = 0.5
WINDOW_LEN = 2.0
SFREQ = 100

preproc = build_offline_preprocessors()

for i, dataset in enumerate(all_dataset):
    preprocess(dataset, preproc, n_jobs = -1)

    dataset = keep_only_recordings_with(ANCHOR, dataset)
    windows = create_windows_from_events(
        dataset, 
        mapping = {ANCHOR: 0},
        trial_start_offset_samples = int(SHIFT_AFTER_STIM * SFREQ),                 # +0.5 s
        trial_stop_offset_samples = int((SHIFT_AFTER_STIM + WINDOW_LEN) * SFREQ),   # +2.5 s
        window_size_samples = int(WINDOW_LEN * SFREQ),
        window_stride_samples = SFREQ,
        preload=True,
    )

    windows = add_extras_columns(
        windows,
        dataset,
        desc=ANCHOR,
        keys=("target", "rt_from_stimulus", "rt_from_trialstart",
              "stimulus_onset", "response_onset", "correct", "response_type")
    )

    list_windows.append(windows)

    save_path = preproc_dir / f"{RELEASES[i]}_windows.pkl"
    joblib.dump(windows, save_path)


Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(

In [44]:
load_windows = []

for rel in RELEASES:
    load_path = preproc_dir / f"{rel}_windows.pkl"
    windows = joblib.load(load_path)
    load_windows.append(windows)

all_windows = BaseConcatDataset(load_windows)

In [45]:
all_windows.get_metadata().head()

Unnamed: 0,i_window_in_trial,i_start_in_trial,i_stop_in_trial,target,rt_from_stimulus,rt_from_trialstart,stimulus_onset,response_onset,correct,response_type,...,thepresent,diaryofawimpykid,contrastchangedetection_1,contrastchangedetection_2,contrastchangedetection_3,surroundsupp_1,surroundsupp_2,seqlearning6target,seqlearning8target,symbolsearch
0,0,4278,4478,2.13,2.13,4.93,42.284,44.414,1,right_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available
1,0,4798,4998,1.96,1.96,4.76,47.484,49.444,1,right_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available
2,0,5478,5678,2.02,2.02,6.42,54.284,56.304,1,right_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available
3,0,6318,6518,1.72,1.72,7.72,62.684,64.404,1,right_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available
4,0,6838,7038,1.8,1.8,4.6,67.884,69.684,1,left_buttonPress,...,available,available,available,available,available,available,available,unavailable,available,available


#### Spliting the train test valid

In [47]:
meta = all_windows.get_metadata()
subjects = list(meta['subject'].unique())

valid_frac = 0.1
test_frac = 0.1
seed = 2025

train_subj, valid_test_subject = train_test_split(subjects, test_size=(valid_frac + test_frac), random_state=check_random_state(seed), shuffle=True)
valid_subj, test_subj = train_test_split(valid_test_subject, test_size=test_frac/(valid_frac+test_frac), random_state=check_random_state(seed+1), shuffle=True)

subject_split = windows.split("subject")
train_sets = [ds for subj, ds in subject_split.items() if subj in train_subj]
valid_sets = [ds for subj, ds in subject_split.items() if subj in valid_subj]
test_sets = [ds for subj, ds in subject_split.items() if subj in test_subj]

train_ds = BaseConcatDataset(train_sets)
valid_ds = BaseConcatDataset(valid_sets)
test_ds = BaseConcatDataset(test_sets)

In [51]:
print(len(train_ds))
print(len(valid_ds))
print(len(test_ds))

1110
134
224


In [62]:
class ContrastChangeDataset(torch.utils.data.Dataset):
    def __init__(self, braindecode_dataset):
        self.dataset = braindecode_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        X, y, _ = self.dataset[idx]
        return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

train_dataset = ContrastChangeDataset(train_ds)
valid_dataset = ContrastChangeDataset(valid_ds)
test_dataset = ContrastChangeDataset(test_ds)

In [63]:
first = train_dataset[0]
first

(tensor([[ 1.2222e-05,  5.1986e-06,  5.3049e-06,  ..., -8.8720e-07,
           7.3418e-07, -9.9474e-07],
         [ 1.1537e-05,  5.4671e-06,  5.5437e-06,  ..., -6.2149e-07,
           1.2688e-06,  3.0214e-07],
         [ 1.3410e-05,  6.8182e-06,  5.9050e-06,  ...,  8.5317e-07,
           1.7593e-06,  8.2808e-07],
         ...,
         [-8.4166e-06, -1.0888e-05, -9.3056e-06,  ...,  6.8735e-06,
           3.1253e-06,  1.5513e-06],
         [-1.2049e-05, -1.4262e-05, -1.3558e-05,  ...,  9.3226e-06,
           5.4775e-06,  2.9852e-06],
         [ 5.0000e-13,  5.0000e-13,  5.0000e-13,  ...,  5.0000e-13,
           5.0000e-13,  5.0000e-13]]),
 tensor([1.9580]))

In [64]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True )
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, pin_memory=True)

In [65]:
first_batch = next(iter(train_loader))
first_batch[0].shape, first_batch[1].shape

(torch.Size([64, 129, 200]), torch.Size([64, 1]))

Load the model 

In [5]:
import torch.nn as nn
import torch 
from model.eegmamba_jamba import EegMambaJEPA

class FinetuneJEPA(nn.Module):
    """Simple wrapper: EegMambaJEPA backbone -> linear regression head."""
    def __init__(self, 
                 n_chans: int = 129, 
                 d_model: int = 256, 
                 n_layer: int = 8, 
                 patch_size: int = 10
                 ):
        super().__init__()
        self.backbone = EegMambaJEPA(
            d_model=d_model, 
            n_layer=n_layer, 
            n_channels=n_chans, 
            patch_size=patch_size
            )
        self.head = nn.Linear(d_model, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, C, T)
        z = self.backbone(x)  # (B, d_model)
        out = self.head(z)    # (B, 1)
        return out


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = FinetuneJEPA(n_chans=129, d_model=256, n_layer=8, patch_size=10)
model = model.to(DEVICE)

In [12]:
weight_path = "finetune_weight/pretrain_epoch020.pt"
state_dict = torch.load(weight_path, map_location=DEVICE)
model_state = state_dict["model_state"]

model.backbone.load_state_dict(model_state, strict = False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['target_model.cls_token', 'target_model.patch_embed.proj.weight', 'target_model.patch_embed.proj.bias', 'target_model.mamba_blocks.0.in_proj.weight', 'target_model.mamba_blocks.0.conv1d.weight', 'target_model.mamba_blocks.0.conv1d.bias', 'target_model.mamba_blocks.0.mamba_fwd.A_log', 'target_model.mamba_blocks.0.mamba_fwd.D', 'target_model.mamba_blocks.0.mamba_fwd.in_proj.weight', 'target_model.mamba_blocks.0.mamba_fwd.conv1d.weight', 'target_model.mamba_blocks.0.mamba_fwd.conv1d.bias', 'target_model.mamba_blocks.0.mamba_fwd.x_proj.weight', 'target_model.mamba_blocks.0.mamba_fwd.dt_proj.weight', 'target_model.mamba_blocks.0.mamba_fwd.dt_proj.bias', 'target_model.mamba_blocks.0.mamba_fwd.out_proj.weight', 'target_model.mamba_blocks.0.mamba_bwd.A_log', 'target_model.mamba_blocks.0.mamba_bwd.D', 'target_model.mamba_blocks.0.mamba_bwd.in_proj.weight', 'target_model.mamba_blocks.0.mamba_bwd.conv1d.weight', 'target_model.mamba_blocks.0.mamb

In [10]:
model_state = state_dict["model_state"]
model_state.keys()

odict_keys(['cls_token', 'patch_embed.proj.weight', 'patch_embed.proj.bias', 'mamba_blocks.0.in_proj.weight', 'mamba_blocks.0.conv1d.weight', 'mamba_blocks.0.conv1d.bias', 'mamba_blocks.0.mamba_fwd.A_log', 'mamba_blocks.0.mamba_fwd.D', 'mamba_blocks.0.mamba_fwd.in_proj.weight', 'mamba_blocks.0.mamba_fwd.conv1d.weight', 'mamba_blocks.0.mamba_fwd.conv1d.bias', 'mamba_blocks.0.mamba_fwd.x_proj.weight', 'mamba_blocks.0.mamba_fwd.dt_proj.weight', 'mamba_blocks.0.mamba_fwd.dt_proj.bias', 'mamba_blocks.0.mamba_fwd.out_proj.weight', 'mamba_blocks.0.mamba_bwd.A_log', 'mamba_blocks.0.mamba_bwd.D', 'mamba_blocks.0.mamba_bwd.in_proj.weight', 'mamba_blocks.0.mamba_bwd.conv1d.weight', 'mamba_blocks.0.mamba_bwd.conv1d.bias', 'mamba_blocks.0.mamba_bwd.x_proj.weight', 'mamba_blocks.0.mamba_bwd.dt_proj.weight', 'mamba_blocks.0.mamba_bwd.dt_proj.bias', 'mamba_blocks.0.mamba_bwd.out_proj.weight', 'mamba_blocks.0.out_proj.weight', 'mamba_blocks.0.norm.weight', 'mamba_blocks.0.norm.bias', 'mamba_blocks.1.in

### Loading Challenge 2

In [59]:
from pathlib import Path
import argparse
import time
import copy

import torch
import numpy as np
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state
from tqdm.notebook import tqdm

from lr_scheduler import CosineLRScheduler

try:
    from eegdash.dataset import EEGChallengeDataset
    from eegdash.hbn.windows import (
        annotate_trials_with_target,
        add_aux_anchors,
        add_extras_columns,
        keep_only_recordings_with,
    )
except Exception as e:
    EEGChallengeDataset = None

from braindecode.preprocessing import Preprocessor, preprocess, create_windows_from_events
from braindecode.datasets import BaseConcatDataset
import torch.nn as nn
from model.eegmamba_jamba import EegMambaJEPA
import joblib
from braindecode.datasets.base import EEGWindowsDataset, BaseConcatDataset, BaseDataset
from braindecode.preprocessing import create_fixed_length_windows



In [122]:
DATA_PATH = "LOL_DATASET/LOL_DATASET/HBN_DATA_FULL/"
RELEASES = ["R1", "R2"]

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SFREQ = 100

In [123]:
all_datasets = []

for rel in RELEASES:
    name_folder = f"{rel}_mini_L100_bdf" 
    cache_dir = Path(DATA_PATH) / name_folder 

    dataset = EEGChallengeDataset(
        cache_dir = cache_dir,
        task = "contrastChangeDetection",
        mini = True,
        download = False,
        release = rel,
        description_fields = [
            "subject",
            "session",
            "run",
            "task",
            "age",
            "gender",
            "sex",
            "p_factor",
            # "internalizing",
            # "externalizing",
            # "ehq_total",
            # "commercial_use",
            # "full_pheno",
            # "attention"
            
        ],

        
    )

    all_datasets.append(dataset)

 ## 2. Wrap the data into a PyTorch-compatible dataset

 The class below defines a dataset wrapper that will extract 2-second windows,
 uniformly sampled over the whole signal. In addition, it will add useful information
 about the extracted windows, such as the externalizing factor, the subject or the task.


In [125]:
class DatasetWrapper(BaseDataset):
    def __init__(
        self, 
        dataset: EEGWindowsDataset,
        crop_size_samples: int,
        target_name: str = "externalizing",
        seed = None,
    ):
        
        self.dataset = dataset
        self.crop_size_samples = crop_size_samples
        self.target_name = target_name
        self.rng = np.random.default_rng(seed)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        X, _, crop_inds = self.dataset[idx]

        target = self.dataset.description[self.target_name]
        print(target)
        target = float(target)

        # Additional information
        infos = {
            "subject": self.dataset.description["subject"],
            "sex": self.dataset.description["sex"],
            "age": float(self.dataset.description["age"]),
            "task": self.dataset.description["task"],
            "session": self.dataset.description.get("session", None) or "",
            "run": self.dataset.description.get("run", None) or "",
        }

        # Random cropping the EEG waves to the size of crop_size_samples 
        i_window_in_trial, i_start, i_stop = crop_inds
        assert i_stop - i_start >= self.crop_size_samples, f"{i_stop=} {i_start=}"
        start_offset = self.rng.integers(0, i_stop - i_start - self.crop_size_samples + 1)
        i_start = i_start + start_offset 
        i_stop  = i_start + self.crop_size_samples

        return X, target, (i_window_in_trial, i_start, i_stop), infos

In [126]:
import tqdm.notebook as tqdm
import math

preproc_dir = Path("preprocessed_dataset")
preproc_dir.mkdir(parents=True, exist_ok=True)

list_windows = []

for idx, data in enumerate(all_datasets):
    filter_data = BaseConcatDataset(
       [
            ds 
            for ds in data.datasets
            if ds.raw.n_times >= 4 * SFREQ
            and len(ds.raw.ch_names) == 129
            and not math.isnan(ds.description["p_factor"])
       ] 
    )

    # Create 4-seconds windows with 2-seconds stride
    windows_ds = create_fixed_length_windows(
        filter_data,
        window_size_samples=4 * SFREQ,
        window_stride_samples=2 * SFREQ,
        drop_last_window=True,
    )
    windows_ds = BaseConcatDataset(
            [DatasetWrapper(ds, crop_size_samples=2 * SFREQ) for ds in windows_ds.datasets]
    )

    list_windows.append(windows_ds)

    save_path = preproc_dir / f"{RELEASES[idx]}_windows.pkl"
    joblib.dump(windows_ds, save_path)





In [127]:
windows_ds[0]

0.62


(array([[ 6.4831329e-05, -6.1362621e-06,  1.8794045e-08, ...,
         -5.5702254e-05, -5.8715075e-05, -5.6652494e-05],
        [ 2.9976351e-05, -3.7022841e-05, -2.9796010e-05, ...,
         -9.4364950e-05, -1.0916590e-04, -1.0272729e-04],
        [ 1.0300881e-05, -4.5629182e-05, -3.3451066e-05, ...,
          3.5722020e-05,  3.0152774e-05,  3.3893011e-05],
        ...,
        [ 8.1039987e-05,  5.8947244e-05,  6.8126959e-05, ...,
         -1.8899624e-05, -7.7549084e-06, -9.7040192e-06],
        [ 7.7256365e-05,  1.5659152e-05,  2.8310096e-05, ...,
         -2.1114942e-05, -1.2811515e-05, -1.4344756e-05],
        [ 5.0000005e-13,  5.0000005e-13,  5.0000005e-13, ...,
          5.0000005e-13,  5.0000005e-13,  5.0000005e-13]],
       shape=(129, 400), dtype=float32),
 0.62,
 (0, np.int64(80), np.int64(280)),
 {'subject': 'NDARAB793GL3',
  'sex': 'M',
  'age': 13.4391,
  'task': 'contrastChangeDetection',
  'session': '',
  'run': '1'})

#### Testing the whole procedure 

In [160]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import joblib, math, numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from braindecode.datasets.base import EEGWindowsDataset, BaseConcatDataset, BaseDataset
from tqdm.notebook import tqdm
import typing
from braindecode.preprocessing import create_fixed_length_windows
from torch.utils.data import DataLoader, random_split

In [161]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SFREQ = 100
CROP_SEC = 2
WINDOW_SEC = 4
STRIDE_SEC = 2

DATA_PATH = Path("LOL_DATASET/LOL_DATASET/HBN_DATA_FULL/")
RELEASES = ["R1", "R2", "R3"]

TASK_NAMES = [
    "contrastChangeDetection"
    # "RestingState", "DespicableMe", "DiaryOfAWimpyKid", "FunwithFractals",
    # "ThePresent", "contrastChangeDetection", "seqLearning6target",
    # "seqLearning8target", "surroundSupp", "symbolSearch"
]

 2. Meta Encoder (task + sex + age)


In [205]:
class MetaEncoder:
    def __init__(self):
        self.task_enc = LabelEncoder()
        self.sex_enc = LabelEncoder()
        self.age_scaler = StandardScaler()

    def fit(self, metas: typing.List[dict]):
        tasks = [m["task"] for m in metas]
        sexes = [m["sex"] for m in metas]
        ages = [[m["age"]] for m in metas]
        self.task_enc.fit(tasks)
        self.sex_enc.fit(sexes)
        self.age_scaler.fit(ages)
        self.dim = len(self.task_enc.classes_) + len(self.sex_enc.classes_) + 1
        return self

    def transform(self, meta: dict) -> torch.Tensor:
        t = self.task_enc.transform([meta["task"]])[0]
        s = self.sex_enc.transform([meta["sex"]])[0]
        a = self.age_scaler.transform([[meta["age"]]])[0, 0]
        vec = torch.zeros(self.dim, dtype=torch.float32)
        vec[t] = 1.0
        vec[len(self.task_enc.classes_) + s] = 1.0
        vec[-1] = a
        return vec

In [237]:
raw_datasets = []
meta_for_encoder = []

for rel in RELEASES:
    folder = f"{rel}_mini_L100_bdf"
    cache_dir = DATA_PATH / folder
    for task in TASK_NAMES:
        ds = EEGChallengeDataset(
            cache_dir=cache_dir,
            task=task,
            mini=True,
            download=False,
            release=rel,
            description_fields=[
                "subject",
                "session",
                "run",
                "task",
                "age",
                "gender",
                "sex",
                "p_factor",
            ],
        )
        raw_datasets.append(ds)
        for sub_ds in ds.datasets:
            d = sub_ds.description
            if not math.isnan(d.get("externalizing", math.nan)):
                meta_for_encoder.append({
                    "task": d["task"],
                    "sex": d["sex"],
                    "age": float(d["age"]),
                })

In [238]:
# Fit global encoder
meta_encoder = MetaEncoder().fit(meta_for_encoder)
META_DIM = meta_encoder.dim
print(f"Meta embedding dim: {META_DIM}")

Meta embedding dim: 4


In [239]:
class CropMetaWrapper(BaseDataset):
    def __init__(self, windows_ds, 
                        crop_samples, 
                        meta_encoder, 
                        target_name="externalizing"):
        
        self.windows_ds = windows_ds
        self.crop_samples = crop_samples
        self.meta_encoder = meta_encoder
        self.target_name = target_name
        self.rng = np.random.default_rng(2025)  # fixed seed

    def __len__(self):
        return len(self.windows_ds)

    def __getitem__(self, idx):
        X, _, crop_inds = self.windows_ds[idx]  # X: (C, 4*SFREQ)

        # Target
        target = float(self.windows_ds.description[self.target_name])

        # Meta
        desc = self.windows_ds.description
        meta_dict = {
            "task": desc["task"],
            "sex": desc["sex"],
            "age": float(desc["age"]),
        }
        meta_vec = self.meta_encoder.transform(meta_dict)

        # Random 2s crop
        i_win, i_start, i_stop = crop_inds


        assert i_stop - i_start >= self.crop_samples

        # FIXED: .integers instead of .randint
        offset = self.rng.integers(0, i_stop - i_start - self.crop_samples + 1)
        i_start = i_start + offset
        i_stop = i_start + self.crop_samples
        X_crop = X[:, offset : offset + self.crop_samples]  # (C, 2*SFREQ)

        # Infos
        infos = {
            "subject": desc["subject"],
            "session": desc.get("session", ""),
            "run": desc.get("run", ""),
            "task": desc["task"],
            "sex": desc["sex"],
            "age": float(desc["age"]),
        }

        return torch.tensor(X_crop), meta_vec, target, (i_win, i_start, i_stop), infos

In [240]:
preproc_dir = Path("preprocessed_dataset")
preproc_dir.mkdir(parents=True, exist_ok=True)
list_windows = []

for rel in RELEASES:
    rel_raw = [ds for ds in raw_datasets if ds.release == rel]
    filtered = BaseConcatDataset([
        sub_ds for ds in rel_raw for sub_ds in ds.datasets
        if (sub_ds.raw.n_times >= 4 * SFREQ
            and len(sub_ds.raw.ch_names) == 129
            and not math.isnan(sub_ds.description.get("externalizing", math.nan)))
    ])
    windows = create_fixed_length_windows(
        filtered,
        window_size_samples= 4 * SFREQ,
        window_stride_samples= 2 * SFREQ,
        drop_last_window=True,
    )
    windows_ds = BaseConcatDataset(
        [CropMetaWrapper(
            ds, crop_samples=CROP_SEC * SFREQ, meta_encoder=meta_encoder
        ) for ds in windows.datasets
        ]
    )

    list_windows.append(windows_ds)
    joblib.dump(windows_ds, preproc_dir / f"{rel}_windows.pkl")

In [241]:
len(list_windows)

3

In [242]:
wrapped_windows = BaseConcatDataset(list_windows)

In [243]:
len(wrapped_windows)

37381

In [244]:
x, _, _, _, _ = wrapped_windows[5]
x.shape

torch.Size([129, 200])

In [245]:
# === RANDOM SPLIT BY LENGTH ===
total_len = len(wrapped_windows)
train_len = int(0.8 * total_len)
valid_len = int(0.1 * total_len)
test_len  = total_len - train_len - valid_len

train_ds, valid_ds, test_ds = random_split(
    wrapped_windows,
    [train_len, valid_len, test_len],
    generator=torch.Generator().manual_seed(2025)  # reproducible
)

# === DATALOADERS ===
BATCH_SIZE = 16
NUM_WORKERS = 4

def collate_fn(batch):
    X, meta, y, wins, infos = zip(*batch)

    return (
        torch.stack(X),
        torch.stack(meta),
        torch.tensor(y, dtype=torch.float32).unsqueeze(1),
    )

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)

In [246]:
batch = next(iter(train_loader))
batch[0].shape, batch[1].shape, batch[2].shape

(torch.Size([16, 129, 200]), torch.Size([16, 4]), torch.Size([16, 1]))

In [176]:
class FinetuneJEPAWithMeta(nn.Module):
    def __init__(
        self,
        d_model: int = 256,
        n_layer: int = 8,
        patch_size: int = 10,
        meta_dim: int = META_DIM,        # ← from meta_encoder
        dropout: float = 0.1,
    ):
        super().__init__()
        # 1. EEG Backbone (your JEPA-pretrained Mamba)
        self.backbone = EegMambaJEPA(
            d_model=d_model,
            n_layer=n_layer,
            n_channels=129,
            patch_size=patch_size,
        )

        # 2. Meta Projection (meta → d_model)
        self.meta_proj = nn.Sequential(
            nn.Linear(meta_dim, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, d_model),
        )

        # 3. Fusion + Head
        self.fusion_norm = nn.LayerNorm(d_model * 2)
        self.head = nn.Sequential(
            nn.Linear(d_model * 2, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 1)
        )

    def forward(self, eeg: torch.Tensor, meta: torch.Tensor):
        """
        eeg:  (B, 129, 200)
        meta: (B, META_DIM)
        """
        # EEG → (B, d_model)
        z_eeg = self.backbone(eeg)
        print(z_eeg.shape)

        # Meta → (B, d_model)
        z_meta = self.meta_proj(meta)
        print(z_meta.shape)

        # Fuse
        z = torch.cat([z_eeg, z_meta], dim=-1)   # (B, 2*d_model)
        z = self.fusion_norm(z)

        # Predict
        out = self.head(z).squeeze(-1)           # (B,)
        return out

In [178]:
egg.shape

torch.Size([16, 129, 0])

In [177]:
model = FinetuneJEPAWithMeta(d_model=256, meta_dim=META_DIM).to(DEVICE)
batch = next(iter(train_loader))
egg, meta, y = batch
model(egg.to(DEVICE), meta.to(DEVICE)).shape

RuntimeError: Calculated padded input size per channel: (0). Kernel size: (10). Kernel size can't be greater than actual input size