# Cross-subject encoding

Predicting each subject's activity using the synchronized activity of the other subjects. (Idk if any papers have done this, probably they have. Please someone lmk if they know of a reference.)

In [1]:
import re
import time
from collections import defaultdict
from pathlib import Path

import h5py
import numpy as np
import torch
from torch.utils.data import IterableDataset
from tqdm import tqdm

In [2]:
SUBJECTS = (1, 2, 3, 5)

In [3]:
root_dir = Path("..").resolve()

data_dir = root_dir / "algonauts_2025.competitors"

out_dir = Path(".") / "output/cross_encoding_v3"
out_dir.mkdir(exist_ok=True, parents=True)
print("Saving output to:", out_dir.resolve())

Saving output to: /home/connor/algonauts2025/cross_encoding/output/cross_encoding_v3


## Data loading

Aligned cross-subject fmri data loader. The data loader samples clips of synchronized activity from the same friends episodes across subjects. Each clip is shape `(n_subs, sample_length, dim)`.

We also load pre-extracted features of the shape `(sample_length, dim)`.

In [4]:
def parse_friends_run(run: str):
    match = re.match(r"s([0-9]+)e([0-9]+)([a-z])", run)
    if match is None:
        raise ValueError(f"Invalid friends run {run}")

    season = int(match.group(1))
    episode = int(match.group(2))
    part = match.group(3)
    return season, episode, part

In [5]:
def load_algonauts2025_friends_fmri(
    root: str | Path,
    subjects: list[int] | None = None,
    seasons: list[int] | None = None,
) -> dict[str, np.ndarray]:
    subjects = subjects or SUBJECTS
    seasons = seasons or list(range(1, 7))

    files = {
        sub: h5py.File(
            Path(root)
            / f"fmri/sub-{sub:02d}/func"
            / f"sub-{sub:02d}_task-friends_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_desc-s123456_bold.h5"
        )
        for sub in subjects
    }

    episode_key_maps = defaultdict(dict)
    seasons_set = set(seasons)
    for sub, file in files.items():
        for key in file.keys():
            entities = dict([ent.split("-", 1) for ent in key.split("_")])
            episode = entities["task"]
            season, _, _ = parse_friends_run(episode)
            if season in seasons_set:
                episode_key_maps[episode][sub] = key

    episode_list = sorted(
        [
            episode for episode, map in episode_key_maps.items()
            if len(map) == len(subjects)
        ]
    )

    data = {}
    for episode in episode_list:
        samples = []
        length = None
        for sub in subjects:
            key = episode_key_maps[episode][sub]
            sample = files[sub][key][:]
            sub_length = len(sample)
            samples.append(sample)
            length = min(length, sub_length) if length else sub_length
        data[episode] = np.stack([sample[:length] for sample in samples])
    
    return data

In [6]:
def parse_movie10_run(run: str):
    match = re.match(r"([a-z]+)([0-9]+)", run)
    if match is None:
        raise ValueError(f"Invalid movie run {run}")

    movie = match.group(1)
    part = int(match.group(2))
    return movie, part

In [7]:
def load_algonauts2025_movie10_fmri(
    root: str | Path,
    subjects: list[int] | None = None,
    movies: list[str] | None = None,
    runs: list[int] | None = None,
) -> dict[str, np.ndarray]:
    subjects = subjects or SUBJECTS
    movies = movies or ["bourne", "wolf", "figures", "life"]
    runs = runs or [1, 2]

    files = {
        sub: h5py.File(
            Path(root)
            / f"fmri/sub-{sub:02d}/func"
            / f"sub-{sub:02d}_task-movie10_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_bold.h5"
        )
        for sub in subjects
    }

    episode_key_maps = defaultdict(dict)
    movies_set = set(movies)
    for sub, file in files.items():
        for key in file.keys():
            entities = dict([ent.split("-", 1) for ent in key.split("_")])
            episode = entities["task"]
            run = int(entities.get("run", 1))
            movie, _ = parse_movie10_run(episode)
            if movie in movies_set and run in runs:
                episode_key_maps[(episode, run)][sub] = key

    episode_list = sorted(
        [
            episode for episode, map in episode_key_maps.items()
            if len(map) == len(subjects)
        ]
    )

    data = {}
    for episode in episode_list:
        samples = []
        length = None
        for sub in subjects:
            key = episode_key_maps[episode][sub]
            sample = files[sub][key][:]
            sub_length = len(sample)
            samples.append(sample)
            length = min(length, sub_length) if length else sub_length
        data[episode] = np.stack([sample[:length] for sample in samples])
    
    return data

