In [1]:
import re
import time
from collections import defaultdict
from pathlib import Path

import h5py
import numpy as np
import torch
from torch.utils.data import IterableDataset
from tqdm import tqdm

In [2]:
SUBJECTS = (1, 2, 3, 5)

In [3]:
root_dir = Path("..").resolve()

data_dir = root_dir / "algonauts_2025.competitors"

out_dir = Path(".") / "output/feature_encoding_v1"
out_dir.mkdir(exist_ok=True, parents=True)
print("Saving output to:", out_dir.resolve())

Saving output to: /home/connor/algonauts2025/feature_encoding/output/feature_encoding_v1


## Data loading

Aligned cross-subject fmri data loader. The data loader samples clips of synchronized activity from the same friends episodes across subjects. Each clip is shape `(n_subs, sample_length, dim)`.

We also load pre-extracted features of the shape `(sample_length, dim)`.

In [4]:
def parse_friends_run(run: str):
    match = re.match(r"s([0-9]+)e([0-9]+)([a-z])", run)
    if match is None:
        raise ValueError(f"Invalid friends run {run}")

    season = int(match.group(1))
    episode = int(match.group(2))
    part = match.group(3)
    return season, episode, part

In [5]:
def load_algonauts2025_friends_fmri(
    root: str | Path,
    subjects: list[int] | None = None,
    seasons: list[int] | None = None,
) -> dict[str, np.ndarray]:
    subjects = subjects or SUBJECTS
    seasons = seasons or list(range(1, 7))

    files = {
        sub: h5py.File(
            Path(root)
            / f"fmri/sub-{sub:02d}/func"
            / f"sub-{sub:02d}_task-friends_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_desc-s123456_bold.h5"
        )
        for sub in subjects
    }

    episode_key_maps = defaultdict(dict)
    seasons_set = set(seasons)
    for sub, file in files.items():
        for key in file.keys():
            entities = dict([ent.split("-", 1) for ent in key.split("_")])
            episode = entities["task"]
            season, _, _ = parse_friends_run(episode)
            if season in seasons_set:
                episode_key_maps[episode][sub] = key

    episode_list = sorted(
        [
            episode for episode, map in episode_key_maps.items()
            if len(map) == len(subjects)
        ]
    )

    data = {}
    for episode in episode_list:
        samples = []
        length = None
        for sub in subjects:
            key = episode_key_maps[episode][sub]
            sample = files[sub][key][:]
            sub_length = len(sample)
            samples.append(sample)
            length = min(length, sub_length) if length else sub_length
        data[episode] = np.stack([sample[:length] for sample in samples])
    
    return data

In [6]:
def parse_movie10_run(run: str):
    match = re.match(r"([a-z]+)([0-9]+)", run)
    if match is None:
        raise ValueError(f"Invalid movie run {run}")

    movie = match.group(1)
    part = int(match.group(2))
    return movie, part

In [7]:
def load_algonauts2025_movie10_fmri(
    root: str | Path,
    subjects: list[int] | None = None,
    movies: list[str] | None = None,
    runs: list[int] | None = None,
) -> dict[str, np.ndarray]:
    subjects = subjects or SUBJECTS
    movies = movies or ["bourne", "wolf", "figures", "life"]
    runs = runs or [1, 2]

    files = {
        sub: h5py.File(
            Path(root)
            / f"fmri/sub-{sub:02d}/func"
            / f"sub-{sub:02d}_task-movie10_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_bold.h5"
        )
        for sub in subjects
    }

    episode_key_maps = defaultdict(dict)
    movies_set = set(movies)
    for sub, file in files.items():
        for key in file.keys():
            entities = dict([ent.split("-", 1) for ent in key.split("_")])
            episode = entities["task"]
            run = int(entities.get("run", 1))
            movie, _ = parse_movie10_run(episode)
            if movie in movies_set and run in runs:
                episode_key_maps[(episode, run)][sub] = key

    episode_list = sorted(
        [
            episode for episode, map in episode_key_maps.items()
            if len(map) == len(subjects)
        ]
    )

    data = {}
    for episode in episode_list:
        samples = []
        length = None
        for sub in subjects:
            key = episode_key_maps[episode][sub]
            sample = files[sub][key][:]
            sub_length = len(sample)
            samples.append(sample)
            length = min(length, sub_length) if length else sub_length
        data[episode] = np.stack([sample[:length] for sample in samples])
    
    return data

