In [1]:
from pathlib import Path
from collections import defaultdict

from torch.utils.data import DataLoader

import data

In [2]:
from importlib import reload

reload(data)

<module 'data' from '/home/connor/algonauts2025.clean/src/data.py'>

In [3]:
root_dir = Path("..").resolve()

data_dir = root_dir / "datasets/algonauts_2025.competitors"
sharded_feature_root = root_dir / "datasets/features.sharded"
merged_feature_root = root_dir / "datasets/features.merged"

In [4]:
friends_fmri = data.load_algonauts2025_friends_fmri(data_dir)
movie10_fmri = data.load_algonauts2025_movie10_fmri(data_dir)

In [5]:
assert len(friends_fmri) == 287
assert len(movie10_fmri) == 61

In [6]:
sample = friends_fmri["s01e05b"]
print("Sample shape (NTC):", sample.shape, sample.dtype)

Sample shape (NTC): (4, 468, 1000) float32


In [7]:
stimuli_features = defaultdict(dict)

In [8]:
sharded_models_layers = [
    ("whisper", "layers.12.fc2"),
    ("internvl3_8b_8bit", "language_model.model.layers.20.post_attention_layernorm"),
]

for model, layer in sharded_models_layers:
    stimuli_features[f"{model}/{layer}"].update(
        data.load_sharded_features(
            sharded_feature_root,
            model=model,
            layer=layer,
            series="friends",
        )
    )
    stimuli_features[f"{model}/{layer}"].update(
        data.load_sharded_features(
            sharded_feature_root,
            model=model,
            layer=layer,
            series="movie10",
        )
    )

In [9]:
merged_models_layers = [
    ("meta-llama__Llama-3.2-1B", "model.layers.7", "context-long"),
]

for model, layer, stem in merged_models_layers:
    stimuli_features[f"{model}/{layer}"].update(
        data.load_merged_features(
            root=merged_feature_root,
            model=model,
            layer=layer,
            series="friends",
            stem=stem,
        )
    )
    stimuli_features[f"{model}/{layer}"].update(
        data.load_merged_features(
            root=merged_feature_root,
            model=model,
            layer=layer,
            series="movie10",
            stem=stem,
        )
    )

In [10]:
all_fmri_episodes = list(friends_fmri) + list(movie10_fmri)
all_fmri_episodes_no_runs = [
    ep[0] if isinstance(ep, tuple) else ep for ep in all_fmri_episodes
]

In [11]:
all_feat_episodes = [list(feats) for feats in stimuli_features.values()]
for episode_list in all_feat_episodes:
    assert set(episode_list) == set(all_feat_episodes[0])
all_feat_episodes = all_feat_episodes[0]

In [12]:
for episode in all_feat_episodes:
    shapes = {layer: feats[episode].shape for layer, feats in stimuli_features.items()}
    if len(set(shape[0] for shape in shapes.values())) > 1:
        print(episode, shapes)