In [8]:
friends_train_fmri = load_algonauts2025_friends_fmri(data_dir, seasons=range(1, 6))
friends_val_fmri = load_algonauts2025_friends_fmri(data_dir, seasons=[6])
movie10_test_fmri = load_algonauts2025_movie10_fmri(data_dir, runs=[1])

In [9]:
print(friends_train_fmri.keys())
print(friends_val_fmri.keys())
print(movie10_test_fmri.keys())

dict_keys(['s01e01a', 's01e01b', 's01e02a', 's01e02b', 's01e03a', 's01e03b', 's01e04a', 's01e04b', 's01e05a', 's01e05b', 's01e06a', 's01e06b', 's01e07a', 's01e07b', 's01e08a', 's01e08b', 's01e09a', 's01e09b', 's01e10a', 's01e10b', 's01e11a', 's01e11b', 's01e12a', 's01e12b', 's01e13a', 's01e13b', 's01e14a', 's01e14b', 's01e15a', 's01e15b', 's01e16a', 's01e16b', 's01e17a', 's01e17b', 's01e18a', 's01e18b', 's01e19a', 's01e19b', 's01e20a', 's01e20b', 's01e21a', 's01e21b', 's01e22a', 's01e22b', 's01e23a', 's01e23b', 's01e24a', 's01e24b', 's02e01a', 's02e01b', 's02e02a', 's02e02b', 's02e03a', 's02e03b', 's02e04a', 's02e04b', 's02e05a', 's02e05b', 's02e06a', 's02e06b', 's02e07a', 's02e07b', 's02e08a', 's02e08b', 's02e09a', 's02e09b', 's02e10a', 's02e10b', 's02e11a', 's02e11b', 's02e12a', 's02e12b', 's02e13a', 's02e13b', 's02e14a', 's02e14b', 's02e15a', 's02e15b', 's02e16a', 's02e16b', 's02e17a', 's02e17b', 's02e18a', 's02e18b', 's02e19a', 's02e19b', 's02e20a', 's02e20b', 's02e21a', 's02e21b',

In [10]:
sample = friends_train_fmri["s01e05b"]
print("Sample shape (NTC):", sample.shape, sample.dtype)

Sample shape (NTC): (4, 468, 1000) float32


In [11]:
class Algonauts2025Dataset(IterableDataset):
    def __init__(
        self,
        fmri_data: dict[str, np.ndarray],
        feat_data: list[dict[str, np.ndarray]] | None = None,
        sample_length: int | None = 128,
        num_samples: int | None = None,
        shuffle: bool = True,
        seed: int | None = None,
    ):
        self.fmri_data = fmri_data
        self.feat_data = feat_data

        self.episode_list = list(fmri_data)
        self.sample_length = sample_length
        self.num_samples = num_samples
        self.shuffle = shuffle
        self.seed = seed

        self._rng = np.random.default_rng(seed)
    
    def _iter_shuffle(self):
        sample_idx = 0
        while True:
            episode_order = self._rng.permutation(len(self.episode_list))

            for ii in episode_order:
                episode = self.episode_list[ii]
                fmri = torch.from_numpy(self.fmri_data[episode])

                if self.feat_data:
                    feats = [torch.from_numpy(data[episode]) for data in self.feat_data]
                else:
                    feats = feat_samples = None

                # Nb, fmri and feature length often off by 1 or 2.
                # But assuming time locked to start.
                length = fmri.shape[1]
                if feats:
                    length = min(length, min(feat.shape[0] for feat in feats))

                if self.sample_length:
                    # Random segment of run
                    offset = int(self._rng.integers(0, length - self.sample_length + 1))
                    fmri_sample = fmri[:, offset: offset + self.sample_length]
                    if feats:
                        feat_samples = [
                            feat[offset: offset + self.sample_length] for feat in feats
                        ]
                else:
                    # Take full run
                    # Nb this only works for batch size 1 since runs are different length
                    fmri_sample = fmri[:, :length]
                    if feats:
                        feat_samples = [feat[:length] for feat in feats]

                if feat_samples:
                    yield episode, fmri_sample, feat_samples
                else:
                    yield episode, fmri_sample

                sample_idx += 1
                if self.num_samples and sample_idx >= self.num_samples:
                    return

    def _iter_ordered(self):
        sample_idx = 0
        for episode in self.episode_list:
            fmri = torch.from_numpy(self.fmri_data[episode])
            if self.feat_data:
                feats = [torch.from_numpy(data[episode]) for data in self.feat_data]
            else:
                feats = feat_samples = None

            length = fmri.shape[1]
            if feats:
                length = min(length, min(feat.shape[0] for feat in feats))

            sample_length = self.sample_length or length

            for offset in range(0, length - sample_length + 1, sample_length):
                fmri_sample = fmri[:, offset: offset + sample_length]
                if feats:
                    feat_samples = [feat[offset: offset + sample_length] for feat in feats]

                if feat_samples:
                    yield episode, fmri_sample, feat_samples
                else:
                    yield episode, fmri_sample

                sample_idx += 1
                if self.num_samples and sample_idx >= self.num_samples:
                    return

    def __iter__(self):
        if self.shuffle:
            yield from self._iter_shuffle()
        else:
            yield from self._iter_ordered()

In [12]:
friends_train_dataset = Algonauts2025Dataset(
    friends_train_fmri,
    sample_length=64,
    num_samples=10000,
    shuffle=True,
    seed=42,
)

In [13]:
total_bytes = 0
tic = time.monotonic()
for task, fmri_sample in tqdm(friends_train_dataset):
    total_bytes += fmri_sample.numel() * 4
rt = time.monotonic() - tic
tput = total_bytes / 1024 ** 2 / rt 
print(f"run time={rt:.3f}s, MB/s={tput:.0f}")

10000it [00:00, 170811.23it/s]

run time=0.061s, MB/s=160208





## Model

Model architecture is a simple linear encoder and decoder for each subject. The encoder/decoder is "factorized" into a depthwise conv1d (to align data temporally), and a linear projection (to align data spatially).

For each subject, the input to the decoder is the average of the latents for the other three subjects.

In [14]:
from functools import partial

import torch
import torch.nn.functional as F
from torch import nn

In [15]:
class CausalConv1d(nn.Conv1d):
    """Conv1d layer with a causal mask, to only "attend" to past time points."""
    attn_mask: torch.Tensor

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: str | int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
    ):
        assert kernel_size % 2 == 1, "causal conv requires odd kernel size"
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )

        attn_mask = torch.zeros(kernel_size)
        attn_mask[:kernel_size // 2 + 1] = 1.0
        self.weight.data.mul_(attn_mask)
        self.register_buffer("attn_mask", attn_mask)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        weight = self.weight * self.attn_mask
        return F.conv1d(
            input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )

In [16]:
class ConvLinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        kernel_size: int = 11,
        causal: bool = False,
    ):
        super().__init__()
        conv_layer = CausalConv1d if causal else nn.Conv1d
        self.conv = conv_layer(
            in_features,
            in_features,
            kernel_size=kernel_size,
            padding="same",
            groups=in_features,
        )
        self.fc = nn.Linear(in_features, out_features)

    def forward(self, x: torch.Tensor):
        # x: (N, L, C)
        x = x.transpose(-1, -2)
        x = self.conv(x)
        x = x.transpose(-1, -2)
        x = self.fc(x)
        return x