In [8]:
friends_train_fmri = load_algonauts2025_friends_fmri(data_dir, seasons=range(1, 6))
friends_val_fmri = load_algonauts2025_friends_fmri(data_dir, seasons=[6])
movie10_test_fmri = load_algonauts2025_movie10_fmri(data_dir, runs=[1])

In [9]:
print(friends_train_fmri.keys())
print(friends_val_fmri.keys())
print(movie10_test_fmri.keys())

dict_keys(['s01e01a', 's01e01b', 's01e02a', 's01e02b', 's01e03a', 's01e03b', 's01e04a', 's01e04b', 's01e05a', 's01e05b', 's01e06a', 's01e06b', 's01e07a', 's01e07b', 's01e08a', 's01e08b', 's01e09a', 's01e09b', 's01e10a', 's01e10b', 's01e11a', 's01e11b', 's01e12a', 's01e12b', 's01e13a', 's01e13b', 's01e14a', 's01e14b', 's01e15a', 's01e15b', 's01e16a', 's01e16b', 's01e17a', 's01e17b', 's01e18a', 's01e18b', 's01e19a', 's01e19b', 's01e20a', 's01e20b', 's01e21a', 's01e21b', 's01e22a', 's01e22b', 's01e23a', 's01e23b', 's01e24a', 's01e24b', 's02e01a', 's02e01b', 's02e02a', 's02e02b', 's02e03a', 's02e03b', 's02e04a', 's02e04b', 's02e05a', 's02e05b', 's02e06a', 's02e06b', 's02e07a', 's02e07b', 's02e08a', 's02e08b', 's02e09a', 's02e09b', 's02e10a', 's02e10b', 's02e11a', 's02e11b', 's02e12a', 's02e12b', 's02e13a', 's02e13b', 's02e14a', 's02e14b', 's02e15a', 's02e15b', 's02e16a', 's02e16b', 's02e17a', 's02e17b', 's02e18a', 's02e18b', 's02e19a', 's02e19b', 's02e20a', 's02e20b', 's02e21a', 's02e21b',

In [10]:
sample = friends_train_fmri["s01e05b"]
print("Sample shape (NTC):", sample.shape, sample.dtype)

Sample shape (NTC): (4, 468, 1000) float32


In [11]:
class Algonauts2025Dataset(IterableDataset):
    def __init__(
        self,
        fmri_data: dict[str, np.ndarray],
        feat_data: list[dict[str, np.ndarray]] | None = None,
        sample_length: int | None = 128,
        num_samples: int | None = None,
        shuffle: bool = True,
        seed: int | None = None,
    ):
        self.fmri_data = fmri_data
        self.feat_data = feat_data

        self.episode_list = list(fmri_data)
        self.sample_length = sample_length
        self.num_samples = num_samples
        self.shuffle = shuffle
        self.seed = seed

        self._rng = np.random.default_rng(seed)
    
    def _iter_shuffle(self):
        sample_idx = 0
        while True:
            episode_order = self._rng.permutation(len(self.episode_list))

            for ii in episode_order:
                episode = self.episode_list[ii]
                feat_episode = episode[0] if isinstance(episode, tuple) else episode

                fmri = torch.from_numpy(self.fmri_data[episode]).float()
    
                if self.feat_data:
                    feats = [torch.from_numpy(data[feat_episode]).float() for data in self.feat_data]
                else:
                    feats = feat_samples = None

                # Nb, fmri and feature length often off by 1 or 2.
                # But assuming time locked to start.
                length = fmri.shape[1]
                if feats:
                    length = min(length, min(feat.shape[0] for feat in feats))

                if self.sample_length:
                    # Random segment of run
                    offset = int(self._rng.integers(0, length - self.sample_length + 1))
                    fmri_sample = fmri[:, offset: offset + self.sample_length]
                    if feats:
                        feat_samples = [
                            feat[offset: offset + self.sample_length] for feat in feats
                        ]
                else:
                    # Take full run
                    # Nb this only works for batch size 1 since runs are different length
                    fmri_sample = fmri[:, :length]
                    if feats:
                        feat_samples = [feat[:length] for feat in feats]

                if feat_samples:
                    yield episode, fmri_sample, feat_samples
                else:
                    yield episode, fmri_sample

                sample_idx += 1
                if self.num_samples and sample_idx >= self.num_samples:
                    return

    def _iter_ordered(self):
        sample_idx = 0
        for episode in self.episode_list:
            feat_episode = episode[0] if isinstance(episode, tuple) else episode
            fmri = torch.from_numpy(self.fmri_data[episode]).float()
            if self.feat_data:
                feats = [torch.from_numpy(data[feat_episode]).float() for data in self.feat_data]
            else:
                feats = feat_samples = None

            length = fmri.shape[1]
            if feats:
                length = min(length, min(feat.shape[0] for feat in feats))

            sample_length = self.sample_length or length

            for offset in range(0, length - sample_length + 1, sample_length):
                fmri_sample = fmri[:, offset: offset + sample_length]
                if feats:
                    feat_samples = [feat[offset: offset + sample_length] for feat in feats]

                if feat_samples:
                    yield episode, fmri_sample, feat_samples
                else:
                    yield episode, fmri_sample

                sample_idx += 1
                if self.num_samples and sample_idx >= self.num_samples:
                    return

    def __iter__(self):
        if self.shuffle:
            yield from self._iter_shuffle()
        else:
            yield from self._iter_ordered()

## Features

- whisper
  - layers:
    - layers.12.fc2
    - layers.25.fc2
    - layers.31.fc2
    - layer_norm
  - dim: 1280
- internvl3_8b_8bit
  - layers:
    - language_model.model.layers.10.post_attention_layernorm
    - language_model.model.layers.15.post_attention_layernorm
    - language_model.model.layers.20.post_attention_layernorm
    - language_model.model.norm
  - dim: 3584
- Llama-3.2-1B
  - layers:
    - model.layers.7
    - model.layers.11
    - model.layers.15
  - dim: 2048

In [12]:
def load_medarc_features(
    root: str | Path,
    model: str,
    layer: str,
    series: str = "friends"
) -> dict[str, np.ndarray]:
    paths = sorted((Path(root) / model / series).rglob("*.h5"))

    features = {}
    for path in paths:
        episode = path.stem.split("_")[-1]  # friends_s01e01a, bourne01
        with h5py.File(path) as f:
            features[episode] = f[layer][:].squeeze()
    return features

In [13]:
def load_merged_features(
    path: str | Path,
    layer: str,
) -> dict[str, np.ndarray]:
    with h5py.File(path) as f:
        features = {k: f[k][layer][:] for k in f}
    return features

In [14]:
medarc_feature_root = root_dir / "features.medarc"
merged_feature_root = root_dir / "features.merged"

stimuli_features_friends = {}
stimuli_features_movie10 = {}

In [16]:
medarc_models_layers = [
    ("whisper", "layers.12.fc2"),
    ("whisper", "layers.31.fc2"),
    ("internvl3_8b_8bit", "language_model.model.layers.10.post_attention_layernorm"),
    ("internvl3_8b_8bit", "language_model.model.layers.20.post_attention_layernorm"),
]

for model, layer in medarc_models_layers:
    stimuli_features_friends[f"{model}/{layer}"] = load_medarc_features(
        medarc_feature_root, model=model, layer=layer, series="friends",
    )
    stimuli_features_movie10[f"{model}/{layer}"] = load_medarc_features(
        medarc_feature_root, model=model, layer=layer, series="movie10",
    )

In [17]:
merged_models_layers = [
    ("Llama-3.2-1B", "model.layers.7"),
    ("Llama-3.2-1B", "model.layers.15"),
]

for model, layer in merged_models_layers:
    # TODO: this path is awkward
    stimuli_features_friends[f"{model}/{layer}"] = load_merged_features(
        path=merged_feature_root / f"friends/meta-llama__{model}/context-long.h5",
        layer=layer,
    )
    stimuli_features_movie10[f"{model}/{layer}"] = load_merged_features(
        path=merged_feature_root / f"movie10/meta-llama__{model}/context-long.h5",
        layer=layer,
    )

In [18]:
dataset = Algonauts2025Dataset(
    movie10_test_fmri,
    list(stimuli_features_movie10.values()),
    sample_length=64,
    num_samples=10000,
    shuffle=True,
    seed=42,
)

In [19]:
total_bytes = 0
tic = time.monotonic()
for task, fmri_sample, feat_samples in tqdm(dataset):
    total_bytes += fmri_sample.numel() * 4
rt = time.monotonic() - tic
tput = total_bytes / 1024 ** 2 / rt 
print(f"run time={rt:.3f}s, MB/s={tput:.0f}")

187it [00:00, 1866.96it/s]

10000it [00:02, 4038.60it/s]

run time=2.480s, MB/s=3938





## Model


In [20]:
from functools import partial

import torch
import torch.nn.functional as F
from torch import nn

In [21]:
class CausalConv1d(nn.Conv1d):
    """Conv1d layer with a causal mask, to only "attend" to past time points."""
    attn_mask: torch.Tensor

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: str | int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
    ):
        assert kernel_size % 2 == 1, "causal conv requires odd kernel size"
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )

        attn_mask = torch.zeros(kernel_size)
        attn_mask[:kernel_size // 2 + 1] = 1.0
        self.weight.data.mul_(attn_mask)
        self.register_buffer("attn_mask", attn_mask)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        weight = self.weight * self.attn_mask
        return F.conv1d(
            input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )

In [22]:
class ConvLinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        kernel_size: int = 11,
        causal: bool = False,
    ):
        super().__init__()
        conv_layer = CausalConv1d if causal else nn.Conv1d
        self.conv = conv_layer(
            in_features,
            in_features,
            kernel_size=kernel_size,
            padding="same",
            groups=in_features,
        )
        self.fc = nn.Linear(in_features, out_features)

    def forward(self, x: torch.Tensor):
        # x: (N, L, C)
        x = x.transpose(-1, -2)
        x = self.conv(x)
        x = x.transpose(-1, -2)
        x = self.fc(x)
        return x


