In [72]:
import json
import re
import time
from collections import defaultdict
from pathlib import Path

import h5py
import numpy as np
import torch
from torch.utils.data import IterableDataset
from matplotlib import pyplot as plt
from tqdm import tqdm

In [2]:
SUBJECTS = (1, 2, 3, 5)

In [3]:
data_dir = Path("/home/clane/algonauts_2025.competitors")
feat_dir = Path("/home/clane/algonauts2025.huggingface/features")

out_dir = Path(".") / "output"
out_dir.mkdir(exist_ok=True)
print("Saving output to:", out_dir.resolve())

Saving output to: /home/clane/algonauts2025/long_context_encoding/output


## Data loading

Aligned cross-subject fmri data loader. The data loader samples clips of synchronized activity from the same friends episodes across subjects. Each clip is shape `(n_subs, sample_length, dim)`.

We also load pre-extracted features of the shape `(sample_length, dim)`.

In [4]:
def parse_friends_run(run: str):
    match = re.match(r"s([0-9]+)e([0-9]+)([a-z])", run)
    if match is None:
        raise ValueError(f"Invalid friends run {run}")

    season = int(match.group(1))
    episode = int(match.group(2))
    part = match.group(3)
    return season, episode, part

In [5]:
class Algonauts2025FriendsFmri:
    def __init__(
        self,
        root: str | Path,
        subjects: list[int] | None = None,
        seasons: list[int] | None = None,
    ):
        self.root = root
        self.subjects = subjects or SUBJECTS
        self.seasons = seasons or list(range(1, 7))

        files = {
            sub: h5py.File(
                Path(root)
                / f"sub-{sub:02d}/func"
                / f"sub-{sub:02d}_task-friends_space-MNI152NLin2009cAsym_atlas-Schaefer18_parcel-1000Par7Net_desc-s123456_bold.h5"
            )
            for sub in self.subjects
        }

        episode_key_maps = defaultdict(dict)
        seasons_set = set(self.seasons)
        for sub, file in files.items():
            for key in file.keys():
                episode = key.split("-")[-1]  # 'ses-066_task-s06e24d'
                season, _, _ = parse_friends_run(episode)
                if season in seasons_set:
                    episode_key_maps[episode][sub] = key

        episode_list = sorted(
            [
                episode for episode, map in episode_key_maps.items()
                if len(map) == len(self.subjects)
            ]
        )

        data = {}
        for episode in episode_list:
            samples = []
            length = None
            for sub in self.subjects:
                key = episode_key_maps[episode][sub]
                sample = files[sub][key][:]
                sub_length = len(sample)
                samples.append(sample)
                length = min(length, sub_length) if length else sub_length
            data[episode] = np.stack([sample[:length] for sample in samples])

        self.episode_list = episode_list 
        self._data = data
    
    def get(self, episode: str) -> np.ndarray:
        return self._data[episode]

In [6]:
fmri_data = Algonauts2025FriendsFmri(
    root=data_dir / "fmri",
    seasons=range(1, 7),
)

In [7]:
sample = fmri_data.get("s01e05b")
print("Sample shape (NTC):", sample.shape, sample.dtype)

Sample shape (NTC): (4, 468, 1000) float32


In [8]:
class Algonauts2025FriendsFeatures:
    def __init__(
        self,
        path: str | Path,
        layer: str,
        episode_list: list[str] | None = None,
    ):
        self.path = path
        self.layer = layer

        file = h5py.File(path)

        data = {}
        episode_set = set(episode_list) if episode_list else None
        episode_list = []
        for key in file:
            if not episode_set or key in episode_set:
                data[key] = file[key][layer][:]
                episode_list.append(key)

        self.episode_list = episode_list
        self._data = data

    def get(self, episode: str) -> np.ndarray:
        return self._data[episode]

In [9]:
feat_data = Algonauts2025FriendsFeatures(
    path=feat_dir / "friends/meta-llama__Llama-3.2-1B/context-short_window-16.h5",
    layer="model.layers.11",
    episode_list=fmri_data.episode_list,
)

In [10]:
feat = feat_data.get("s01e05b")
print("Feature shape (TC):", feat.shape, feat.dtype)