class LinearConv(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        kernel_size: int = 11,
        causal: bool = False,
    ):
        super().__init__()
        conv_layer = CausalConv1d if causal else nn.Conv1d
        self.fc = nn.Linear(in_features, out_features)
        self.conv = conv_layer(
            out_features,
            out_features,
            kernel_size=kernel_size,
            padding="same",
            groups=out_features,
        )

    def forward(self, x: torch.Tensor):
        # x: (N, L, C)
        x = self.fc(x)
        x = x.transpose(-1, -2)
        x = self.conv(x)
        x = x.transpose(-1, -2)
        return x

In [17]:
encoder = ConvLinear(
    in_features=1000,
    out_features=256,
    causal=True
)
print(encoder)

# (N, L, C)
x = torch.randn(16, 64, 1000)
embed = encoder.forward(x)
print(embed.shape)

ConvLinear(
  (conv): CausalConv1d(1000, 1000, kernel_size=(11,), stride=(1,), padding=same, groups=1000)
  (fc): Linear(in_features=1000, out_features=256, bias=True)
)
torch.Size([16, 64, 256])


In [44]:
class CrossSubjectConvLinearEncoderV2(nn.Module):
    """
    - Minor refactoring
    - Added subject shared linear projections
    """
    weight: torch.Tensor

    def __init__(
        self,
        num_subjects: int = 4,
        fmri_dim: int = 1000,
        embed_dim: int = 256,
        encoder_kernel_size: int = 11,
        decoder_kernel_size: int = 11,
        normalize: bool = False,
        with_shared_encoder: bool = True,
        with_shared_decoder: bool = True,
        with_subject_encoders: bool = True,
        with_subject_decoders: bool = True,
    ):
        super().__init__()
        assert with_shared_encoder or with_subject_encoders
        assert with_shared_decoder or with_subject_decoders

        self.num_subjects = num_subjects

        if with_shared_encoder:
            self.shared_encoder = nn.Linear(fmri_dim, embed_dim)
        else:
            self.register_module("shared_encoder", None)

        if with_subject_encoders:
            if encoder_kernel_size > 1:
                encoder_fn = partial(LinearConv, kernel_size=encoder_kernel_size)
            else:
                encoder_fn = nn.Linear
            self.subject_encoders = nn.ModuleList(
                [encoder_fn(fmri_dim, embed_dim) for _ in range(num_subjects)]
            )
        else:
            self.register_module("subject_encoders", None)

        self.norm = nn.LayerNorm(embed_dim) if normalize else nn.Identity()

        if with_shared_decoder:
            self.shared_decoder = nn.Linear(embed_dim, fmri_dim)
        else:
            self.register_module("shared_decoder", None)
        
        if with_subject_decoders:
            if decoder_kernel_size > 1:
                decoder_fn = partial(ConvLinear, kernel_size=decoder_kernel_size)
            else:
                decoder_fn = nn.Linear
            self.subject_decoders = nn.ModuleList(
                [decoder_fn(embed_dim, fmri_dim) for _ in range(num_subjects)]
            )
        else:
            self.register_module("subject_decoders", None)

        # todo: could learn the averaging weights
        weight = (1.0 - torch.eye(self.num_subjects)) / (self.num_subjects - 1.0)
        self.register_buffer("weight", weight)
        self.apply(init_weights)
    
    def forward(self, input: torch.Tensor):
        # input: (N, S, L, C)
        # subject specific encoders

        if self.shared_encoder is not None:
            shared_embed = self.shared_encoder(input)
        else:
            shared_embed = 0.0
        
        if self.subject_encoders is not None:
            subject_embeds = torch.stack(
                [encoder(input[:, ii]) for ii, encoder in enumerate(self.subject_encoders)],
                dim=1,
            )
        else:
            subject_embeds = 0.0

        embed = self.norm(shared_embed + subject_embeds)

        # average pool the latents for all but target subject
        embed = torch.einsum("nslc,ts->ntlc", embed, self.weight)

        # subject specific decoders
        if self.shared_decoder is not None:
            shared_output = self.shared_decoder(embed)
        else:
            shared_output = 0.0
        
        if self.subject_decoders is not None:
            subject_outputs = torch.stack(
                [decoder(embed[:, ii]) for ii, decoder in enumerate(self.subject_decoders)],
                dim=1,
            )
        else:
            subject_outputs = 0.0
        output = shared_output + subject_outputs
        return output
    