class LinearConv(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        kernel_size: int = 11,
        causal: bool = False,
    ):
        super().__init__()
        conv_layer = CausalConv1d if causal else nn.Conv1d
        self.fc = nn.Linear(in_features, out_features)
        self.conv = conv_layer(
            out_features,
            out_features,
            kernel_size=kernel_size,
            padding="same",
            groups=out_features,
        )

    def forward(self, x: torch.Tensor):
        # x: (N, L, C)
        x = self.fc(x)
        x = x.transpose(-1, -2)
        x = self.conv(x)
        x = x.transpose(-1, -2)
        return x

In [23]:
encoder = ConvLinear(
    in_features=1000,
    out_features=256,
    causal=True
)
print(encoder)

# (N, L, C)
x = torch.randn(16, 64, 1000)
embed = encoder.forward(x)
print(embed.shape)

ConvLinear(
  (conv): CausalConv1d(1000, 1000, kernel_size=(11,), stride=(1,), padding=same, groups=1000)
  (fc): Linear(in_features=1000, out_features=256, bias=True)
)
torch.Size([16, 64, 256])


In [24]:
class FeatEmbed(nn.Module):
    def __init__(
        self,
        feat_dim: int = 2048,
        embed_dim: int = 256,
        kernel_size: int = 33,
        causal: bool = True,
        normalize: bool = True,
    ):
        super().__init__()
        self.norm = nn.LayerNorm(feat_dim) if normalize else nn.Identity()
        if kernel_size > 1:
            self.embed = LinearConv(
                feat_dim, embed_dim, kernel_size=kernel_size, causal=causal
            )
        else:
            self.embed = nn.Linear(feat_dim, embed_dim)
    
    def forward(self, input: torch.Tensor):
        return self.embed(self.norm(input))