s01e03b {'whisper/layers.12.fc2': (472, 1280), 'internvl3_8b_8bit/language_model.model.layers.20.post_attention_layernorm': (472, 3584), 'meta-llama__Llama-3.2-1B/model.layers.7': (471, 2048)}
s02e20a {'whisper/layers.12.fc2': (449, 1280), 'internvl3_8b_8bit/language_model.model.layers.20.post_attention_layernorm': (449, 3584), 'meta-llama__Llama-3.2-1B/model.layers.7': (448, 2048)}
s03e03a {'whisper/layers.12.fc2': (453, 1280), 'internvl3_8b_8bit/language_model.model.layers.20.post_attention_layernorm': (483, 3584), 'meta-llama__Llama-3.2-1B/model.layers.7': (453, 2048)}
s03e03b {'whisper/layers.12.fc2': (453, 1280), 'internvl3_8b_8bit/language_model.model.layers.20.post_attention_layernorm': (466, 3584), 'meta-llama__Llama-3.2-1B/model.layers.7': (453, 2048)}
s03e04a {'whisper/layers.12.fc2': (473, 1280), 'internvl3_8b_8bit/language_model.model.layers.20.post_attention_layernorm': (450, 3584), 'meta-llama__Llama-3.2-1B/model.layers.7': (473, 2048)}
s03e04b {'whisper/layers.12.fc2': (

In [13]:
print(sorted(set(all_feat_episodes) - set(all_fmri_episodes_no_runs)))
print(sorted(set(all_fmri_episodes_no_runs) - set(all_feat_episodes)))

['s04e01a', 's04e01b', 's04e13b', 's05e20a', 's06e03a', 's07e01a', 's07e01b', 's07e02a', 's07e02b', 's07e03a', 's07e03b', 's07e04a', 's07e04b', 's07e05a', 's07e05b', 's07e06a', 's07e06b', 's07e07a', 's07e07b', 's07e08a', 's07e08b', 's07e09a', 's07e09b', 's07e10a', 's07e10b', 's07e11a', 's07e11b', 's07e12a', 's07e12b', 's07e13a', 's07e13b', 's07e14a', 's07e14b', 's07e15a', 's07e15b', 's07e16a', 's07e16b', 's07e16c', 's07e17a', 's07e17b', 's07e18a', 's07e18b', 's07e19a', 's07e19b', 's07e20a', 's07e20b', 's07e21a', 's07e21b', 's07e22a', 's07e22b', 's07e23a', 's07e23b', 's07e23c', 's07e23d']
[]


In [14]:
for ep in all_fmri_episodes:
    if isinstance(ep, tuple) and ep[1] == 2:
        print(ep)

('figures01', 2)
('figures02', 2)
('figures03', 2)
('figures04', 2)
('figures05', 2)
('figures06', 2)
('figures07', 2)
('figures08', 2)
('figures09', 2)
('figures10', 2)
('figures11', 2)
('figures12', 2)
('life01', 2)
('life02', 2)
('life03', 2)
('life04', 2)
('life05', 2)


In [15]:
train_episode_list = list(
    filter(
        data.episode_filter(seasons=range(6), movies=["bourne", "wolf"]),
        all_fmri_episodes,
    ),
)

In [16]:
for ep in train_episode_list:
    print(ep)

s01e01a
s01e01b
s01e02a
s01e02b
s01e03a
s01e03b
s01e04a
s01e04b
s01e05a
s01e05b
s01e06a
s01e06b
s01e07a
s01e07b
s01e08a
s01e08b
s01e09a
s01e09b
s01e10a
s01e10b
s01e11a
s01e11b
s01e12a
s01e12b
s01e13a
s01e13b
s01e14a
s01e14b
s01e15a
s01e15b
s01e16a
s01e16b
s01e17a
s01e17b
s01e18a
s01e18b
s01e19a
s01e19b
s01e20a
s01e20b
s01e21a
s01e21b
s01e22a
s01e22b
s01e23a
s01e23b
s01e24a
s01e24b
s02e01a
s02e01b
s02e02a
s02e02b
s02e03a
s02e03b
s02e04a
s02e04b
s02e05a
s02e05b
s02e06a
s02e06b
s02e07a
s02e07b
s02e08a
s02e08b
s02e09a
s02e09b
s02e10a
s02e10b
s02e11a
s02e11b
s02e12a
s02e12b
s02e13a
s02e13b
s02e14a
s02e14b
s02e15a
s02e15b
s02e16a
s02e16b
s02e17a
s02e17b
s02e18a
s02e18b
s02e19a
s02e19b
s02e20a
s02e20b
s02e21a
s02e21b
s02e22a
s02e22b
s02e23a
s02e23b
s02e24a
s02e24b
s03e01a
s03e01b
s03e02a
s03e02b
s03e03a
s03e03b
s03e04a
s03e04b
s03e05a
s03e05b
s03e06a
s03e06b
s03e07a
s03e07b
s03e08a
s03e08b
s03e09a
s03e09b
s03e10a
s03e10b
s03e11a
s03e11b
s03e12a
s03e12b
s03e13a
s03e13b
s03e14a
s03e14b
s03e15a


In [17]:
merged_fmri = {**friends_fmri, **movie10_fmri}

In [18]:
train_dataset = data.Algonauts2025Dataset(
    episode_list=train_episode_list,
    fmri_data=merged_fmri,
    feat_data=list(stimuli_features.values()),
    sample_length=64,
    num_samples=400,
    shuffle=True,
    seed=42,
)

In [19]:
train_loader = DataLoader(train_dataset, batch_size=4)

In [20]:
for batch in train_loader:
    print("episode:", batch["episode"])
    print("fmri:", batch["fmri"].shape)
    print("features:", [feat.shape for feat in batch["features"]])

episode: ['s01e13a', 'wolf17', 's02e19a', 's02e24b']
fmri: torch.Size([4, 4, 64, 1000])
features: [torch.Size([4, 64, 1280]), torch.Size([4, 64, 3584]), torch.Size([4, 64, 2048])]
episode: ['s01e20a', 's01e24a', 's03e22a', 's01e07a']
fmri: torch.Size([4, 4, 64, 1000])
features: [torch.Size([4, 64, 1280]), torch.Size([4, 64, 3584]), torch.Size([4, 64, 2048])]
episode: ['wolf11', 's03e03a', 's05e01a', 's04e20a']
fmri: torch.Size([4, 4, 64, 1000])
features: [torch.Size([4, 64, 1280]), torch.Size([4, 64, 3584]), torch.Size([4, 64, 2048])]
episode: ['s05e17b', 's05e18a', 's03e25a', 's05e03b']
fmri: torch.Size([4, 4, 64, 1000])
features: [torch.Size([4, 64, 1280]), torch.Size([4, 64, 3584]), torch.Size([4, 64, 2048])]
episode: ['s02e23a', 's04e05a', 'wolf08', 's01e02b']
fmri: torch.Size([4, 4, 64, 1000])
features: [torch.Size([4, 64, 1280]), torch.Size([4, 64, 3584]), torch.Size([4, 64, 2048])]
episode: ['s05e18b', 's03e02a', 's03e25b', 's04e22b']
fmri: torch.Size([4, 4, 64, 1000])
features:

episode: ['s01e21b', 's03e08a', 's01e12a', 's01e10b']
fmri: torch.Size([4, 4, 64, 1000])
features: [torch.Size([4, 64, 1280]), torch.Size([4, 64, 3584]), torch.Size([4, 64, 2048])]
episode: ['s04e10b', 's03e09b', 's05e23d', 's01e08a']
fmri: torch.Size([4, 4, 64, 1000])
features: [torch.Size([4, 64, 1280]), torch.Size([4, 64, 3584]), torch.Size([4, 64, 2048])]
episode: ['s03e23b', 's05e02b', 's01e19a', 's03e01b']
fmri: torch.Size([4, 4, 64, 1000])
features: [torch.Size([4, 64, 1280]), torch.Size([4, 64, 3584]), torch.Size([4, 64, 2048])]
episode: ['s03e21a', 's04e02b', 's01e20b', 's01e06b']
fmri: torch.Size([4, 4, 64, 1000])
features: [torch.Size([4, 64, 1280]), torch.Size([4, 64, 3584]), torch.Size([4, 64, 2048])]
episode: ['s05e03a', 's05e17b', 's02e04b', 'bourne07']
fmri: torch.Size([4, 4, 64, 1000])
features: [torch.Size([4, 64, 1280]), torch.Size([4, 64, 3584]), torch.Size([4, 64, 2048])]
episode: ['s05e19a', 'bourne08', 'wolf10', 's01e15b']
fmri: torch.Size([4, 4, 64, 1000])
featu