def init_weights(m: nn.Module):
    if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)):
        nn.init.trunc_normal_(m.weight, std=0.02)
        nn.init.constant_(m.bias, 0)

In [45]:
cross_encoder = CrossSubjectConvLinearEncoderV2()
print(cross_encoder)

# (N, S, L, C)
x = torch.randn(16, 4, 64, 1000)
z = cross_encoder.forward(x)
print(z.shape)

CrossSubjectConvLinearEncoderV2(
  (shared_encoder): Linear(in_features=1000, out_features=256, bias=True)
  (subject_encoders): ModuleList(
    (0-3): 4 x LinearConv(
      (fc): Linear(in_features=1000, out_features=256, bias=True)
      (conv): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=same, groups=256)
    )
  )
  (norm): Identity()
  (shared_decoder): Linear(in_features=256, out_features=1000, bias=True)
  (subject_decoders): ModuleList(
    (0-3): 4 x ConvLinear(
      (conv): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=same, groups=256)
      (fc): Linear(in_features=256, out_features=1000, bias=True)
    )
  )
)
torch.Size([16, 4, 64, 1000])


## Training

Basic training loop, AdamW, no lr decay, no bells and whistles.

In [35]:
import math

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from timm.utils import AverageMeter, random_seed

In [36]:
def train_one_epoch(
    *,
    epoch: int,
    model: torch.nn.Module,
    train_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    epoch_batches: int | None,
    device: torch.device,
):
    model.train()
    
    use_cuda = device.type == "cuda"
    if use_cuda:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    loss_m = AverageMeter()
    data_time_m = AverageMeter()
    step_time_m = AverageMeter()

    epoch_batches = len(train_loader) if epoch_batches is None else epoch_batches
    first_step = epoch * epoch_batches

    end = time.monotonic()
    for batch_idx, (_, sample) in enumerate(train_loader):
        step = first_step + batch_idx
        sample = sample.to(device)
        batch_size = sample.size(0)
        data_time = time.monotonic() - end

        # forward pass
        output = model(sample)
        loss = F.mse_loss(output, sample)
        loss_item = loss.item()

        if math.isnan(loss_item) or math.isinf(loss_item):
            raise RuntimeError("NaN/Inf loss encountered on step %d; exiting", step)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # end of iteration timing
        if use_cuda:
            torch.cuda.synchronize()
        step_time = time.monotonic() - end

        loss_m.update(loss_item, batch_size)
        data_time_m.update(data_time, batch_size)
        step_time_m.update(step_time, batch_size)

        if step % 10 == 0:
            tput = batch_size / step_time_m.avg
            if use_cuda:
                alloc_mem_gb = torch.cuda.max_memory_allocated() / 1e9
                res_mem_gb = torch.cuda.max_memory_reserved() / 1e9
            else:
                alloc_mem_gb = res_mem_gb = 0.0

            print(
                f"Train: {epoch:>3d} [{batch_idx:>3d}/{epoch_batches}][{step:>6d}]"
                f"  Loss: {loss_m.val:#.3g} ({loss_m.avg:#.3g})"
                f"  Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s"
                f"  Mem: {alloc_mem_gb:.2f},{res_mem_gb:.2f} GB"
            )

        # Restart timer for next iteration
        end = time.monotonic()