In [25]:
class MultiSubjectConvLinearEncoderV3(nn.Module):
    """
    - Added support for multiple features
    """
    weight: torch.Tensor

    def __init__(
        self,
        num_subjects: int = 4,
        feat_dims: tuple[int, ...] = (2048,),
        embed_dim: int = 256,
        target_dim: int = 1000,
        encoder_kernel_size: int = 33,
        decoder_kernel_size: int = 0,
        encoder_causal: bool = True,
        encoder_normalize: bool = True,
    ):
        super().__init__()
        self.num_subjects = num_subjects

        self.feat_embeds = nn.ModuleList(
            [
                FeatEmbed(
                    feat_dim,
                    embed_dim,
                    kernel_size=encoder_kernel_size,
                    causal=encoder_causal,
                    normalize=encoder_normalize,
                )
                for feat_dim in feat_dims
            ]
        )

        if decoder_kernel_size > 1:
            decoder_linear = partial(ConvLinear, kernel_size=decoder_kernel_size)
        else:
            decoder_linear = nn.Linear

        self.shared_decoder = nn.Linear(embed_dim, target_dim)
        self.subject_decoders = nn.ModuleList(
            [
                decoder_linear(embed_dim, target_dim) for _ in range(num_subjects)
            ]
        )
        self.apply(init_weights)
    
    def forward(self, inputs: list[torch.Tensor]):
        # input: (N, L, D)
        # output: (N, S, L, C)
        embed = sum(feat_embed(input) for input, feat_embed in zip(inputs, self.feat_embeds))
        shared_output = self.shared_decoder(embed)
        subject_output = torch.stack(
            [decoder(embed) for decoder in self.subject_decoders],
            dim=1,
        )
        output = subject_output + shared_output[:, None]
        return output


