# Cross-subject encoding

Predicting each subject's activity using the synchronized activity of the other subjects. (Idk if any papers have done this, probably they have. Please someone lmk if they know of a reference.)

## Data loader

The data loader samples clips of synchronized activity from the same friends episodes across subjects. Each clip is shape `(n_subs, sample_length, dim)`.

In [1]:
import re
import time
from collections import defaultdict
from pathlib import Path

import h5py
import numpy as np
from torch.utils.data import IterableDataset
from tqdm import tqdm

SUBJECTS = [1, 2, 3, 5]

ROOT = Path("/ocean/projects/med220004p/clane2/algonauts25")
ALGONAUTS_2025_FMRI_ROOT = ROOT / "data/algonauts_2025_fmri"

In [2]:
class Algonauts2025FriendsFmri(IterableDataset):
    def __init__(
        self,
        root: str | Path,
        subjects: list[int] | None = None,
        seasons: list[int] | None = None,
        sample_length: int = 128,
        num_samples: int | None = None,
        shuffle: bool = True,
        keep_in_memory: bool = False,
        seed: int | None = None,
    ):
        self.root = root
        self.subjects = subjects or SUBJECTS
        self.seasons = seasons or list(range(1, 7))
        self.sample_length = sample_length
        self.num_samples = num_samples
        self.shuffle = shuffle
        self.keep_in_memory = keep_in_memory
        self.seed = seed

        self._files = {
            sub: h5py.File(
                Path(root)
                / f"data/sub-{sub:02d}/func" 
                / f"sub-{sub:02d}_task-friends_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_desc-s123456_bold.h5"
            )
            for sub in self.subjects
        }

        self._task_key_maps = defaultdict(dict)
        seasons_set = set(self.seasons)
        for sub, file in self._files.items():
            for key in file.keys():
                task = key.split("-")[-1]  # 'ses-066_task-s06e24d'
                season, _, _ = _parse_friends_task(task)
                if season in seasons_set:
                    self._task_key_maps[task][sub] = key

        self._task_list = sorted(
            [
                task for task, map in self._task_key_maps.items()
                if len(map) == len(self.subjects)
            ]
        )
        
        if self.keep_in_memory:
            self._data = defaultdict(dict)
            for sub in self.subjects:
                for task in self._task_list:
                    key = self._task_key_maps[task][sub]
                    self._data[sub][key] = self._files[sub][key][:]
        else:
            self._data = None

        self._rng = np.random.default_rng(seed)
    
    def _iter_shuffle(self):
        sample_idx = 0
        while True:
            task_order = self._rng.permutation(len(self._task_list))

            for ii in task_order:
                task = self._task_list[ii]

                keys = [self._task_key_maps[task][sub] for sub in self.subjects]
                datas = [self._get_data(sub, key) for sub, key in zip(self.subjects, keys)]
                length = min(len(data) for data in datas)
                
                offset = int(self._rng.integers(0, length - self.sample_length + 1))
                sample = np.stack(
                    [data[offset: offset + self.sample_length] for data in datas]
                )

                yield sample

                sample_idx += 1
                if self.num_samples and sample_idx >= self.num_samples:
                    return
    
    def _iter_ordered(self):
        sample_idx = 0
        for task in self._task_list:
            keys = [self._task_key_maps[task][sub] for sub in self.subjects]
            datas = [self._get_data(sub, key) for sub, key in zip(self.subjects, keys)]
            length = min(len(data) for data in datas)

            for offset in range(0, length - self.sample_length + 1, self.sample_length):
                sample = np.stack(
                    [data[offset: offset + self.sample_length] for data in datas]
                )
                yield sample

                sample_idx += 1
                if self.num_samples and sample_idx >= self.num_samples:
                    return
    
    def _get_data(self, sub: int, key: str):
        maps = self._data if self.keep_in_memory else self._files
        return maps[sub][key]
    
    def __iter__(self):
        if self.shuffle:
            yield from self._iter_shuffle()
        else:
            yield from self._iter_ordered()