In [37]:
@torch.no_grad()
def validate(
    *,
    epoch: int,
    model: torch.nn.Module,
    val_loader: DataLoader,
    device: torch.device,
):
    model.eval()

    use_cuda = device.type == "cuda"

    loss_m = AverageMeter()
    data_time_m = AverageMeter()
    step_time_m = AverageMeter()

    samples = []
    outputs = []

    end = time.monotonic()
    for batch_idx, (_, sample) in enumerate(val_loader):
        sample = sample.to(device)
        batch_size = sample.size(0)
        data_time = time.monotonic() - end

        # forward pass
        output = model(sample)
        loss = F.mse_loss(output, sample)
        loss_item = loss.item()

        # end of iteration timing
        if use_cuda:
            torch.cuda.synchronize()
        step_time = time.monotonic() - end

        loss_m.update(loss_item, batch_size)
        data_time_m.update(data_time, batch_size)
        step_time_m.update(step_time, batch_size)

        N, S, L, C = sample.shape
        assert N, S == (1, 4)
        samples.append(sample.cpu().numpy().swapaxes(0, 1).reshape((S, N*L, C)))
        outputs.append(output.cpu().numpy().swapaxes(0, 1).reshape((S, N*L, C)))

        # Reset timer
        end = time.monotonic()

    # (S, N, C)
    samples = np.concatenate(samples, axis=1)
    outputs = np.concatenate(outputs, axis=1)

    metrics = {}

    # Encoding accuracy metrics
    dim = samples.shape[-1]
    acc = 0.0
    acc_map = np.zeros(dim)
    for ii, sub in enumerate(SUBJECTS):
        y_true = samples[ii].reshape(-1, dim)
        y_pred = outputs[ii].reshape(-1, dim)
        metrics[f"acc_map_sub-{sub}"] = acc_map_i = pearsonr_score(y_true, y_pred)
        metrics[f"acc_sub-{sub}"] = acc_i = np.mean(acc_map_i)
        acc_map += acc_map_i / len(SUBJECTS)
        acc += acc_i / len(SUBJECTS)

    metrics["acc_map_avg"] = acc_map
    metrics["acc_avg"] = acc
    accs_fmt = ",".join(
        f"{val:.3f}" for key, val in metrics.items() if key.startswith("acc_sub-")
    )

    tput = batch_size / step_time_m.avg
    print(
        f"Val: {epoch:>3d}"
        f"  Loss: {loss_m.avg:#.3g}"
        f"  Acc: {accs_fmt} ({acc:.3f})"
        f"  Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s"
    )

    return acc, metrics