def init_weights(m: nn.Module):
    if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)):
        nn.init.trunc_normal_(m.weight, std=0.02)
        nn.init.constant_(m.bias, 0)

In [26]:
encoder = MultiSubjectConvLinearEncoderV3(embed_dim=64)
print(encoder)

MultiSubjectConvLinearEncoderV3(
  (feat_embeds): ModuleList(
    (0): FeatEmbed(
      (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (embed): LinearConv(
        (fc): Linear(in_features=2048, out_features=64, bias=True)
        (conv): CausalConv1d(64, 64, kernel_size=(33,), stride=(1,), padding=same, groups=64)
      )
    )
  )
  (shared_decoder): Linear(in_features=64, out_features=1000, bias=True)
  (subject_decoders): ModuleList(
    (0-3): 4 x Linear(in_features=64, out_features=1000, bias=True)
  )
)


In [27]:
# (N, L, C)
x = torch.randn(16, 256, 2048)
z = encoder.forward([x])
print(z.shape)

torch.Size([16, 4, 256, 1000])


Test loading cross encoding checkpoint

In [28]:
cross_encoding_dir = root_dir / "cross_encoding/output/cross_encoding_v3"

cross_encoder_ckpt = torch.load(
    cross_encoding_dir / "ckpt.pt", map_location="cpu", weights_only=False
)

missing_keys, unexpected_keys = encoder.load_state_dict(
    cross_encoder_ckpt["model"], strict=False
)
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)

Missing keys: ['feat_embeds.0.norm.weight', 'feat_embeds.0.norm.bias', 'feat_embeds.0.embed.fc.weight', 'feat_embeds.0.embed.fc.bias', 'feat_embeds.0.embed.conv.weight', 'feat_embeds.0.embed.conv.bias', 'feat_embeds.0.embed.conv.attn_mask']
Unexpected keys: ['weight', 'shared_encoder.weight', 'shared_encoder.bias', 'subject_encoders.0.fc.weight', 'subject_encoders.0.fc.bias', 'subject_encoders.0.conv.weight', 'subject_encoders.0.conv.bias', 'subject_encoders.1.fc.weight', 'subject_encoders.1.fc.bias', 'subject_encoders.1.conv.weight', 'subject_encoders.1.conv.bias', 'subject_encoders.2.fc.weight', 'subject_encoders.2.fc.bias', 'subject_encoders.2.conv.weight', 'subject_encoders.2.conv.bias', 'subject_encoders.3.fc.weight', 'subject_encoders.3.fc.bias', 'subject_encoders.3.conv.weight', 'subject_encoders.3.conv.bias']


## Training

In [29]:
import math

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from timm.utils import AverageMeter, random_seed