Feature shape (TC): (468, 2048) float32


In [26]:
class Algonauts2025FriendsDataset(IterableDataset):
    def __init__(
        self,
        feat_data: Algonauts2025FriendsFeatures,
        fmri_data: Algonauts2025FriendsFmri,
        seasons: list[int] | None = None,
        sample_length: int | None = 128,
        num_samples: int | None = None,
        shuffle: bool = True,
        seed: int | None = None,
    ):
        self.feat_data = feat_data
        self.fmri_data = fmri_data
        self.seasons = seasons or list(range(1, 7))

        episode_list = []
        seasons_set = set(self.seasons)
        for episode in fmri_data.episode_list:
            season, _, _ = parse_friends_run(episode)
            if season in seasons_set:
                episode_list.append(episode)
        self.episode_list = episode_list

        self.sample_length = sample_length
        self.num_samples = num_samples
        self.shuffle = shuffle
        self.seed = seed

        self._rng = np.random.default_rng(seed)
    
    def _iter_shuffle(self):
        sample_idx = 0
        while True:
            episode_order = self._rng.permutation(len(self.episode_list))

            for ii in episode_order:
                episode = self.episode_list[ii]
                feat = torch.from_numpy(self.feat_data.get(episode))
                fmri = torch.from_numpy(self.fmri_data.get(episode))

                # Nb, fmri and feature length often off by 1 or 2.
                # But assuming time locked to start.
                length = min(feat.shape[0], fmri.shape[1])

                if self.sample_length:
                    # Random segment of run
                    offset = int(self._rng.integers(0, length - self.sample_length + 1))
                    feat_sample = feat[offset: offset + self.sample_length]
                    fmri_sample = fmri[:, offset: offset + self.sample_length]
                else:
                    # Take full run
                    # Nb this only works for batch size 1 since runs are different length
                    feat_sample = feat[:length]
                    fmri_sample = fmri[:, :length]

                yield episode, feat_sample, fmri_sample

                sample_idx += 1
                if self.num_samples and sample_idx >= self.num_samples:
                    return

    def _iter_ordered(self):
        sample_idx = 0
        for episode in self.episode_list:
            feat = torch.from_numpy(self.feat_data.get(episode))
            fmri = torch.from_numpy(self.fmri_data.get(episode))

            length = min(feat.shape[0], fmri.shape[1])
            sample_length = self.sample_length or length

            for offset in range(0, length - sample_length + 1, sample_length):
                feat_sample = feat[offset: offset + sample_length]
                fmri_sample = fmri[:, offset: offset + sample_length]
                yield episode, feat_sample, fmri_sample

                sample_idx += 1
                if self.num_samples and sample_idx >= self.num_samples:
                    return

    def __iter__(self):
        if self.shuffle:
            yield from self._iter_shuffle()
        else:
            yield from self._iter_ordered()

In [27]:
dataset = Algonauts2025FriendsDataset(
    feat_data,
    fmri_data,
    seasons=range(1, 6),
    sample_length=64,
    num_samples=10000,
    shuffle=True,
    seed=42,
)

In [29]:
total_bytes = 0
tic = time.monotonic()
for task, feat_sample, fmri_sample in tqdm(dataset):
    total_bytes += feat_sample.numel() * 4 + fmri_sample.numel() * 4
rt = time.monotonic() - tic
tput = total_bytes / 1024 ** 2 / rt 
print(f"run time={rt:.3f}s, MB/s={tput:.0f}")

10000it [00:00, 50584.73it/s]

run time=0.207s, MB/s=71338





## Model


In [14]:
from functools import partial

import torch
import torch.nn.functional as F
from torch import nn