def pearsonr_score(
    y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-7
) -> np.ndarray:
    assert y_true.ndim == y_pred.ndim == 2

    y_true = y_true - y_true.mean(axis=0)
    y_true = y_true / (np.linalg.norm(y_true, axis=0) + eps)

    y_pred = y_pred - y_pred.mean(axis=0)
    y_pred = y_pred / (np.linalg.norm(y_pred, axis=0) + eps)

    score = (y_true * y_pred).sum(axis=0)
    return score

In [64]:
seed = 3315
batch_size = 16
sample_length = 64
n_train_samples = 2000
lr = 3e-4
weight_decay = 0.001
epochs = 10

embed_dim = 64
encoder_kernel_size = 7
decoder_kernel_size = 0
with_shared = True
with_subject = True

In [65]:
config = {
    k: globals()[k] for k in
    [
        "seed",
        "batch_size",
        "sample_length",
        "n_train_samples",
        "lr",
        "weight_decay",
        "epochs",
        "embed_dim",
        "encoder_kernel_size",
        "decoder_kernel_size",
        "with_shared",
        "with_subject",
        ]
}
print(config)

{'seed': 3315, 'batch_size': 16, 'sample_length': 64, 'n_train_samples': 2000, 'lr': 0.0003, 'weight_decay': 0.001, 'epochs': 10, 'embed_dim': 64, 'encoder_kernel_size': 7, 'decoder_kernel_size': 0, 'with_shared': True, 'with_subject': True}