In [30]:
def train_one_epoch(
    *,
    epoch: int,
    model: torch.nn.Module,
    train_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    epoch_batches: int | None,
    device: torch.device,
):
    model.train()
    
    use_cuda = device.type == "cuda"
    if use_cuda:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    loss_m = AverageMeter()
    data_time_m = AverageMeter()
    step_time_m = AverageMeter()

    epoch_batches = len(train_loader) if epoch_batches is None else epoch_batches
    first_step = epoch * epoch_batches

    end = time.monotonic()
    for batch_idx, (_, sample, feats) in enumerate(train_loader):
        step = first_step + batch_idx
        feats = [feat.to(device) for feat in feats]
        sample = sample.to(device)
        batch_size = sample.size(0)
        data_time = time.monotonic() - end

        # forward pass
        output = model(feats)
        loss = F.mse_loss(output, sample)
        loss_item = loss.item()

        if math.isnan(loss_item) or math.isinf(loss_item):
            raise RuntimeError("NaN/Inf loss encountered on step %d; exiting", step)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # end of iteration timing
        if use_cuda:
            torch.cuda.synchronize()
        step_time = time.monotonic() - end

        loss_m.update(loss_item, batch_size)
        data_time_m.update(data_time, batch_size)
        step_time_m.update(step_time, batch_size)

        if step % 10 == 0:
            tput = batch_size / step_time_m.avg
            if use_cuda:
                alloc_mem_gb = torch.cuda.max_memory_allocated() / 1e9
                res_mem_gb = torch.cuda.max_memory_reserved() / 1e9
            else:
                alloc_mem_gb = res_mem_gb = 0.0

            print(
                f"Train: {epoch:>3d} [{batch_idx:>3d}/{epoch_batches}][{step:>6d}]"
                f"  Loss: {loss_m.val:#.3g} ({loss_m.avg:#.3g})"
                f"  Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s"
                f"  Mem: {alloc_mem_gb:.2f},{res_mem_gb:.2f} GB"
            )

        # Restart timer for next iteration
        end = time.monotonic()

In [31]:
@torch.no_grad()
def validate(
    *,
    epoch: int,
    model: torch.nn.Module,
    val_loader: DataLoader,
    device: torch.device,
):
    model.eval()

    use_cuda = device.type == "cuda"

    loss_m = AverageMeter()
    data_time_m = AverageMeter()
    step_time_m = AverageMeter()

    samples = []
    outputs = []

    end = time.monotonic()
    for batch_idx, (_, sample, feats) in enumerate(val_loader):
        sample = sample.to(device)
        feats = [feat.to(device) for feat in feats]
        batch_size = sample.size(0)
        data_time = time.monotonic() - end

        # forward pass
        output = model(feats)
        loss = F.mse_loss(output, sample)
        loss_item = loss.item()

        # end of iteration timing
        if use_cuda:
            torch.cuda.synchronize()
        step_time = time.monotonic() - end

        loss_m.update(loss_item, batch_size)
        data_time_m.update(data_time, batch_size)
        step_time_m.update(step_time, batch_size)

        N, S, L, C = sample.shape
        assert N, S == (1, 4)
        samples.append(sample.cpu().numpy().swapaxes(0, 1).reshape((S, N*L, C)))
        outputs.append(output.cpu().numpy().swapaxes(0, 1).reshape((S, N*L, C)))

        # Reset timer
        end = time.monotonic()

    # (S, N, C)
    samples = np.concatenate(samples, axis=1)
    outputs = np.concatenate(outputs, axis=1)

    metrics = {}

    # Encoding accuracy metrics
    dim = samples.shape[-1]
    acc = 0.0
    acc_map = np.zeros(dim)
    for ii, sub in enumerate(SUBJECTS):
        y_true = samples[ii].reshape(-1, dim)
        y_pred = outputs[ii].reshape(-1, dim)
        metrics[f"acc_map_sub-{sub}"] = acc_map_i = pearsonr_score(y_true, y_pred)
        metrics[f"acc_sub-{sub}"] = acc_i = np.mean(acc_map_i)
        acc_map += acc_map_i / len(SUBJECTS)
        acc += acc_i / len(SUBJECTS)

    metrics["acc_map_avg"] = acc_map
    metrics["acc_avg"] = acc
    accs_fmt = ",".join(
        f"{val:.3f}" for key, val in metrics.items() if key.startswith("acc_sub-")
    )

    tput = batch_size / step_time_m.avg
    print(
        f"Val: {epoch:>3d}"
        f"  Loss: {loss_m.avg:#.3g}"
        f"  Acc: {accs_fmt} ({acc:.3f})"
        f"  Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s"
    )

    return acc, metrics