def _parse_friends_task(task: str):
    match = re.match(r"s([0-9]+)e([0-9]+)([a-z])", task)
    season = int(match.group(1))
    episode = int(match.group(2))
    part = match.group(3)
    return season, episode, part

In [3]:
dataset = Algonauts2025FriendsFmri(
    root=ALGONAUTS_2025_FMRI_ROOT,
    seasons=[6],
    sample_length=64,
    num_samples=None,
    shuffle=False,
    keep_in_memory=True,
    seed=42,
)

total_bytes = 0
tic = time.monotonic()
for sample in tqdm(dataset):
    total_bytes += sample.size * 4
rt = time.monotonic() - tic
tput = total_bytes / 1024 ** 2 / rt 
print(f"run time={rt:.3f}s, MB/s={tput:.0f}")

336it [00:00, 12811.81it/s]

run time=0.030s, MB/s=10944





## Model

Model architecture is a simple linear encoder and decoder for each subject. The encoder/decoder is "factorized" into a depthwise conv1d (to align data temporally), and a linear projection (to align data spatially).

For each subject, the input to the decoder is the average of the latents for the other three subjects.

In [4]:
from functools import partial

import torch
import torch.nn.functional as F
from torch import nn

In [5]:
class CausalConv1d(nn.Conv1d):
    """Conv1d layer with a causal mask, to only "attend" to past time points."""
    attn_mask: torch.Tensor

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: str | int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
    ):
        assert kernel_size % 2 == 1, "causal conv requires odd kernel size"
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )

        attn_mask = torch.zeros(kernel_size)
        attn_mask[:kernel_size // 2 + 1] = 1.0
        self.weight.data.mul_(attn_mask)
        self.register_buffer("attn_mask", attn_mask)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        weight = self.weight * self.attn_mask
        return F.conv1d(
            input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )

In [6]:
class ConvLinearEncoder(nn.Module):
    def __init__(
        self,
        in_features: int,
        embed_dim: int,
        kernel_size: int = 11,
        causal: bool = False,
    ):
        super().__init__()
        conv_layer = CausalConv1d if causal else nn.Conv1d
        self.conv = conv_layer(
            in_features,
            in_features,
            kernel_size=kernel_size,
            padding="same",
            groups=in_features,
        )
        self.fc = nn.Linear(in_features, embed_dim)

    def forward(self, x: torch.Tensor):
        # x: (N, L, C)
        x = x.transpose(-1, -2)
        x = self.conv(x)
        x = x.transpose(-1, -2)
        x = self.fc(x)
        return x


class ConvLinearDecoder(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        out_features: int,
        kernel_size: int = 11,
        causal: bool = False,
    ):
        super().__init__()
        conv_layer = CausalConv1d if causal else nn.Conv1d
        self.fc = nn.Linear(embed_dim, out_features)
        self.conv = conv_layer(
            out_features,
            out_features,
            kernel_size=kernel_size,
            padding="same",
            groups=out_features,
        )

    def forward(self, x: torch.Tensor):
        # x: (N, L, C)
        x = self.fc(x)
        x = x.transpose(-1, -2)
        x = self.conv(x)
        x = x.transpose(-1, -2)
        return x

In [7]:
encoder = ConvLinearEncoder(
    in_features=1000,
    embed_dim=256,
    causal=True
)
print(encoder)

# (N, L, C)
x = torch.randn(16, 64, 1000)
embed = encoder.forward(x)
print(embed.shape)

ConvLinearEncoder(
  (conv): CausalConv1d(1000, 1000, kernel_size=(11,), stride=(1,), padding=same, groups=1000)
  (fc): Linear(in_features=1000, out_features=256, bias=True)
)
torch.Size([16, 64, 256])


In [8]:
class CrossSubjectConvLinearEncoder(nn.Module):
    weight: torch.Tensor

    def __init__(
        self,
        num_subjects: int,
        encoder_fn: type[nn.Module],
        decoder_fn: type[nn.Module],
        embed_dim: int = 256,
        normalize: bool = False,
    ):
        super().__init__()
        self.num_subjects = num_subjects
        # todo: could also consider having a shared group encoder/decoder
        self.encoders = nn.ModuleList([encoder_fn() for _ in range(num_subjects)])
        self.norm = nn.LayerNorm(embed_dim) if normalize else nn.Identity()
        self.decoders = nn.ModuleList([decoder_fn() for _ in range(num_subjects)])
        
        # todo: could learn the averaging weights
        weight = (1.0 - torch.eye(self.num_subjects)) / (self.num_subjects - 1.0)
        self.register_buffer("weight", weight)
        self.apply(init_weights)
    
    def forward(self, input: torch.Tensor):
        # input: (N, S, L, C)
        # subject specific encoders
        embed = torch.stack(
            [encoder(input[:, ii]) for ii, encoder in enumerate(self.encoders)],
            dim=1,
        )
        embed = self.norm(embed)
        # average pool the latents for all but target subject
        embed = torch.einsum("nslc,ts->ntlc", embed, self.weight)
        # subject specific decoders
        output = torch.stack(
            [decoder(embed[:, ii]) for ii, decoder in enumerate(self.decoders)],
            dim=1,
        )
        return output


def init_weights(m: nn.Module):
    if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)):
        nn.init.trunc_normal_(m.weight, std=0.02)
        nn.init.constant_(m.bias, 0)