In [15]:
class CausalConv1d(nn.Conv1d):
    """Conv1d layer with a causal mask, to only "attend" to past time points."""
    attn_mask: torch.Tensor

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: str | int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
    ):
        assert kernel_size % 2 == 1, "causal conv requires odd kernel size"
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )

        attn_mask = torch.zeros(kernel_size)
        attn_mask[:kernel_size // 2 + 1] = 1.0
        self.weight.data.mul_(attn_mask)
        self.register_buffer("attn_mask", attn_mask)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        weight = self.weight * self.attn_mask
        return F.conv1d(
            input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )

In [16]:
class ConvLinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        kernel_size: int = 11,
        causal: bool = False,
    ):
        super().__init__()
        conv_layer = CausalConv1d if causal else nn.Conv1d
        self.conv = conv_layer(
            in_features,
            in_features,
            kernel_size=kernel_size,
            padding="same",
            groups=in_features,
        )
        self.fc = nn.Linear(in_features, out_features)

    def forward(self, x: torch.Tensor):
        # x: (N, L, C)
        x = x.transpose(-1, -2)
        x = self.conv(x)
        x = x.transpose(-1, -2)
        x = self.fc(x)
        return x


class LinearConv(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        kernel_size: int = 11,
        causal: bool = False,
    ):
        super().__init__()
        conv_layer = CausalConv1d if causal else nn.Conv1d
        self.fc = nn.Linear(in_features, out_features)
        self.conv = conv_layer(
            out_features,
            out_features,
            kernel_size=kernel_size,
            padding="same",
            groups=out_features,
        )

    def forward(self, x: torch.Tensor):
        # x: (N, L, C)
        x = self.fc(x)
        x = x.transpose(-1, -2)
        x = self.conv(x)
        x = x.transpose(-1, -2)
        return x

In [17]:
encoder = ConvLinear(
    in_features=1000,
    out_features=256,
    causal=True
)
print(encoder)

# (N, L, C)
x = torch.randn(16, 64, 1000)
embed = encoder.forward(x)
print(embed.shape)

ConvLinear(
  (conv): CausalConv1d(1000, 1000, kernel_size=(11,), stride=(1,), padding=same, groups=1000)
  (fc): Linear(in_features=1000, out_features=256, bias=True)
)
torch.Size([16, 64, 256])


In [18]:
class MultiSubjectConvLinearEncoder(nn.Module):
    weight: torch.Tensor

    def __init__(
        self,
        num_subjects: int = 4,
        feat_dim: int = 2048,
        embed_dim: int = 256,
        target_dim: int = 1000,
        kernel_size: int = 11,
        causal: bool = False,
    ):
        super().__init__()
        self.num_subjects = num_subjects

        self.norm = nn.LayerNorm(feat_dim)
        self.feat_embed = nn.Linear(feat_dim, embed_dim)

        self.shared_encoder = ConvLinear(
            embed_dim, target_dim, kernel_size=kernel_size, causal=causal,
        )
        self.subject_encoders = nn.ModuleList(
            [
                ConvLinear(
                    embed_dim, target_dim, kernel_size=kernel_size, causal=causal
                )
                for _ in range(num_subjects)
            ]
        )
        self.apply(init_weights)
    
    def forward(self, input: torch.Tensor):
        # input: (N, L, D)
        # output: (N, S, L, C)
        embed = self.feat_embed(self.norm(input))
        shared_output = self.shared_encoder(embed)
        subject_output = torch.stack(
            [encoder(embed) for encoder in self.subject_encoders],
            dim=1,
        )
        output = subject_output + shared_output[:, None]
        return output


def init_weights(m: nn.Module):
    if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)):
        nn.init.trunc_normal_(m.weight, std=0.02)
        nn.init.constant_(m.bias, 0)

In [19]:
encoder = MultiSubjectConvLinearEncoder()
print(encoder)

MultiSubjectConvLinearEncoder(
  (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
  (feat_embed): Linear(in_features=2048, out_features=256, bias=True)
  (shared_encoder): ConvLinear(
    (conv): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=same, groups=256)
    (fc): Linear(in_features=256, out_features=1000, bias=True)
  )
  (subject_encoders): ModuleList(
    (0-3): 4 x ConvLinear(
      (conv): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=same, groups=256)
      (fc): Linear(in_features=256, out_features=1000, bias=True)
    )
  )
)


In [20]:
# (N, L, C)
x = torch.randn(16, 64, 2048)
z = encoder.forward(x)
print(z.shape)

torch.Size([16, 4, 64, 1000])


## Training

In [21]:
import math

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from timm.utils import AverageMeter, random_seed