def pearsonr_score(
    y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-7
) -> np.ndarray:
    assert y_true.ndim == y_pred.ndim == 2

    y_true = y_true - y_true.mean(axis=0)
    y_true = y_true / (np.linalg.norm(y_true, axis=0) + eps)

    y_pred = y_pred - y_pred.mean(axis=0)
    y_pred = y_pred / (np.linalg.norm(y_pred, axis=0) + eps)

    score = (y_true * y_pred).sum(axis=0)
    return score

In [32]:
seed = 3315
batch_size = 4
sample_length = 256
n_train_samples = 500
lr = 3e-4
weight_decay = 0.1
epochs = 10

embed_dim = 64
encoder_kernel_size = 33
decoder_kernel_size = 0

freeze_decoder = True

In [33]:
config = {
    k: globals()[k] for k in
    [
        "seed",
        "batch_size",
        "sample_length",
        "n_train_samples",
        "lr",
        "weight_decay",
        "epochs",
        "embed_dim",
        "encoder_kernel_size",
        "decoder_kernel_size",
        "freeze_decoder",
        ]
}
print(config)

{'seed': 3315, 'batch_size': 4, 'sample_length': 256, 'n_train_samples': 500, 'lr': 0.0003, 'weight_decay': 0.1, 'epochs': 10, 'embed_dim': 64, 'encoder_kernel_size': 33, 'decoder_kernel_size': 0, 'freeze_decoder': True}


In [34]:
random_seed(seed)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print("Running on:", device)

Running on: cuda:1


In [35]:
train_dataset = Algonauts2025Dataset(
    friends_train_fmri,
    list(stimuli_features_friends.values()),
    sample_length=sample_length,
    num_samples=n_train_samples,
    shuffle=True,
    seed=42,
)

val_dataset = Algonauts2025Dataset(
    friends_val_fmri,
    list(stimuli_features_friends.values()),
    sample_length=None,
    shuffle=False,
)

test_dataset = Algonauts2025Dataset(
    movie10_test_fmri,
    list(stimuli_features_movie10.values()),
    sample_length=None,
    shuffle=False,
)

In [36]:
train_loader = DataLoader(train_dataset, batch_size=batch_size)
val_loader = DataLoader(val_dataset, batch_size=1)
test_loader = DataLoader(test_dataset, batch_size=1)

In [37]:
_, sample, feats = next(iter(train_loader))
print("Sample shape:", tuple(sample.shape))
print("Feats shape:", [tuple(feat.shape) for feat in feats])
print("Sample dtype:", sample.dtype)
print("Feats dtype:", [feat.dtype for feat in feats])

Sample shape: (4, 4, 256, 1000)
Feats shape: [(4, 256, 1280), (4, 256, 1280), (4, 256, 3584), (4, 256, 3584), (4, 256, 2048), (4, 256, 2048)]
Sample dtype: torch.float32
Feats dtype: [torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, torch.float32]


In [38]:
model = MultiSubjectConvLinearEncoderV3(
    feat_dims=[1280, 1280, 3584, 3584, 2048, 2048],
    embed_dim=embed_dim,
    encoder_kernel_size=encoder_kernel_size,
    decoder_kernel_size=decoder_kernel_size,
)

if freeze_decoder:
    missing_keys, unexpected_keys = model.load_state_dict(
        cross_encoder_ckpt["model"], strict=False
    )
    for p in model.shared_decoder.parameters():
        p.requires_grad_(False)
    for p in model.subject_decoders.parameters():
        p.requires_grad_(False)
    print("Missing keys:", missing_keys)
    print("Unexpected keys:", unexpected_keys)

model = model.to(device)

param_count = sum(p.numel() for p in model.parameters())
train_param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model:", model)
print(f"Num params: {param_count/1e6:.2f}M ({train_param_count/1e6:.2f}M)")