In [9]:
encoder_fn = partial(ConvLinearEncoder, 1000, 256)
decoder_fn = partial(ConvLinearDecoder, 256, 1000)

cross_encoder = CrossSubjectConvLinearEncoder(
    num_subjects=4,
    encoder_fn=encoder_fn,
    decoder_fn=decoder_fn,
    embed_dim=256,
)
print(cross_encoder)

# (N, S, L, C)
x = torch.randn(16, 4, 64, 1000)
z = cross_encoder.forward(x)
print(z.shape)

CrossSubjectConvLinearEncoder(
  (encoders): ModuleList(
    (0-3): 4 x ConvLinearEncoder(
      (conv): Conv1d(1000, 1000, kernel_size=(11,), stride=(1,), padding=same, groups=1000)
      (fc): Linear(in_features=1000, out_features=256, bias=True)
    )
  )
  (norm): Identity()
  (decoders): ModuleList(
    (0-3): 4 x ConvLinearDecoder(
      (fc): Linear(in_features=256, out_features=1000, bias=True)
      (conv): Conv1d(1000, 1000, kernel_size=(11,), stride=(1,), padding=same, groups=1000)
    )
  )
)
torch.Size([16, 4, 64, 1000])


## Training

Basic training loop, AdamW, no lr decay, no bells and whistles.

In [10]:
import math

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from timm.utils import AverageMeter, random_seed