In [22]:
def train_one_epoch(
    *,
    epoch: int,
    model: torch.nn.Module,
    train_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    epoch_batches: int | None,
    device: torch.device,
):
    model.train()
    
    use_cuda = device.type == "cuda"
    if use_cuda:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    loss_m = AverageMeter()
    data_time_m = AverageMeter()
    step_time_m = AverageMeter()

    epoch_batches = len(train_loader) if epoch_batches is None else epoch_batches
    first_step = epoch * epoch_batches

    end = time.monotonic()
    for batch_idx, (_, feat, sample) in enumerate(train_loader):
        step = first_step + batch_idx
        feat = feat.to(device)
        sample = sample.to(device)
        batch_size = feat.size(0)
        data_time = time.monotonic() - end

        # forward pass
        output = model(feat)
        loss = F.mse_loss(output, sample)
        loss_item = loss.item()

        if math.isnan(loss_item) or math.isinf(loss_item):
            raise RuntimeError("NaN/Inf loss encountered on step %d; exiting", step)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # end of iteration timing
        if use_cuda:
            torch.cuda.synchronize()
        step_time = time.monotonic() - end

        loss_m.update(loss_item, batch_size)
        data_time_m.update(data_time, batch_size)
        step_time_m.update(step_time, batch_size)

        if step % 10 == 0:
            tput = batch_size / step_time_m.avg
            if use_cuda:
                alloc_mem_gb = torch.cuda.max_memory_allocated() / 1e9
                res_mem_gb = torch.cuda.max_memory_reserved() / 1e9
            else:
                alloc_mem_gb = res_mem_gb = 0.0

            print(
                f"Train: {epoch:>3d} [{batch_idx:>3d}/{epoch_batches}][{step:>6d}]"
                f"  Loss: {loss_m.val:#.3g} ({loss_m.avg:#.3g})"
                f"  Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s"
                f"  Mem: {alloc_mem_gb:.2f},{res_mem_gb:.2f} GB"
            )

        # Restart timer for next iteration
        end = time.monotonic()

In [23]:
@torch.no_grad()
def validate(
    *,
    epoch: int,
    model: torch.nn.Module,
    val_loader: DataLoader,
    device: torch.device,
):
    model.eval()

    use_cuda = device.type == "cuda"

    loss_m = AverageMeter()
    data_time_m = AverageMeter()
    step_time_m = AverageMeter()

    samples = []
    outputs = []

    end = time.monotonic()
    for batch_idx, (_, feat, sample) in enumerate(val_loader):
        feat = feat.to(device)
        sample = sample.to(device)
        batch_size = feat.size(0)
        data_time = time.monotonic() - end

        # forward pass
        output = model(feat)
        loss = F.mse_loss(output, sample)
        loss_item = loss.item()

        # end of iteration timing
        if use_cuda:
            torch.cuda.synchronize()
        step_time = time.monotonic() - end

        loss_m.update(loss_item, batch_size)
        data_time_m.update(data_time, batch_size)
        step_time_m.update(step_time, batch_size)

        N, S, L, C = sample.shape
        assert N, S == (1, 4)
        samples.append(sample.cpu().numpy().swapaxes(0, 1).reshape((S, N*L, C)))
        outputs.append(output.cpu().numpy().swapaxes(0, 1).reshape((S, N*L, C)))

        # Reset timer
        end = time.monotonic()

    # (S, N, C)
    samples = np.concatenate(samples, axis=1)
    outputs = np.concatenate(outputs, axis=1)

    metrics = {}

    # Encoding accuracy metrics
    dim = samples.shape[-1]
    acc = 0.0
    acc_map = np.zeros(dim)
    for ii, sub in enumerate(SUBJECTS):
        y_true = samples[ii].reshape(-1, dim)
        y_pred = outputs[ii].reshape(-1, dim)
        metrics[f"acc_map_sub-{sub}"] = acc_map_i = pearsonr_score(y_true, y_pred)
        metrics[f"acc_sub-{sub}"] = acc_i = np.mean(acc_map_i)
        acc_map += acc_map_i / len(SUBJECTS)
        acc += acc_i / len(SUBJECTS)

    metrics["acc_map_avg"] = acc_map
    metrics["acc_avg"] = acc
    accs_fmt = ",".join(
        f"{val:.3f}" for key, val in metrics.items() if key.startswith("acc_sub-")
    )

    tput = batch_size / step_time_m.avg
    print(
        f"Val: {epoch:>3d}"
        f"  Loss: {loss_m.avg:#.3g}"
        f"  Acc: {accs_fmt} ({acc:.3f})"
        f"  Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s"
    )

    return acc, metrics