Missing keys: ['feat_embeds.0.norm.weight', 'feat_embeds.0.norm.bias', 'feat_embeds.0.embed.fc.weight', 'feat_embeds.0.embed.fc.bias', 'feat_embeds.0.embed.conv.weight', 'feat_embeds.0.embed.conv.bias', 'feat_embeds.0.embed.conv.attn_mask', 'feat_embeds.1.norm.weight', 'feat_embeds.1.norm.bias', 'feat_embeds.1.embed.fc.weight', 'feat_embeds.1.embed.fc.bias', 'feat_embeds.1.embed.conv.weight', 'feat_embeds.1.embed.conv.bias', 'feat_embeds.1.embed.conv.attn_mask', 'feat_embeds.2.norm.weight', 'feat_embeds.2.norm.bias', 'feat_embeds.2.embed.fc.weight', 'feat_embeds.2.embed.fc.bias', 'feat_embeds.2.embed.conv.weight', 'feat_embeds.2.embed.conv.bias', 'feat_embeds.2.embed.conv.attn_mask', 'feat_embeds.3.norm.weight', 'feat_embeds.3.norm.bias', 'feat_embeds.3.embed.fc.weight', 'feat_embeds.3.embed.fc.bias', 'feat_embeds.3.embed.conv.weight', 'feat_embeds.3.embed.conv.bias', 'feat_embeds.3.embed.conv.attn_mask', 'feat_embeds.4.norm.weight', 'feat_embeds.4.norm.bias', 'feat_embeds.4.embed.fc.w

In [39]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
epoch_batches = n_train_samples // batch_size

In [40]:
for epoch in range(epochs):
    print("Train friends s1-5")
    train_one_epoch(
        epoch=epoch,
        model=model,
        train_loader=train_loader,
        optimizer=optimizer,
        epoch_batches=epoch_batches,
        device=device,
    )
    print("Eval friends s6")
    val_acc, val_metrics = validate(
        epoch=epoch,
        model=model,
        val_loader=val_loader,
        device=device,
    )
    print("Eval movie10")
    test_acc, test_metrics = validate(
        epoch=epoch,
        model=model,
        val_loader=test_loader,
        device=device,
    )

Train friends s1-5
Train:   0 [  0/125][     0]  Loss: 0.370 (0.370)  Time: 0.371,1.985 2/s  Mem: 0.00,0.00 GB
Train:   0 [ 10/125][    10]  Loss: 0.364 (0.370)  Time: 0.050,0.250 16/s  Mem: 0.00,0.00 GB
Train:   0 [ 20/125][    20]  Loss: 0.359 (0.365)  Time: 0.034,0.141 28/s  Mem: 0.00,0.00 GB
Train:   0 [ 30/125][    30]  Loss: 0.349 (0.362)  Time: 0.029,0.102 39/s  Mem: 0.00,0.00 GB
Train:   0 [ 40/125][    40]  Loss: 0.361 (0.361)  Time: 0.025,0.082 49/s  Mem: 0.00,0.00 GB
Train:   0 [ 50/125][    50]  Loss: 0.364 (0.360)  Time: 0.023,0.070 57/s  Mem: 0.00,0.00 GB
Train:   0 [ 60/125][    60]  Loss: 0.345 (0.358)  Time: 0.022,0.061 65/s  Mem: 0.00,0.00 GB
Train:   0 [ 70/125][    70]  Loss: 0.341 (0.356)  Time: 0.021,0.055 72/s  Mem: 0.00,0.00 GB
Train:   0 [ 80/125][    80]  Loss: 0.336 (0.355)  Time: 0.020,0.050 79/s  Mem: 0.00,0.00 GB
Train:   0 [ 90/125][    90]  Loss: 0.345 (0.353)  Time: 0.019,0.047 86/s  Mem: 0.00,0.00 GB
Train:   0 [100/125][   100]  Loss: 0.338 (0.352)  T

## Results

First layer per model

```
Eval friends s6
Val:   9  Loss: 0.335  Acc: 0.266,0.271,0.292,0.251 (0.270)  Time: 0.003,0.004 235/s
Eval movie10
Val:   9  Loss: 0.356  Acc: 0.214,0.190,0.208,0.177 (0.197)  Time: 0.003,0.004 257/s
```

First and last layers per model

```
Eval friends s6
Val:   9  Loss: 0.335  Acc: 0.270,0.276,0.296,0.255 (0.274)  Time: 0.006,0.007 150/s
Eval movie10
Val:   9  Loss: 0.356  Acc: 0.218,0.194,0.211,0.178 (0.200)  Time: 0.005,0.006 164/s
```