In [11]:
def train_one_epoch(
    *,
    epoch: int,
    model: torch.nn.Module,
    train_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    epoch_batches: int | None,
    device: torch.device,
):
    model.train()
    
    use_cuda = device.type == "cuda"
    if use_cuda:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    loss_m = AverageMeter()
    data_time_m = AverageMeter()
    step_time_m = AverageMeter()

    epoch_batches = len(train_loader) if epoch_batches is None else epoch_batches
    first_step = epoch * epoch_batches

    end = time.monotonic()
    for batch_idx, sample in enumerate(train_loader):
        step = first_step + batch_idx
        sample = sample.to(device)
        batch_size = sample.size(0)
        data_time = time.monotonic() - end

        # forward pass
        output = model(sample)
        loss = F.mse_loss(output, sample)
        loss_item = loss.item()

        if math.isnan(loss_item) or math.isinf(loss_item):
            raise RuntimeError("NaN/Inf loss encountered on step %d; exiting", step)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # end of iteration timing
        if use_cuda:
            torch.cuda.synchronize()
        step_time = time.monotonic() - end

        loss_m.update(loss_item, batch_size)
        data_time_m.update(data_time, batch_size)
        step_time_m.update(step_time, batch_size)

        if step % 10 == 0:
            tput = batch_size / step_time_m.avg
            if use_cuda:
                alloc_mem_gb = torch.cuda.max_memory_allocated() / 1e9
                res_mem_gb = torch.cuda.max_memory_reserved() / 1e9
            else:
                alloc_mem_gb = res_mem_gb = 0.0

            print(
                f"Train: {epoch:>3d} [{batch_idx:>3d}/{epoch_batches}][{step:>6d}]"
                f"  Loss: {loss_m.val:#.3g} ({loss_m.avg:#.3g})"
                f"  Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s"
                f"  Mem: {alloc_mem_gb:.2f},{res_mem_gb:.2f} GB"
            )

        # Restart timer for next iteration
        end = time.monotonic()

In [12]:
@torch.no_grad()
def validate(
    *,
    epoch: int,
    model: torch.nn.Module,
    val_loader: DataLoader,
    device: torch.device,
):
    model.eval()

    use_cuda = device.type == "cuda"
    
    loss_m = AverageMeter()
    data_time_m = AverageMeter()
    step_time_m = AverageMeter()

    samples = []
    outputs = []

    end = time.monotonic()
    for batch_idx, sample in enumerate(val_loader):
        sample = sample.to(device)
        batch_size = sample.size(0)
        data_time = time.monotonic() - end

        # forward pass
        output = model(sample)
        loss = F.mse_loss(output, sample)
        loss_item = loss.item()

        # end of iteration timing
        if use_cuda:
            torch.cuda.synchronize()
        step_time = time.monotonic() - end

        loss_m.update(loss_item, batch_size)
        data_time_m.update(data_time, batch_size)
        step_time_m.update(step_time, batch_size)

        samples.append(sample.cpu().numpy())
        outputs.append(output.cpu().numpy())

        # Reset timer
        end = time.monotonic()

    # (N, S, L, C)
    samples = np.concatenate(samples)
    outputs = np.concatenate(outputs)
    
    accs = {
        f"acc_s{sub}": pearsonr_score(samples[:, ii], outputs[:, ii])
        for ii, sub in enumerate(SUBJECTS)
    }
    accs_fmt = ",".join(f"{acc:.3f}" for acc in accs.values())
    acc = sum(accs.values()) / len(accs)

    tput = batch_size / step_time_m.avg
    print(
        f"Val: {epoch:>3d}"
        f"  Loss: {loss_m.avg:#.3g}"
        f"  Acc: {accs_fmt} ({acc:.3f})"
        f"  Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s"
    )

    return acc


def pearsonr_score(
    y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-7
) -> np.ndarray:
    y_true = y_true.reshape(-1, y_true.shape[-1])
    y_pred = y_pred.reshape(-1, y_pred.shape[-1])

    y_true = y_true - y_true.mean(axis=0)
    y_true = y_true / (np.linalg.norm(y_true, axis=0) + eps)

    y_pred = y_pred - y_pred.mean(axis=0)
    y_pred = y_pred / (np.linalg.norm(y_pred, axis=0) + eps)

    score = (y_true * y_pred).sum(axis=0).mean()
    return score

In [13]:
seed = 3315
batch_size = 16
sample_length = 64
n_train_samples = 2000
embed_dim = 256
kernel_size = 11
lr = 3e-4
weight_decay = 0.001
epochs = 10

In [14]:
random_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", device)

Running on: cpu


In [15]:
train_dataset = Algonauts2025FriendsFmri(
    root=ALGONAUTS_2025_FMRI_ROOT,
    seasons=range(1, 6),
    sample_length=sample_length,
    num_samples=n_train_samples,
    shuffle=True,
    keep_in_memory=True,
    seed=42,
)