In [66]:
random_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", device)

Running on: cuda


In [67]:
train_dataset = Algonauts2025Dataset(
    friends_train_fmri,
    sample_length=sample_length,
    num_samples=n_train_samples,
    shuffle=True,
    seed=42,
)

val_dataset = Algonauts2025Dataset(
    friends_val_fmri,
    sample_length=None,
    shuffle=False,
)

test_dataset = Algonauts2025Dataset(
    movie10_test_fmri,
    sample_length=None,
    shuffle=False,
)

In [68]:
train_loader = DataLoader(train_dataset, batch_size=batch_size)
val_loader = DataLoader(val_dataset, batch_size=1)
test_loader = DataLoader(test_dataset, batch_size=1)

In [69]:
_, sample = next(iter(train_loader))
print("Sample shape:", tuple(sample.shape))

Sample shape: (16, 4, 64, 1000)


In [70]:
model = CrossSubjectConvLinearEncoderV2(
    embed_dim=embed_dim,
    encoder_kernel_size=encoder_kernel_size,
    decoder_kernel_size=decoder_kernel_size,
    with_shared_encoder=with_shared,
    with_shared_decoder=with_shared,
    with_subject_encoders=with_subject,
    with_subject_decoders=with_subject,
)
model = model.to(device)

param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model:", model)
print(f"Num params: {param_count/1e6:.2f}M")

Model: CrossSubjectConvLinearEncoderV2(
  (shared_encoder): Linear(in_features=1000, out_features=64, bias=True)
  (subject_encoders): ModuleList(
    (0-3): 4 x LinearConv(
      (fc): Linear(in_features=1000, out_features=64, bias=True)
      (conv): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=same, groups=64)
    )
  )
  (norm): Identity()
  (shared_decoder): Linear(in_features=64, out_features=1000, bias=True)
  (subject_decoders): ModuleList(
    (0-3): 4 x Linear(in_features=64, out_features=1000, bias=True)
  )
)
Num params: 0.65M


In [71]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
epoch_batches = n_train_samples // batch_size

In [72]:
for epoch in range(epochs):
    print("Train friends s1-5")
    train_one_epoch(
        epoch=epoch,
        model=model,
        train_loader=train_loader,
        optimizer=optimizer,
        epoch_batches=epoch_batches,
        device=device,
    )
    print("Eval friends s6")
    val_acc, val_metrics = validate(
        epoch=epoch,
        model=model,
        val_loader=val_loader,
        device=device,
    )
    print("Eval movie10")
    test_acc, test_metrics = validate(
        epoch=epoch,
        model=model,
        val_loader=test_loader,
        device=device,
    )