def pearsonr_score(
    y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-7
) -> np.ndarray:
    assert y_true.ndim == y_pred.ndim == 2

    y_true = y_true - y_true.mean(axis=0)
    y_true = y_true / (np.linalg.norm(y_true, axis=0) + eps)

    y_pred = y_pred - y_pred.mean(axis=0)
    y_pred = y_pred / (np.linalg.norm(y_pred, axis=0) + eps)

    score = (y_true * y_pred).sum(axis=0)
    return score

In [56]:
seed = 3315
batch_size = 16
sample_length = 64
n_train_samples = 2000
embed_dim = 256
kernel_size = 11
causal = False
lr = 3e-4
weight_decay = 0.001
epochs = 10

In [57]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", device)

Running on: cuda


In [58]:
def make_data_loaders(
    model_name: str = "meta-llama/Llama-3.2-1B",
    layer: str = "model.layers.11",
    context: str = "short",
    window: int = 16,
    summary: bool = False,
):
    model_str = model_name.replace('/', '__')
    if context == "short":
        feat_path = f"{model_str}/context-{context}_window-{window}.h5"
    else:
        feat_path = f"{model_str}/context-{context}_summary-{int(summary)}.h5"
    feat_path = feat_dir / "friends" / feat_path

    feat_data = Algonauts2025FriendsFeatures(
        path=feat_path,
        layer=layer,
        episode_list=fmri_data.episode_list,
    )

    train_dataset = Algonauts2025FriendsDataset(
        feat_data,
        fmri_data,
        seasons=range(1, 6),
        sample_length=sample_length,
        num_samples=n_train_samples,
        shuffle=True,
        seed=42,
    )

    val_dataset = Algonauts2025FriendsDataset(
        feat_data,
        fmri_data,
        seasons=[6],
        sample_length=None,
        shuffle=False,
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    val_loader = DataLoader(val_dataset, batch_size=1)

    return train_loader, val_loader

In [59]:
train_loader, val_loader = make_data_loaders()

In [60]:
_, feat, sample = next(iter(train_loader))
print("feat shape:", tuple(feat.shape))
print("Sample shape:", tuple(sample.shape))

feat shape: (16, 64, 2048)
Sample shape: (16, 4, 64, 1000)


In [74]:
def run_experiment(
    model_name: str = "meta-llama/Llama-3.2-1B",
    layer: str = "model.layers.11",
    context: str = "short",
    window: int = 16,
    summary: bool = False,
    overwrite: bool = False,
):
    random_seed(seed)

    model_str = model_name.replace('/', '__')
    if context == "short":
        out_path = f"{model_str}/context-{context}_window-{window}/{layer}.pt"
    else:
        out_path = f"{model_str}/context-{context}_summary-{int(summary)}/{layer}.pt"

    out_path = out_dir / "results" / out_path
    out_path.parent.mkdir(parents=True, exist_ok=True)

    if out_path.exists() and not overwrite:
        ckpt = torch.load(out_path, map_location=device, weights_only=False)
        return ckpt["acc"]
    
    train_loader, val_loader = make_data_loaders(
        model_name=model_name,
        layer=layer,
        context=context,
        window=window,
        summary=summary,
    )

    model = MultiSubjectConvLinearEncoder(
        embed_dim=embed_dim,
        kernel_size=kernel_size,
        causal=causal,
    )
    model = model.to(device)

    param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Model:", model)
    print(f"Num params: {param_count/1e6:.2f}M")

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    epoch_batches = n_train_samples // batch_size

    for epoch in range(epochs):
        train_one_epoch(
            epoch=epoch,
            model=model,
            train_loader=train_loader,
            optimizer=optimizer,
            epoch_batches=epoch_batches,
            device=device,
        )
        acc, metrics = validate(
            epoch=epoch,
            model=model,
            val_loader=val_loader,
            device=device,
        )
    
    acc = float(acc)

    with out_path.open("wb") as f:
        torch.save(
            {
                "model": model.state_dict(),
                "metrics": metrics,
                "acc": acc,
            },
            f,
        )
    
    return acc

In [71]:
run_experiment(overwrite=True)

Model: MultiSubjectConvLinearEncoder(
  (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
  (feat_embed): Linear(in_features=2048, out_features=256, bias=True)
  (shared_encoder): ConvLinear(
    (conv): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=same, groups=256)
    (fc): Linear(in_features=256, out_features=1000, bias=True)
  )
  (subject_encoders): ModuleList(
    (0-3): 4 x ConvLinear(
      (conv): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=same, groups=256)
      (fc): Linear(in_features=256, out_features=1000, bias=True)
    )
  )
)
Num params: 1.83M
Train:   0 [  0/125][     0]  Loss: 0.367 (0.367)  Time: 0.014,0.059 269/s  Mem: 0.22,0.25 GB
Train:   0 [ 10/125][    10]  Loss: 0.366 (0.375)  Time: 0.018,0.034 471/s  Mem: 0.25,0.27 GB
Train:   0 [ 20/125][    20]  Loss: 0.381 (0.375)  Time: 0.016,0.029 557/s  Mem: 0.25,0.27 GB
Train:   0 [ 30/125][    30]  Loss: 0.359 (0.371)  Time: 0.014,0.025 640/s  Mem: 0.25,0.27 GB
Train:   0 [ 40/12

0.17757554352283478

In [75]:
results = []

model_name = "meta-llama/Llama-3.2-1B"

cfgs = [
    ("short", 16, False),
    ("short", 32, False),
    ("long", 16, False),
    ("long", 16, True),
]

layers = [f"model.layers.{ii}" for ii in [7, 11, 15]]

for context, window, summary in cfgs:
    for layer in layers:
        acc = run_experiment(
            model_name=model_name,
            layer=layer,
            context=context,
            window=window,
            summary=summary,
        )

        results.append(
            {
                "model": model_name,
                "context": context,
                "window": window,
                "summary": summary,
                "layer": layer,
                "acc": acc,
            }
        )

        print(json.dumps(results[-1]))

Model: MultiSubjectConvLinearEncoder(
  (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
  (feat_embed): Linear(in_features=2048, out_features=256, bias=True)
  (shared_encoder): ConvLinear(
    (conv): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=same, groups=256)
    (fc): Linear(in_features=256, out_features=1000, bias=True)
  )
  (subject_encoders): ModuleList(
    (0-3): 4 x ConvLinear(
      (conv): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=same, groups=256)
      (fc): Linear(in_features=256, out_features=1000, bias=True)
    )
  )
)
Num params: 1.83M
Train:   0 [  0/125][     0]  Loss: 0.367 (0.367)  Time: 0.037,0.069 232/s  Mem: 0.27,0.31 GB
Train:   0 [ 10/125][    10]  Loss: 0.366 (0.375)  Time: 0.017,0.033 486/s  Mem: 0.30,0.33 GB
Train:   0 [ 20/125][    20]  Loss: 0.380 (0.375)  Time: 0.015,0.028 578/s  Mem: 0.30,0.33 GB
Train:   0 [ 30/125][    30]  Loss: 0.359 (0.371)  Time: 0.013,0.024 673/s  Mem: 0.30,0.33 GB
Train:   0 [ 40/12

In [76]:
import pandas as pd

In [77]:
results_df = pd.DataFrame.from_records(results)

In [84]:
results_table = results_df.iloc[:, 1:].pivot_table(
    "acc", index=["context", "window", "summary"], columns=["layer"]
)

results_table = results_table.iloc[:, [2, 0, 1]]
results_table.style.format(precision=3)

Unnamed: 0_level_0,Unnamed: 1_level_0,layer,model.layers.7,model.layers.11,model.layers.15
context,window,summary,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
long,16,False,0.198,0.194,0.185
long,16,True,0.198,0.195,0.185
short,16,False,0.175,0.178,0.174
short,32,False,0.185,0.185,0.179