val_dataset = Algonauts2025FriendsFmri(
    root=ALGONAUTS_2025_FMRI_ROOT,
    seasons=[6],
    sample_length=sample_length,
    shuffle=False,
    keep_in_memory=True,
)

In [16]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, drop_last=True)

In [17]:
batch = next(iter(train_loader))
print("Batch shape:", tuple(batch.shape))

Batch shape: (16, 4, 64, 1000)


In [18]:
encoder_fn = partial(ConvLinearEncoder, 1000, embed_dim, kernel_size=kernel_size)
decoder_fn = partial(ConvLinearDecoder, embed_dim, 1000, kernel_size=kernel_size)

model = CrossSubjectConvLinearEncoder(
    num_subjects=4,
    encoder_fn=encoder_fn,
    decoder_fn=decoder_fn,
    embed_dim=embed_dim,
)
model = model.to(device)

param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model:", model)
print(f"Num params: {param_count/1e6:.2f}M")

Model: CrossSubjectConvLinearEncoder(
  (encoders): ModuleList(
    (0-3): 4 x ConvLinearEncoder(
      (conv): Conv1d(1000, 1000, kernel_size=(11,), stride=(1,), padding=same, groups=1000)
      (fc): Linear(in_features=1000, out_features=256, bias=True)
    )
  )
  (norm): Identity()
  (decoders): ModuleList(
    (0-3): 4 x ConvLinearDecoder(
      (fc): Linear(in_features=256, out_features=1000, bias=True)
      (conv): Conv1d(1000, 1000, kernel_size=(11,), stride=(1,), padding=same, groups=1000)
    )
  )
)
Num params: 2.15M


In [19]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
epoch_batches = n_train_samples // batch_size

In [20]:
for epoch in range(epochs):
    train_one_epoch(
        epoch=epoch,
        model=model,
        train_loader=train_loader,
        optimizer=optimizer,
        epoch_batches=epoch_batches,
        device=device,
    )
    validate(
        epoch=epoch,
        model=model,
        val_loader=val_loader,
        device=device,
    )

Train:   0 [  0/125][     0]  Loss: 0.359 (0.359)  Time: 0.012,0.258 62/s  Mem: 0.00,0.00 GB
Train:   0 [ 10/125][    10]  Loss: 0.368 (0.371)  Time: 0.006,0.156 103/s  Mem: 0.00,0.00 GB
Train:   0 [ 20/125][    20]  Loss: 0.361 (0.373)  Time: 0.006,0.161 100/s  Mem: 0.00,0.00 GB
Train:   0 [ 30/125][    30]  Loss: 0.378 (0.373)  Time: 0.006,0.157 102/s  Mem: 0.00,0.00 GB
Train:   0 [ 40/125][    40]  Loss: 0.378 (0.373)  Time: 0.006,0.163 98/s  Mem: 0.00,0.00 GB
Train:   0 [ 50/125][    50]  Loss: 0.375 (0.372)  Time: 0.006,0.159 100/s  Mem: 0.00,0.00 GB
Train:   0 [ 60/125][    60]  Loss: 0.352 (0.371)  Time: 0.005,0.159 101/s  Mem: 0.00,0.00 GB
Train:   0 [ 70/125][    70]  Loss: 0.357 (0.369)  Time: 0.005,0.158 101/s  Mem: 0.00,0.00 GB
Train:   0 [ 80/125][    80]  Loss: 0.357 (0.367)  Time: 0.005,0.157 102/s  Mem: 0.00,0.00 GB
Train:   0 [ 90/125][    90]  Loss: 0.347 (0.365)  Time: 0.005,0.156 103/s  Mem: 0.00,0.00 GB
Train:   0 [100/125][   100]  Loss: 0.350 (0.364)  Time: 0.005