Train friends s1-5
Train:   0 [  0/125][     0]  Loss: 0.363 (0.363)  Time: 0.013,0.021 757/s  Mem: 0.19,0.23 GB
Train:   0 [ 10/125][    10]  Loss: 0.361 (0.370)  Time: 0.005,0.009 1738/s  Mem: 0.20,0.27 GB
Train:   0 [ 20/125][    20]  Loss: 0.342 (0.365)  Time: 0.004,0.008 2057/s  Mem: 0.20,0.27 GB
Train:   0 [ 30/125][    30]  Loss: 0.346 (0.359)  Time: 0.004,0.007 2190/s  Mem: 0.20,0.27 GB
Train:   0 [ 40/125][    40]  Loss: 0.347 (0.355)  Time: 0.003,0.007 2273/s  Mem: 0.20,0.27 GB
Train:   0 [ 50/125][    50]  Loss: 0.350 (0.352)  Time: 0.003,0.007 2323/s  Mem: 0.20,0.27 GB
Train:   0 [ 60/125][    60]  Loss: 0.323 (0.349)  Time: 0.003,0.007 2355/s  Mem: 0.20,0.27 GB
Train:   0 [ 70/125][    70]  Loss: 0.330 (0.346)  Time: 0.003,0.007 2379/s  Mem: 0.20,0.27 GB
Train:   0 [ 80/125][    80]  Loss: 0.332 (0.344)  Time: 0.003,0.007 2401/s  Mem: 0.20,0.27 GB
Train:   0 [ 90/125][    90]  Loss: 0.327 (0.342)  Time: 0.003,0.007 2400/s  Mem: 0.20,0.27 GB
Train:   0 [100/125][   100]  Lo

## Results

```
seed = 3315
batch_size = 16
sample_length = 64
n_train_samples = 2000
lr = 3e-4
weight_decay = 0.001
epochs = 10

embed_dim = 64
kernel_size = 0
with_shared = True
with_subject = True
```

```
Eval friends s6
Val:   9  Loss: 0.316  Acc: 0.340,0.339,0.365,0.314 (0.340)  Time: 0.001,0.002 534/s
Eval movie10
Val:   9  Loss: 0.330  Acc: 0.307,0.282,0.304,0.272 (0.291)  Time: 0.001,0.002 564/s
```

```
embed_dim = 128
```
```
Eval friends s6
Val:   9  Loss: 0.316  Acc: 0.341,0.340,0.365,0.312 (0.339)  Time: 0.001,0.002 446/s
Eval movie10
Val:   9  Loss: 0.330  Acc: 0.307,0.281,0.303,0.269 (0.290)  Time: 0.001,0.002 519/s
```

```
embed_dim = 32
```

```
Eval friends s6
Val:   9  Loss: 0.317  Acc: 0.336,0.336,0.362,0.311 (0.336)  Time: 0.001,0.002 514/s
Eval movie10
Val:   9  Loss: 0.330  Acc: 0.304,0.279,0.303,0.270 (0.289)  Time: 0.001,0.002 582/s
```

```
embed_dim = 64
kernel_size = 7
```

```
Eval friends s6
Val:   9  Loss: 0.314  Acc: 0.349,0.349,0.375,0.322 (0.349)  Time: 0.001,0.002 433/s
Eval movie10
Val:   9  Loss: 0.328  Acc: 0.318,0.289,0.314,0.278 (0.300)  Time: 0.001,0.002 472/s
```

```
embed_dim = 32
kernel_size = 7
```

```
Eval friends s6
Val:   9  Loss: 0.315  Acc: 0.343,0.344,0.369,0.318 (0.343)  Time: 0.001,0.002 436/s
Eval movie10
Val:   9  Loss: 0.329  Acc: 0.312,0.284,0.310,0.275 (0.295)  Time: 0.001,0.002 452/s
```

```
embed_dim = 64
encoder_kernel_size = 7
decoder_kernel_size = 7
```
```
Eval friends s6
Val:   9  Loss: 0.314  Acc: 0.348,0.348,0.373,0.322 (0.348)  Time: 0.001,0.002 450/s
Eval movie10
Val:   9  Loss: 0.328  Acc: 0.317,0.289,0.312,0.277 (0.299)  Time: 0.001,0.002 470/s
```

In [73]:
with open(out_dir / "ckpt.pt", "wb") as f:
    torch.save(
        {
            "config": config,
            "model": model.state_dict(),
            "val_metrics": val_metrics,
            "val_acc": val_acc,
            "test_metrics": test_metrics,
            "test_acc": test_acc,
        },
        f,
    )