In [None]:
from pathlib import Path
from main.dataset.amigos import load_participant_data
from main.dataset.amigos.config import AmigosConfig
import mne.io

from main.core_data.media.eeg import EEGFeatureExtractor

info = mne.create_info(
    ch_names=AmigosConfig.CH_NAMES,
    ch_types=AmigosConfig.CH_TYPES,
    sfreq=AmigosConfig.original_eeg_fs
)
participant_data = load_participant_data(Path("../../../resources/AMIGOS/pre_processed_py/"))
eeg_data = participant_data["P01"]["joined_data"][0]
raw = mne.io.RawArray(eeg_data.T, info=info, verbose=False)

In [None]:
extractor = EEGFeatureExtractor(raw)
starts = extractor.pick_segments(
    4, 0.125, bands=((4, 8), (8, 13), (13, 30)), band_weights=(0.4, 0.5, 0.4)
)

In [None]:
from main.core_data.sampler import EegFeaturesAndRandLogUIntervalsSegmenter

feat_seg = EegFeaturesAndRandLogUIntervalsSegmenter(
    min_length=1,
    max_length=30,
    num_segments=20,
    extraction_jitter=0,
    anchor_identification_hop=0.125
)

In [None]:
from main.core_data.media.eeg import EEG
import numpy as np

r = feat_seg.compute_segments(EEG(data=raw, eid="1", fs=128))
a = np.array(r)

np.diff(a)

In [None]:
a

In [None]:
# Total coverage:
unique_coverage = 0
duration = raw.duration

taken_quarter_seconds = np.array([0 for i in range(int(duration * 4))])
for start, stop in a:
    start = int(start * 4)
    stop = int(stop * 4)
    taken_quarter_seconds[start:stop] += 1

unique_coverage = len(taken_quarter_seconds[taken_quarter_seconds > 0]) / (raw.duration * 4)
print(unique_coverage)  # We use 95% of the original signal

In [None]:
# Coverage of SHORT segments:
unique_coverage = 0
duration = raw.duration

taken_quarter_seconds = np.array([0 for i in range(int(duration * 4))])
for start, stop in a[a[:, 1] - a[:, 0] < 4.1]:
    start = int(start * 4)
    stop = int(stop * 4)
    taken_quarter_seconds[start:stop] += 1

unique_coverage = len(taken_quarter_seconds[taken_quarter_seconds > 0]) / (raw.duration * 4)
print(unique_coverage)  # We use 95% of the original signal

In [1]:
from braindecode.models import Labram

model = Labram(input_window_seconds=15, sfreq=200, n_chans=64, n_outputs=1)

In [2]:
import torch
from collections import OrderedDict

ckpt = torch.load("../../../dependencies/labram-base.pth", map_location="cpu", weights_only=False)

# 1) Extract the actual state dict
sd = ckpt.get("state_dict", ckpt.get("model", ckpt))
assert isinstance(sd, (dict, OrderedDict)), "No model/state_dict found in checkpoint."

# 2) Optional: normalize key names (extend as needed)
ALIASES = {
    "pos_embed": "position_embedding",
    "time_embed": "temporal_embedding",
    "head.weight": "final_layer.weight",
    "head.bias": "final_layer.bias",
    "norm.weight": "fc_norm.weight",
    "norm.bias": "fc_norm.bias",
}


def normalize(k: str) -> str:
    k = k.removeprefix("module.").removeprefix("backbone.").removeprefix("encoder.")
    # simple alias swaps
    for a, b in ALIASES.items():
        if k == a or k.endswith("." + a):
            k = k[: -len(a)] + b if k.endswith("." + a) else b
    return k


# 3) Per-layer copy: only load keys that exist and match shape
model_sd = model.state_dict()
to_load = {}
loaded, skipped, shape_mismatch = [], [], []

for k, w in sd.items():
    nk = normalize(k)
    if nk in model_sd:
        if model_sd[nk].shape == w.shape:
            to_load[nk] = w
            loaded.append(nk)
        else:
            shape_mismatch.append((nk, tuple(w.shape), tuple(model_sd[nk].shape)))
    else:
        skipped.append(k)

# 4) Load the matching subset (strict=False), report the rest
incompat = model.load_state_dict(to_load, strict=False)
print(f"Loaded {len(loaded)} layers")
if incompat.missing_keys: print("Missing in model:", incompat.missing_keys[:10], "...")
if incompat.unexpected_keys: print("Unexpected in ckpt (after mapping):", incompat.unexpected_keys[:10], "...")
if shape_mismatch: print("Shape mismatches:", shape_mismatch[:10], "...")


Loaded 0 layers
Missing in model: ['cls_token', 'position_embedding', 'temporal_embedding', 'patch_embed.segment_patch.patcher.weight', 'patch_embed.segment_patch.patcher.bias', 'patch_embed.temporal_conv.conv1.weight', 'patch_embed.temporal_conv.conv1.bias', 'patch_embed.temporal_conv.norm1.weight', 'patch_embed.temporal_conv.norm1.bias', 'patch_embed.temporal_conv.conv2.weight'] ...


In [3]:
# map checkpoint_key -> model_key
model.mapping = {
    "pos_embed": "position_embedding",
    "time_embed": "temporal_embedding",
    "head.weight": "final_layer.weight",
    "head.bias": "final_layer.bias",
    "norm.weight": "fc_norm.weight",
    "norm.bias": "fc_norm.bias",
    # add more as needed
}
model.load_state_dict(sd, strict=False)

_IncompatibleKeys(missing_keys=['cls_token', 'position_embedding', 'temporal_embedding', 'patch_embed.segment_patch.patcher.weight', 'patch_embed.segment_patch.patcher.bias', 'patch_embed.temporal_conv.conv1.weight', 'patch_embed.temporal_conv.conv1.bias', 'patch_embed.temporal_conv.norm1.weight', 'patch_embed.temporal_conv.norm1.bias', 'patch_embed.temporal_conv.conv2.weight', 'patch_embed.temporal_conv.conv2.bias', 'patch_embed.temporal_conv.norm2.weight', 'patch_embed.temporal_conv.norm2.bias', 'patch_embed.temporal_conv.conv3.weight', 'patch_embed.temporal_conv.conv3.bias', 'patch_embed.temporal_conv.norm3.weight', 'patch_embed.temporal_conv.norm3.bias', 'blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.attn.qkv.weight', 'blocks.0.attn.proj.weight', 'blocks.0.attn.proj.bias', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.0.mlp.0.weight', 'blocks.0.mlp.0.bias', 'blocks.0.mlp.2.weight', 'blocks.0.mlp.2.bias', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1

In [10]:
model.to('cuda')
o = model(torch.randn(2, 64, 400).to('cuda'), return_all_tokens=True)

RuntimeError: The size of tensor a (129) must match the size of tensor b (961) at non-singleton dimension 1

In [8]:
o.shape

torch.Size([2, 961, 200])

In [6]:
from torch import nn

model.final_layer = nn.Identity()

In [None]:
model