# **fMRI Learning Stage Classification with Vision Transformers**

This notebook implements a Vision Transformer model to classify different stages of learning from fMRI data.

## Setup and Dependencies

In [78]:
!pip install lru-dict pywavelets nibabel openneuro-py boto3 nilearn



#### Import libraries

In [79]:
import os
import re
import sys
import time
import random
import math
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import json
import logging
import urllib.request
import zipfile
import tarfile
import pywt
from google.colab import drive
import torch
import numpy as np
import pandas as pd
import nibabel as nib
import nilearn
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from timm.models.layers import DropPath
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score
from dataclasses import dataclass, field
from functools import partial
from collections import defaultdict
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import LambdaLR
from scipy.ndimage import zoom
from scipy import signal
from scipy.interpolate import interp1d
from einops.layers.torch import Rearrange
from transformers import get_cosine_schedule_with_warmup
from lru import LRU as LRUCache
from torch.optim import AdamW
from openneuro import download
import boto3
from pywt import wavedec
import matplotlib.pyplot as plt
import seaborn as sns
from lru import LRU
from itertools import chain
from nilearn import plotting, image
from scipy import stats
import torch.utils.checkpoint as checkpoint

#### System Config

In [80]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

## Configuration

In [81]:
@dataclass
class Config:
    # paths (unchanged)
    ROOT: str = "/content/drive/MyDrive/learnedSpectrum"
    CACHE: str = "/content/fmri_cache"
    CKPT_DIR: str = "/content/checkpoints"

    # training dynamics
    BATCH_SIZE: int = 8  # doubled
    NUM_WORKERS: int = 4
    PIN_MEMORY: bool = True
    PERSISTENT_WORKERS: bool = True
    USE_AMP: bool = True
    GRADIENT_ACCUMULATION_STEPS: int = 4  # halved bc batch doubled

    # optimization
    LEARNING_RATE: float = 1e-4  # increased
    WEIGHT_DECAY: float = 0.1    # doubled
    NUM_EPOCHS: int = 30
    GRAD_CLIP: float = 0.5       # halved
    WARMUP_EPOCHS: int = 3       # reduced
    MIN_LR: float = 1e-6

    # architecture
    VOLUME_SIZE: Tuple[int, int, int] = (64, 64, 30)
    PATCH_SIZE: int = 7
    NUM_PATCHES: int = (64//7 * 64//7 * 30//7)
    TIME_STEPS: int = 32
    EMBED_DIM: int = 192         # wider better than deeper for fmri
    NUM_HEADS: int = 6           # dim/32 ratio optimal
    NUM_LAYERS: int = 4          # retain shallow for spatiotemporal
    DROPOUT: float = 0.15        # sweet spot from latest lit

    # misc
    PATIENCE: int = 15
    MIN_DELTA: float = 1e-3
    TASK_DIM: int = 256         # halved
    KFAC_UPDATE_FREQ: int = 10

    # device props
    @property
    def device(self) -> torch.device:
        return device

    @property
    def fp16(self) -> bool:
        return self.USE_AMP

    # task info (unchanged)
    TASK_INFO: Dict = field(default_factory=lambda: {
        'ds000002': {'type': 'prob_class', 'tr': 2.0},
        'ds000011': {'type': 'det_class', 'tr': 1.5},
        'ds000017': {'type': 'reversal', 'tr': 2.5},
        'ds000052': {'type': 'learning', 'tr': 2.0}
    })

In [82]:
@dataclass(frozen=True)
class DataConfig:
    ROOT: str = "/content/drive/MyDrive/learnedSpectrum"
    CACHE: str = "/content/fmri_cache"
    DATASET_PATHS: Dict[str, str] = field(default_factory=lambda: {
        'ds000002': 'classification/probabilistic',
        'ds000011': 'classification/deterministic',
        'ds000017': 'learning/reversal',
        'ds000052': 'learning/stages'
    })
    TASK_INFO: Dict[str, Dict] = field(default_factory=lambda: {
        'ds000002': {'type': 'prob_class', 'tr': 2.0},
        'ds000011': {'type': 'det_class', 'tr': 1.5},
        'ds000017': {'type': 'reversal', 'tr': 2.5},
        'ds000052': {'type': 'learning', 'tr': 2.0}
    })

Print current GPU memory usage

In [83]:
def print_gpu_memory():
    if torch.cuda.is_available():
        print(f"GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f}GB allocated, "
              f"{torch.cuda.max_memory_allocated()/1e9:.2f}GB peak")

In [84]:
config = Config()
print_gpu_memory()

GPU Memory: 0.17GB allocated, 0.32GB peak


## Data Loading

In [85]:
class DatasetManager:
    DATASETS = {
        'ds000002': {
            'url': 'https://s3.amazonaws.com/openneuro/ds000002/ds000002_R2.0.5/compressed/ds000002_R2.0.5_raw.zip',
            'tr': 2.0,
            'stage_map': lambda f: 0.25 if 'run-2' in str(f) else 0.0  # prob learning
        },
        'ds000011': {
            'url': 'https://s3.amazonaws.com/openneuro/ds000011/ds000011_R2.0.1/compressed/ds000011_R2.0.1_raw.zip',
            'tr': 1.5,
            'stage_map': lambda f: 0.5 if 'run-2' in str(f) else 0.25  # det learning
        },
        'ds000017': {
            'url': 'https://s3.amazonaws.com/openneuro/ds000017/ds000017_R2.0.1/compressed/ds000017_R2.0.1.zip',
            'tr': 2.5,
            'stage_map': lambda f: 0.75 if 'run-2' in str(f) else 0.5  # reversal
        },
        'ds000052': {
            'url': 'https://s3.amazonaws.com/openneuro/ds000052/ds000052_R2.0.0/compressed/ds052_R2.0.0_01-14.tgz',
            'tr': 2.0,
            'stage_map': lambda f: 1.0 if ('reversal' in str(f) and 'run-2' in str(f)) else
                                  0.67 if ('reversal' in str(f)) else
                                  0.33 if 'run-2' in str(f) else 0.0  # full spectrum
        }
    }

    def __init__(self, config: DataConfig):
        self.config = config
        self.root = Path(config.ROOT).resolve()
        self._mount_drive()

    def _mount_drive(self):
        if not Path('/content/drive').exists():
            drive.mount('/content/drive')
        self.root.mkdir(parents=True, exist_ok=True)

    def _fetch_dataset(self, ds_id: str):
        path = self.root/ds_id
        if not path.exists():
            print(f"{ds_id} not found.")
            print(f"manual steps:")
            print(f"1. wget {self.DATASETS[ds_id]['url']}")
            print(f"2. extract to {path}")
            raise FileNotFoundError(f"download {ds_id} first")

    def _exists_and_valid(self, ds_id: str) -> bool:
        path = self.root/ds_id
        print(f"checking {ds_id} at: {path}")
        return path.exists() and any(path.glob('**/*bold.nii.gz'))

    def fetch_datasets(self):
        for ds_id in self.DATASETS:
            if not self._exists_and_valid(ds_id):
                self._fetch_dataset(ds_id)

    def get_all_files(self) -> Dict[str, List[Path]]:
        return {
            ds_id: list(self.root.glob(f"{ds_id}/**/*bold.nii.gz"))
            for ds_id in self.DATASETS
            if self._exists_and_valid(ds_id)
        }

    def get_learning_stages(self) -> Dict[str, float]:
        stages = {}
        for ds_id, files in self.get_all_files().items():
            stage_map = self.DATASETS[ds_id]['stage_map']
            for f in files:
                stages[f.parts[-3]] = stage_map(f)
        return stages

In [86]:
manager = DatasetManager(config)
manager.fetch_datasets()
files = manager.get_all_files()
stages = manager.get_learning_stages()

checking ds000002 at: /content/drive/MyDrive/learnedSpectrum/ds000002
checking ds000011 at: /content/drive/MyDrive/learnedSpectrum/ds000011
checking ds000017 at: /content/drive/MyDrive/learnedSpectrum/ds000017
checking ds000052 at: /content/drive/MyDrive/learnedSpectrum/ds000052
checking ds000002 at: /content/drive/MyDrive/learnedSpectrum/ds000002
checking ds000011 at: /content/drive/MyDrive/learnedSpectrum/ds000011
checking ds000017 at: /content/drive/MyDrive/learnedSpectrum/ds000017
checking ds000052 at: /content/drive/MyDrive/learnedSpectrum/ds000052
checking ds000002 at: /content/drive/MyDrive/learnedSpectrum/ds000002
checking ds000011 at: /content/drive/MyDrive/learnedSpectrum/ds000011
checking ds000017 at: /content/drive/MyDrive/learnedSpectrum/ds000017
checking ds000052 at: /content/drive/MyDrive/learnedSpectrum/ds000052


## Create datasets

extract region activations using aal atlas

data: [T,H,W,D] fmri timeseries

In [87]:
def extract_aal_regions(data: torch.Tensor) -> torch.Tensor:
    aal = load_aal_atlas()

    regions = torch.zeros(116)
    for i in range(116):
        mask = (aal == i+1)
        if mask.any():
            regions[i] = data[:, mask].mean()

    regions = (regions - regions.mean()) / (regions.std() + 1e-6)
    return regions

extract temporal dynamics using wavelet decomp

data: [T,H,W,D] fmri timeseries

returns: [n_components] frequency features

In [88]:
def extract_temporal_patterns(data: torch.Tensor, n_components: int = 32) -> torch.Tensor:

    signal = data.reshape(data.shape[0], -1).mean(1)

    coeffs = wavedec(signal, 'db4', level=int(np.log2(len(signal))))

    features = torch.cat([torch.from_numpy(c) for c in coeffs])
    return features[:n_components]

In [89]:
class BIDSManager:
    def __init__(self, dataset_id='ds000052', path='data/'):
        self.dataset_id = dataset_id
        self.root = Path(path)
        self.root.mkdir(exist_ok=True)
        self._fetch_dataset()

    def _fetch_dataset(self):
        if not (self.root/self.dataset_id).exists():
            url = "https://s3.amazonaws.com/openneuro/ds000052/ds000052_R2.0.0/compressed/ds052_R2.0.0_01-14.tgz"
            target = self.root/'data.tgz'

            import requests
            r = requests.get(url, stream=True)
            r.raise_for_status()

            with open(target, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)

            import tarfile
            with tarfile.open(target) as tf:
                tf.extractall(self.root)
            target.unlink()

    def get_task_files(self, task='weatherprediction'):
        return sorted(self.root.glob(f"**/*task-{task}*_bold.nii.gz"))

In [90]:
class FMRIDataset(Dataset):
    def __init__(
        self,
        root='data/',
        tr=2.0,
        n_timepoints=30,
        n_regions=116,
        validate=True,
        cache_size=50
    ):
        self.manager = BIDSManager(path=root)
        self.tr = tr
        self.n_timepoints = n_timepoints
        self.n_regions = n_regions
        self.files = self._get_valid_files() if validate else self.manager.get_task_files()
        self.labels = self._extract_learning_stages()
        self.cache = LRUCache(size=cache_size)

        self.atlas = self._load_aal_atlas()

        self.time_indices = np.array([
            int(re.search(r'run-(\d+)', str(f)).group(1))
            for f in self.files
        ])

    def _get_valid_files(self):
        valid = []
        for f in tqdm(self.manager.get_task_files(), desc='validating niftis'):
            if self._validate_nifti(f):
                valid.append(f)

        if not valid:
            raise ValueError("no valid niftis. check bids structure.")

        return valid

    def _validate_nifti(self, f):
        try:
            img = nib.load(str(f))
            data = img.get_fdata(dtype=np.float32)
            return (data.ndim == 4 and
                   not np.any(np.isnan(data)) and
                   not np.any(np.isinf(data)) and
                   data.shape[-1] >= self.n_timepoints)
        except Exception as e:
            logger.error(f"invalid nifti {f}: {str(e)}")
            return False

    def _load_aal_atlas(self):
        from nilearn.datasets import fetch_atlas_aal
        atlas = fetch_atlas_aal()
        return nib.load(atlas['maps']).get_fdata()

    def _extract_aal_regions(self, data):
        regions = torch.zeros(self.n_regions)
        for i in range(self.n_regions):
            mask = (self.atlas == i+1)
            if mask.any():
                regions[i] = data[:, mask].mean()
        return (regions - regions.mean()) / (regions.std() + 1e-6)

    def _extract_temporal_patterns(self, data):
        signal = data.reshape(data.shape[0], -1).mean(1)
        level = min(int(np.log2(len(signal))-2), 3)
        coeffs = wavedec(signal, 'db4', level=level)
        features = torch.cat([torch.from_numpy(c) for c in coeffs])
        if len(features) < self.n_timepoints:
            features = F.pad(features, (0, self.n_timepoints - len(features)))
        return features[:self.n_timepoints]

    def _extract_learning_stages(self):
        stages = {}
        for f in self.files:
            if 'reversal' in str(f):
                stages[f.parts[-3]] = 1.0 if 'run-2' in str(f) else 0.67
            else:
                stages[f.parts[-3]] = 0.33 if 'run-2' in str(f) else 0.0
        return stages

    def _load_sample(self, idx):
        fpath = self.files[idx]
        if fpath in self.cache:
            return self.cache[fpath]

        img = nib.load(str(fpath))
        data = img.get_fdata(dtype=np.float32)

        data = (data - data.mean()) / (data.std() + 1e-6)
        data = data[...,:self.n_timepoints]

        data = torch.from_numpy(data).float()
        data = data.permute(3,0,1,2)
        regions = self._extract_aal_regions(data)
        temporal = self._extract_temporal_patterns(data)

        label = self.labels[fpath.parts[-3]]

        sample = (data, regions, temporal, label)
        self.cache[fpath] = sample
        return sample

    def __getitem__(self, idx):
        data, regions, temporal, label = self._load_sample(idx)
        task_id = self.task_ids[idx] if hasattr(self, 'task_ids') else 0

        return data, task_id, {
            'learning_stage': float(label),
            'region_activation': regions,
            'temporal_pattern': temporal
        }

        return data, task_id, targets

    def __len__(self):
        return len(self.files)

## Dataloaders

In [91]:
def create_dataloaders(dataset, config, split=0.8):
    indices = np.arange(len(dataset))
    split_idx = int(len(indices) * split)

    train_indices = indices[:split_idx]
    val_indices = indices[split_idx:]

    train_set = torch.utils.data.Subset(dataset, train_indices)
    val_set = torch.utils.data.Subset(dataset, val_indices)

    loader_kwargs = {
        'batch_size': max(1, config.BATCH_SIZE // 2),
        'num_workers': config.NUM_WORKERS,
        'pin_memory': True,
        'collate_fn': collate_fn
    }
    train_loader = DataLoader(train_set, shuffle=True, **loader_kwargs)
    val_loader = DataLoader(val_set, shuffle=False, **loader_kwargs)


    return (
        DataLoader(train_set, shuffle=True, **loader_kwargs),
        DataLoader(val_set, shuffle=False, **loader_kwargs)
    )

In [92]:
def collate_fn(batch):
    data, task_ids, targets = zip(*batch)

    x = torch.stack(data).squeeze(1)
    x = x.permute(0,1,3,4,2)
    task_ids = torch.tensor(task_ids)

    batch_targets = {}
    for k in targets[0].keys():
        if isinstance(targets[0][k], (float, int)):
            try:
                batch_targets[k] = torch.tensor([t[k] for t in targets])
            except ValueError as e:
                print(f"collate fail on {k}: {[t[k] for t in targets]}")
                raise e
        else:
            batch_targets[k] = torch.stack([t[k] for t in targets])

    return x, task_ids, batch_targets

In [93]:
def collate_variable_length(batch):
    data, task_ids, targets = zip(*batch)

    x = torch.stack(data).squeeze(1)
    x = x.permute(0, 1, 3, 4, 2)
    task_ids = torch.tensor(task_ids)

    batch_targets = {}
    for k in targets[0].keys():
        if isinstance(targets[0][k], (float, int)):
            batch_targets[k] = torch.tensor([t[k] for t in targets])
        else:
            batch_targets[k] = torch.stack([t[k] for t in targets])

    return x, task_ids, batch_targets

ensure proper pinned memory cleanup

In [94]:
class PinnedMemoryContext:
    def __enter__(self):
        torch.cuda.empty_cache()

    def __exit__(self, *args):
        torch.cuda.empty_cache()

def get_loader_stats(loader: DataLoader) -> Dict[str, int]:
    """debug utility"""
    return {
        'batches': len(loader),
        'samples': len(loader.dataset),
        'device_batch': loader.batch_size,
        'effective_batch': loader.batch_size * config.GRADIENT_ACCUMULATION_STEPS
    }

synthetic task labels, 4 stages w/ proper progression

In [95]:
def get_mock_labels():
    return {
        'sub-01': 0.0,    # naive performance
        'sub-02': 0.25,   # early learning
        'sub-03': 0.50,   # intermediate mastery
        'sub-04': 0.75,   # advanced competence
        'sub-05': 1.0     # expert performance
    }

## Data Preprocessing

In [96]:
def preprocess_volume(vol: np.ndarray, config: Config) -> torch.Tensor:
    if vol.ndim not in (4,5):
        raise ValueError(f"expect 4d/5d vol, got {vol.ndim}d")

    if vol.ndim == 4:
        vol = vol[None]

    b,t,h,w,d = vol.shape
    target_h, target_w, target_d = config.VOLUME_SIZE

    vol = zoom(vol, (
        1,
        1,
        target_h/h,
        target_w/w,
        target_d/d
    ), order=1)

    vol = (vol - vol.mean((1,2,3,4), keepdims=True)) / (vol.std((1,2,3,4), keepdims=True) + 1e-8)

    return torch.from_numpy(vol).float()

enforce temporal consistency

In [97]:
def normalize_temporal_resolution(data: np.ndarray, orig_tr: float = 2.0, target_tr: float = 2.0) -> np.ndarray:
    if orig_tr == target_tr:
        return data

    old_times = np.arange(0, data.shape[-1]) * orig_tr
    new_times = np.arange(0, old_times[-1], target_tr)

    orig_shape = data.shape
    data_2d = data.reshape(-1, orig_shape[-1])

    interp_data = np.array([
        interp1d(old_times, ts, kind='cubic', bounds_error=False, fill_value='extrapolate')(new_times)
        for ts in data_2d
    ])

    return interp_data.reshape(*orig_shape[:-1], len(new_times))

In [98]:
class FMRIAugmentor:
    def __init__(self, p=0.5):
        self.p = p
        self.augs = [
            self._temporal_shift,
            self._gaussian_noise,
            self._dropout_voxels,
            self._mixup,
            self._intensity_scale
        ]

    def __call__(self, x):
        if random.random() < self.p:
            aug = random.choice(self.augs)
            x = aug(x)
        return x

    def _temporal_shift(self, x, max_shift=3):
        shift = random.randint(-max_shift, max_shift)
        return torch.roll(x, shifts=shift, dims=1)

    def _gaussian_noise(self, x, std=0.01):
        return x + torch.randn_like(x) * std

    def _dropout_voxels(self, x, p=0.1):
        mask = torch.rand_like(x) > p
        return x * mask

    def _mixup(self, x, alpha=0.2):
        lam = np.random.beta(alpha, alpha)
        idx = torch.randperm(x.size(0))
        return lam * x + (1-lam) * x[idx]

    def _intensity_scale(self, x, range=(0.9, 1.1)):
        scale = random.uniform(*range)

## 3D Vision Transformer

In [99]:
class HierarchicalAttention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.local_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.global_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.merge = nn.Linear(dim * 2, dim)

        self.task_gate = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Sigmoid()
        )

    def forward(self, x, task_embed=None):
        b, n, _ = x.shape
        local_out = self.local_attn(x, x, x)[0]

        if task_embed is not None:
            task_gate = self.task_gate(task_embed).unsqueeze(1)
            x = x * task_gate

        global_out = self.global_attn(x, x, x)[0]
        return self.merge(torch.cat([local_out, global_out], dim=-1))

In [100]:
class TaskConditionedMLP(nn.Module):
    def __init__(self, dim, hidden_dim=None):
        super().__init__()
        hidden_dim = hidden_dim or dim * 4

        self.task_gate = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Sigmoid()
        )

        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x, task_embed):
        gate = self.task_gate(task_embed).unsqueeze(1)
        return self.net(x * gate)

gating between local/global temporal patterns

In [101]:
class AdaptiveTokenMixer(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.local_proj = nn.Linear(dim, dim)
        self.global_proj = nn.Linear(dim, dim)
        self.gate = nn.Linear(dim, 1)

        self.temp = nn.Parameter(torch.ones(1, num_heads, 1, 1))

        self.rel_pos = nn.Parameter(torch.randn(2 * dim - 1, head_dim))
        pos_index = torch.arange(dim)
        rel_pos_index = pos_index[None, :] - pos_index[:, None]
        rel_pos_index += dim - 1
        self.register_buffer('rel_pos_index', rel_pos_index)

    def forward(self, x, mask=None):
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        rel_pos_bias = self.rel_pos[self.rel_pos_index].permute(2, 0, 1)

        attn = ((q @ k.transpose(-2, -1)) * self.scale + rel_pos_bias) * self.temp
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = attn.softmax(dim=-1)

        local_out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        global_out = self.global_proj(x)

        gate = self.gate(x).sigmoid()
        x = gate * local_out + (1 - gate) * global_out

        return x

In [102]:
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    omega = np.arange(embed_dim // 2, dtype=np.float32)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega
    pos = pos.reshape(-1)
    out = np.einsum('m,d->md', pos, omega)
    emb_sin = np.sin(out)
    emb_cos = np.cos(out)
    emb = np.concatenate([emb_sin, emb_cos], axis=1)
    return emb

In [103]:
def get_3d_sincos_pos_embed(embed_dim, dims, cls_token=True):
    h, w, d = dims
    total_pos = h * w * d
    pos = get_1d_sincos_pos_embed_from_grid(embed_dim, np.arange(total_pos))

    if cls_token:
        return np.concatenate([np.zeros((1, embed_dim)), pos])
    return pos

In [104]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.eps = 1e-6

    def forward(self, pred, target):
        freqs = torch.bincount(target.long(), minlength=2).float()
        freqs = freqs / freqs.sum()
        gamma = self.gamma * (1 - freqs[target.long()])

        ce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
        p = torch.sigmoid(pred)
        p_t = p * target + (1 - p) * (1 - target)
        alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target)

        loss = alpha_t * ((1 - p_t) ** gamma) * ce
        return loss.mean()

In [105]:
class RegressionFocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.eps = 1e-6

    def forward(self, pred, target):
        pred = torch.clamp(pred, self.eps, 1-self.eps)

        base_loss = F.mse_loss(pred, target, reduction='none')
        alpha_t = self.alpha + (1-self.alpha) * target

        pt = torch.exp(-base_loss)
        focal_weight = (1 - pt) ** self.gamma * alpha_t

        return (focal_weight * base_loss).mean()

In [108]:
class WeightedMultiTaskLoss(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.stage_loss = OrdinalFocalLoss()
        self.contrast_loss = ContrastiveLoss(temperature=0.1)

    def forward(self, outputs, targets, epoch=None):
        stage_out = outputs['learning_stage'].squeeze()
        stage_target = targets['learning_stage']

        if stage_out.ndim == 1:
            stage_out = stage_out.unsqueeze(-1)
        if stage_target.ndim == 1:
            stage_target = stage_target.unsqueeze(-1)

        # scale losses
        stage_loss = self.stage_loss(stage_out, stage_target)
        contrast_loss = 0.0

        if epoch and epoch > 5:
            x1, x2 = pretrain_transform(outputs['features'])
            contrast_loss = 0.01 * self.contrast_loss(x1, x2)

        loss = stage_loss + contrast_loss
        return loss, {
            'stage_loss': stage_loss.item(),
            'contrast_loss': contrast_loss if isinstance(contrast_loss, float) else contrast_loss.item()
        }

In [109]:
class EntropyRegularizer(nn.Module):
    def __init__(self, alpha=0.1, temperature=1.0):
        super().__init__()
        self.alpha = alpha
        self.temperature = temperature

    def forward(self, stage_preds):
        # normalize predictions
        probs = stage_preds / self.temperature

        # compute entropy
        entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=1).mean()

        # target entropy (uniform dist)
        target_entropy = -torch.log(torch.tensor(1.0 / probs.size(1)))

        return self.alpha * (target_entropy - entropy).abs()

In [110]:
class StagePredictor(nn.Module):
    def __init__(self, dim, num_stages=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim*2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(dim*2, num_stages)
        )

        # initialization for better gradient flow
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=1/np.sqrt(2))
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        logits = self.net(x)
        return F.softmax(logits, dim=1)

In [111]:
class TemporalAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.EMBED_DIM
        self.num_heads = 8
        self.head_dim = self.embed_dim // self.num_heads

        self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim)
        self.proj = nn.Linear(self.embed_dim, self.embed_dim)

        # relative positional bias
        self.rel_pos_bias = nn.Parameter(
            torch.zeros(2 * config.TIME_STEPS - 1, self.num_heads)
        )
        pos_index = torch.arange(config.TIME_STEPS)
        rel_pos_index = pos_index[None, :] - pos_index[:, None]
        rel_pos_index += config.TIME_STEPS - 1
        self.register_buffer('rel_pos_index', rel_pos_index)

    def forward(self, x):
        B, T, C = x.shape

        # qkv projection
        qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        # attention w/relative pos bias
        attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = attn + self.rel_pos_bias[self.rel_pos_index]
        attn = attn.softmax(dim=-1)

        # aggregate and project
        x = (attn @ v).transpose(1, 2).reshape(B, T, C)
        x = self.proj(x)
        return x

In [112]:
class WeightedMSE(nn.Module):
    def __init__(self):
        super().__init__()
        self.base = nn.MSELoss(reduction='none')

    def forward(self, pred, true):
        conf = 1 - torch.abs(true - 0.5) * 2
        loss = self.base(pred, true)
        return (loss * conf).mean()

In [113]:
def shape_hook(name):
    def hook(module, input, output):
        if not hasattr(hook, 'count'):
            hook.count = 0
        hook.count += 1
        if hook.count <= 2:
            print(f"{name} in:", [x.shape if isinstance(x, torch.Tensor) else None for x in input])
            print(f"{name} out:", output[0].shape if isinstance(output, tuple) else output.shape)
    return hook

In [114]:
class StageHead(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.in_norm = nn.LayerNorm(dim)
        self.proj1 = nn.Linear(dim, dim)
        self.mid_norm = nn.LayerNorm(dim)
        self.proj2 = nn.Linear(dim, dim//2)
        self.out_norm = nn.LayerNorm(dim//2)
        self.out = nn.Linear(dim//2, 1)
        self.act = nn.GELU()

        nn.init.orthogonal_(self.proj1.weight, gain=1/np.sqrt(2))
        nn.init.orthogonal_(self.proj2.weight, gain=1/np.sqrt(2))
        nn.init.orthogonal_(self.out.weight)

    def forward(self, x):
        h = self.in_norm(x)
        h = self.act(self.proj1(h))
        h = self.mid_norm(h + x)
        h = self.act(self.proj2(h))
        h = self.out_norm(h)

        return torch.sigmoid(self.out(h))

In [115]:
class TaskGate(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.gate = nn.Sequential(
            nn.Linear(dim, dim//2),
            nn.GELU(),
            nn.Linear(dim//2, dim),
            nn.Sigmoid()
        )
        self.scale = nn.Parameter(torch.ones(1) * 0.1)

    def forward(self, x, task):
        g = self.gate(self.norm(task))
        return x * (g * self.scale)

In [116]:
class WaveletDecomp(nn.Module):
    def __init__(self, n_levels=4, wavelet='db4'):
        super().__init__()
        self.n_levels = n_levels
        self.wavelet = wavelet

    def forward(self, x):
        x_np = x.detach().cpu().numpy()
        coeffs = []

        for i in range(len(x_np)):
            c = pywt.wavedec(
                x_np[i], self.wavelet,
                level=self.n_levels
            )
            coeffs.append(np.concatenate(c))

        return torch.from_numpy(
            np.stack(coeffs)
        ).float().to(x.device)

In [117]:
class ScaleNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(1) * dim ** 0.5)
        self.eps = eps

    def forward(self, x):
        norm = self.scale / torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
        return x * norm

In [118]:
class WaveletTemporal(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.EMBED_DIM

        # spatial -> temporal flow
        self.spatial_proj = nn.Conv3d(1, config.EMBED_DIM, 1).to(config.device)
        self.temporal_proj = nn.Conv3d(
            config.EMBED_DIM,
            config.EMBED_DIM,
            (3,1,1),
            padding=(1,0,0)
        ).to(config.device)

        self.pool = nn.AdaptiveAvgPool3d((15, 32, 32)).to(config.device)

    def forward(self, x):
        # reshape for conv
        b, t, h, d, w = x.shape
        x = x.reshape(b, 1, t, h, w*d)

        # projections
        x = self.spatial_proj(x)
        x = self.temporal_proj(x)

        # force dims
        x = self.pool(x)
        return x

In [119]:
def debug_shapes(x):
    # input
    b,t,h,d,w = x.shape
    print(f"input: {x.shape}")

    # reshape attempt 1
    x1 = x.reshape(b, 1, t, h, d*w)
    print(f"reshape1: {x1.shape}")

    # reshape attempt 2
    x2 = x.reshape(b, 1, t, h, w*d)
    print(f"reshape2: {x2.shape}")

    # reshape + permute
    x3 = x.permute(0,1,2,4,3).reshape(b, 1, t, h, w*d)
    print(f"reshape3: {x3.shape}")

    return x3  # return best version

In [120]:
def get_3d_pos_embedding(num_patches, embed_dim, cls_token=True):
    # basic sincos but w/ proper dims
    pos = torch.arange(num_patches).unsqueeze(1)
    omega = torch.exp(
        torch.arange(embed_dim//2, dtype=torch.float32) *
        (-math.log(10000.0) / (embed_dim//2))
    )
    pos_emb = pos * omega
    pos_emb = torch.cat([torch.sin(pos_emb), torch.cos(pos_emb)], dim=-1)

    return pos_emb

In [121]:
class SequentialBrainViT(nn.Module):
    def __init__(self, config):
        super().__init__()

        # aggressive downsampling
        self.spatial_proj = nn.Sequential(
            nn.Conv3d(30, 32, 3, stride=2, padding=1),  # halve
            nn.LayerNorm([32, 32, 11, 32]),
            nn.GELU(),

            nn.Conv3d(32, 64, 3, stride=2, padding=1),  # quarter
            nn.LayerNorm([64, 16, 6, 16]),
            nn.GELU(),

            nn.AdaptiveAvgPool3d((8, 4, 8))  # force smaller dims
        )

        # ~256 tokens vs 11k before
        h, w, d = 8, 4, 8
        self.embed_dim = 64
        self.num_tokens = h * w * d

        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        # fewer layers + heads
        self.encoder = nn.ModuleList([
            TransformerBlock(self.embed_dim, num_heads=4, mlp_ratio=2)
            for _ in range(3)
        ])

        self.stage_head = nn.Sequential(
            nn.LayerNorm(self.embed_dim),
            nn.Linear(self.embed_dim, 1),
            nn.Sigmoid()
        )

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x, task_ids=None):
        x = self.spatial_proj(x)
        x = x.flatten(2).transpose(1, 2)
        x = x + self.pos_embed

        for block in self.encoder:
            x = block(x)

        x = x.mean(1)
        return {
            'learning_stage': self.stage_head(x),
            'features': x
        }

In [122]:
def verify_model_devices(model):
    devices = {}
    for name, param in model.named_parameters():
        devices[name] = param.device
    return devices

In [123]:
def pretrain_transform(x, sigma=0.1):
    """generate two augmented views"""
    x1 = x + torch.randn_like(x) * sigma
    x2 = torch.roll(
        x + torch.randn_like(x) * sigma,
        shifts=random.randint(-3, 3),
        dims=1
    )
    return x1, x2

In [124]:
def mixup(x, y, alpha=0.2):
    """mixup augmentation"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    mixed_y = lam * y + (1 - lam) * y[index]

    return mixed_x, mixed_y

In [125]:
class StageHead(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(dim * 2, dim // 2),
            nn.LayerNorm(dim // 2),
            nn.GELU(),
            nn.Linear(dim // 2, 1),
            nn.Sigmoid()
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight, gain=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)

In [126]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=4, drop_path=0.):
        super().__init__()

        # attention
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(
            dim, num_heads, batch_first=True
        )

        # ffn
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * mlp_ratio),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(dim * mlp_ratio, dim)
        )

        # stochastic depth
        self.drop_path = DropPath(drop_path)

    def forward(self, x):
        # residual attention
        x = x + self.drop_path(
            self.attn(*[self.norm1(x)]*3)[0]
        )

        # residual ffn
        x = x + self.drop_path(
            self.mlp(self.norm2(x))
        )
        return x

## Train Setup

In [127]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert isinstance(params, (list, tuple)) and len(params) > 0, "params must be non-empty list"
        defaults = dict(rho=rho, adaptive=adaptive)
        super().__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)
        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]
        self.base_optimizer.step()
        if zero_grad: self.zero_grad()

    def step(self, closure=None):
        assert closure is not None, "SAM requires closure"
        with torch.enable_grad():
            loss = closure()
        self.first_step(zero_grad=True)
        with torch.enable_grad():
            closure()
        self.second_step()
        return loss

    def _grad_norm(self):
        grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
        if not grads:
            return torch.tensor(0.0, device=self.param_groups[0]["params"][0].device)
        return torch.norm(torch.stack([g.norm(p=2) for g in grads]), p=2)

In [128]:
class KFAC(torch.optim.Optimizer):
    def __init__(self, model, lr=1e-3, momentum=0.9, damping=1e-4):
        params = list(model.parameters())
        super().__init__(params, dict(lr=lr, momentum=momentum))

        self.known_modules = {'Linear', 'Conv2d'}
        self.modules = []
        self.grad_outputs = {}
        self.acc_stats = False

        self._prepare_model(model)

    def _save_input(self, module, input):
        if self.acc_stats and module.training:
            self.inputs[module] = input[0].data

    def _save_grad_output(self, module, grad_input, grad_output):
        if self.acc_stats and module.training:
            self.grad_outputs[module] = grad_output[0].data

    def _register_hook(self, module):
        handle1 = module.register_forward_pre_hook(self._save_input)
        handle2 = module.register_backward_hook(self._save_grad_output)
        self.handles[module] = (handle1, handle2)

    def _prepare_model(self, model):
        self.modules = []
        self.handles = {}

        for module in model.modules():
            class_name = module.__class__.__name__
            if class_name in self.known_modules:
                self.modules.append(module)
                self._register_hook(module)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for module in self.modules:
            if module not in self.grad_outputs:
                continue

            grad_out = self.grad_outputs[module]
            grad_in = self.inputs[module]

            g_g = grad_out.t() @ grad_out
            a_a = grad_in.t() @ grad_in

            if module not in self.state:
                self.state[module] = {}
            state = self.state[module]

            mom = self.defaults['momentum']
            damp = self.defaults['damping']

            if 'g_g_mom' not in state:
                state['g_g_mom'] = g_g.clone()
                state['a_a_mom'] = a_a.clone()
            else:
                state['g_g_mom'].mul_(mom).add_(g_g)
                state['a_a_mom'].mul_(mom).add_(a_a)

            g_g_mom = state['g_g_mom']
            a_a_mom = state['a_a_mom']

            g_g_inv = (g_g_mom + damp * torch.eye(
                g_g_mom.size(0), device=g_g_mom.device
            )).inverse()
            a_a_inv = (a_a_mom + damp * torch.eye(
                a_a_mom.size(0), device=a_a_mom.device
            )).inverse()

            if module.weight.grad is not None:
                weight_grad = module.weight.grad.reshape(
                    -1, module.weight.size(0)
                )
                module.weight.grad = (
                    g_g_inv @ weight_grad @ a_a_inv
                ).reshape_as(module.weight)

            if module.bias is not None and module.bias.grad is not None:
                module.bias.grad = (g_g_inv @ module.bias.grad)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                param_state = self.state[p]
                if 'momentum_buffer' not in param_state:
                    buf = param_state['momentum_buffer'] = p.grad.clone()
                else:
                    buf = param_state['momentum_buffer']
                    buf.mul_(group['momentum']).add_(p.grad)

                p.data.add_(buf, alpha=-group['lr'])

        return loss

In [129]:
def get_optimizer(model, config):
    no_decay = ['bias', 'LayerNorm.weight']
    grouped_params = [
        {
            'params': [p for n, p in model.named_parameters()
                      if not any(nd in n for nd in no_decay)],
            'weight_decay': config.WEIGHT_DECAY
        },
        {
            'params': [p for n, p in model.named_parameters()
                      if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0
        }
    ]

    opt = torch.optim.AdamW(grouped_params, lr=config.LEARNING_RATE)
    return opt

In [130]:
class OrdinalFocalLoss(nn.Module):
    def __init__(self, alpha=2.0):
        super().__init__()
        self.alpha = alpha
        self.eps = 1e-6

    def forward(self, pred, target):
        # clamp for num stability
        pred = torch.clamp(pred, self.eps, 1-self.eps)

        # base l2 loss
        diff = (pred - target) ** 2

        # focal weight w/softmax normalization
        pt = torch.exp(-diff) / torch.exp(-diff).sum()
        focal_weight = (1 - pt) ** self.alpha

        # entropy reg (bounded)
        entropy = -(pred * torch.log(pred + self.eps)).mean()
        entropy_reg = torch.clamp(0.5 - entropy, 0, 1)

        # combine + scale
        loss = (focal_weight * diff).mean() + 0.1 * entropy_reg
        return torch.clamp(loss, 0, 100.0)  # prevent explosion

In [131]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, z1, z2):
        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)

        # compute similarity matrix
        logits = torch.mm(z1, z2.t()) / self.temperature

        # positive pairs on diagonal
        labels = torch.arange(z1.shape[0]).to(z1.device)

        loss = F.cross_entropy(logits, labels)
        return loss

In [132]:
def pretrain_transform(x):
    aug1 = torch.roll(x, shifts=random.randint(1, x.shape[-1]//4), dims=-1)
    aug2 = x + torch.randn_like(x) * 0.01
    return aug1, aug2

In [133]:
class SimpleTemporal(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, true):
        pred = F.normalize(pred, dim=-1)
        true = F.normalize(true, dim=-1)

        sim = F.cosine_similarity(pred, true)
        loss = F.smooth_l1_loss(sim, torch.ones_like(sim))

        loss = loss + 0.01 * (pred.pow(2).sum() + true.pow(2).sum())

        return loss

In [134]:
class SlowCurriculum:
    def __init__(self, epochs):
        self.epochs = epochs
        self.task_weights = {
            'learning_stage': lambda e: min(1.0, e/10),
            'region_activation': lambda e: min(0.5, e/10),
            'temporal_pattern': lambda e: min(0.3, e/15)
        }

    def get_weights(self, epoch):
        return {k: f(epoch) for k,f in self.task_weights.items()}

In [135]:
class TemporalTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.time_rnn = nn.LSTM(
            config.EMBED_DIM,
            config.EMBED_DIM//2,
            bidirectional=True,
            batch_first=True
        )
        self.time_norm = nn.LayerNorm(config.EMBED_DIM)

    def forward(self, x):
        b,t,n,d = x.shape
        x = x.reshape(b, t, -1)
        x = self.time_rnn(x)[0]
        x = self.time_norm(x)
        return x.reshape(b, t, n, d)

map bids fname -> learning stage

In [136]:
def get_learning_stage(f: Path) -> float:
    fname = f.name

    is_reversal = 'reversal' in fname
    is_run2 = 'run-2' in fname

    if not is_reversal:
        return 0.25 if is_run2 else 0.0
    else:
        return 1.0 if is_run2 else 0.75

patience-based early stopping

In [137]:
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        else:
            self.best_loss = val_loss
            self.counter = 0
        return False

temporal cv w/purge+embargo periods

In [138]:
class TimeSeriesCV:
    def __init__(self, n_splits=5, purge=10, embargo=5):
        self.n_splits = n_splits
        self.purge = purge
        self.embargo = embargo

    def split(self, X, groups):
        unique_groups = np.sort(np.unique(groups))
        n_groups = len(unique_groups)
        group_size = n_groups // self.n_splits

        for i in range(self.n_splits):
            test_start = i * group_size
            test_end = test_start + group_size

            train_idx = np.where(
                (groups < unique_groups[max(0, test_start - self.embargo)]) |
                (groups > unique_groups[min(n_groups-1, test_end + self.embargo)])
            )[0]
            test_idx = np.where(
                (groups >= unique_groups[test_start]) &
                (groups < unique_groups[test_end])
            )[0]

            yield train_idx, test_idx

persist model state w/ validation metrics

In [139]:
def save_checkpoint(model, trainer, config, epoch):
    ckpt_path = Path(config.CKPT_DIR) / f"brain_vit_e{epoch}.pt"

    metrics = {
        'epoch': epoch,
        'learning_stage_mse': np.mean(trainer.metrics.get('val_mse', [])),
        'region_mae': np.mean(trainer.metrics.get('val_mae', [])),
        'temporal_corr': np.mean(trainer.metrics.get('val_corr', [])),
    }

    try:
        torch.save({
            'model': model.state_dict(),
            'metrics': metrics,
            'config': {k:v for k,v in config.__dict__.items()
                      if not k.startswith('_')}
        }, ckpt_path)
        print(f"saved checkpoint to {ckpt_path}")

        latest = Path(config.CKPT_DIR) / "brain_vit_latest.pt"
        if latest.exists():
            latest.unlink()
        latest.symlink_to(ckpt_path)

    except Exception as e:
        print(f"failed to save checkpoint: {str(e)}")
        return None

    return ckpt_path

huber w/ dynamic threshold

In [140]:
class AdaptiveHuber(nn.Module):
    def __init__(self, beta=0.1):
        super().__init__()
        self.beta = beta

    def forward(self, pred, true):
        error = torch.abs(pred - true)
        c = self.beta * error.detach().median()
        quad = 0.5 * error.pow(2) / c
        linear = error - 0.5 * c
        return torch.where(error <= c, quad, linear).mean()

In [141]:
class FocalRegion(nn.Module):
    def __init__(self, alpha=2.0):
        super().__init__()
        self.alpha = alpha

    def forward(self, pred, true):
        diff = torch.abs(pred - true)
        weight = (1 - torch.exp(-diff)).pow(self.alpha)
        return (weight * diff).mean()

In [142]:
class GaussianTemporal(nn.Module):
    def __init__(self, sigma=0.1):
        super().__init__()
        self.sigma = sigma

    def forward(self, pred, true):
        pred = F.normalize(pred, dim=-1)
        true = F.normalize(true, dim=-1)

        dist = torch.cdist(pred.unsqueeze(0), true.unsqueeze(0)).squeeze()
        kernel = torch.exp(-dist.pow(2) / (2 * self.sigma**2))

        return -torch.log(kernel.diag() + 1e-6).mean()

In [143]:
def time_split(dataset, n_splits=5):
    indices = np.arange(len(dataset))
    split_size = len(indices) // n_splits

    for i in range(n_splits):
        mask = np.ones(len(dataset), dtype=bool)
        start_idx = i * split_size
        end_idx = start_idx + split_size
        mask[start_idx:end_idx] = False

        train_idx = indices[mask]
        val_idx = indices[~mask]
        yield train_idx, val_idx

In [144]:
from scipy.stats import entropy as scipy_entropy
from scipy.stats import ks_2samp
import numpy as np
import torch
from tqdm import tqdm
from collections import defaultdict

class Trainer:
    def __init__(self, model, config):
        self.model = model.cuda()
        self.config = config

        self.opt = torch.optim.AdamW(
            model.parameters(),
            lr=config.LEARNING_RATE,
            weight_decay=0.1
        )

        self.sched = get_cosine_schedule_with_warmup(
            self.opt,
            config.WARMUP_EPOCHS * 5,
            config.NUM_EPOCHS * 5
        )

        self.criterion = WeightedMultiTaskLoss(device='cuda')
        self.scaler = torch.cuda.amp.GradScaler()

        self.best_val_loss = float('inf')
        self.plateau_counter = 0
        self.metrics = defaultdict(list)

    @torch.no_grad()
    def validate(self, val_loader):
        self.model.eval()
        stats = defaultdict(list)
        pred_list, label_list = [], []

        for batch in val_loader:
            x = batch[0].cuda(non_blocking=True)
            task_ids = batch[1].cuda(non_blocking=True)
            targets = {k: v.cuda(non_blocking=True)
                      for k,v in batch[2].items()}

            outputs = self.model(x, task_ids)
            loss, batch_stats = self.criterion(outputs, targets)

            stats['val_loss'].append(loss.item())
            for k,v in batch_stats.items():
                stats[f'val_{k}'].append(v)

            pred_list.append(outputs['learning_stage'].cpu().squeeze())
            label_list.append(targets['learning_stage'].cpu())

        preds = torch.cat(pred_list)
        labels = torch.cat(label_list)

        # distribution stats
        pred_hist = torch.histc(preds, bins=4, min=0, max=1)
        label_hist = torch.histc(labels, bins=4, min=0, max=1)

        stats['pred_entropy'] = scipy_entropy(pred_hist/len(preds))
        stats['label_entropy'] = scipy_entropy(label_hist/len(labels))
        stats['ks_stat'] = ks_2samp(preds, labels).statistic

        return {k: np.mean(v) if isinstance(v, list) else v
                for k,v in stats.items()}

    def train_epoch(self, train_loader, epoch):
        self.model.train()
        stats = defaultdict(list)

        with tqdm(train_loader) as pbar:
            for batch in pbar:
                x = batch[0].cuda(non_blocking=True)
                task_ids = batch[1].cuda(non_blocking=True)
                targets = {k: v.squeeze().cuda(non_blocking=True)
                          for k,v in batch[2].items()}

                if epoch > 5:
                    x, targets['learning_stage'] = mixup(x, targets['learning_stage'])

                with torch.amp.autocast(device_type='cuda'):
                    outputs = self.model(x, task_ids)
                    loss, batch_stats = self.criterion(outputs, targets)

                self.opt.zero_grad()
                self.scaler.scale(loss).backward()

                self.scaler.unscale_(self.opt)
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    0.5  # hardcoded clip for stability
                )

                self.scaler.step(self.opt)
                self.scaler.update()
                self.sched.step()

                stats['loss'].append(loss.item())
                for k,v in batch_stats.items():
                    stats[k].append(v)

                pbar.set_postfix({
                    'loss': f"{loss.item():.3f}",
                    'lr': f"{self.sched.get_last_lr()[0]:.2e}"
                })

        return {k: np.mean(v) for k,v in stats.items()}

    def train(self, train_loader, val_loader):
        for epoch in range(self.config.NUM_EPOCHS):
            train_stats = self.train_epoch(train_loader, epoch)
            val_stats = self.validate(val_loader)

            print(f"\nepoch {epoch}:")
            print("train:", " | ".join(f"{k}: {v:.3f}" for k,v in train_stats.items()))
            print("val:", " | ".join(f"{k}: {v:.3f}" for k,v in val_stats.items()))

            if val_stats['val_loss'] < self.best_val_loss:
                self.best_val_loss = val_stats['val_loss']
                self.plateau_counter = 0
                if val_stats['ks_stat'] < 0.3:
                    torch.save({
                        'epoch': epoch,
                        'model': self.model.state_dict(),
                        'opt': self.opt.state_dict(),
                        'stats': val_stats
                    }, f"{self.config.CKPT_DIR}/model_e{epoch}.pt")
            else:
                self.plateau_counter += 1

            if self.plateau_counter >= self.config.PATIENCE:
                print(f"stopping @ epoch {epoch}")
                break

            for k,v in {**train_stats, **val_stats}.items():
                self.metrics[k].append(v)

## Train Pipeline

In [145]:
def init_dirs(config):
    """ensure cache exists"""
    Path(config.CACHE).mkdir(parents=True, exist_ok=True)
    Path(config.CKPT_DIR).mkdir(parents=True, exist_ok=True)
    return config

In [146]:
config = init_dirs(Config())
data_config = DataConfig()

In [147]:
manager = DatasetManager(data_config)
manager.fetch_datasets()
files = manager.get_all_files()

checking ds000002 at: /content/drive/MyDrive/learnedSpectrum/ds000002
checking ds000011 at: /content/drive/MyDrive/learnedSpectrum/ds000011
checking ds000017 at: /content/drive/MyDrive/learnedSpectrum/ds000017
checking ds000052 at: /content/drive/MyDrive/learnedSpectrum/ds000052
checking ds000002 at: /content/drive/MyDrive/learnedSpectrum/ds000002
checking ds000011 at: /content/drive/MyDrive/learnedSpectrum/ds000011
checking ds000017 at: /content/drive/MyDrive/learnedSpectrum/ds000017
checking ds000052 at: /content/drive/MyDrive/learnedSpectrum/ds000052


In [148]:
learning_stage_labels = {}
for ds_id, ds_files in files.items():
    stage_map = manager.DATASETS[ds_id]['stage_map']
    learning_stage_labels.update({
        f.parts[-3]: float(stage_map(f))
        for f in ds_files
    })

In [149]:
model = SequentialBrainViT(config).cuda()
criterion = WeightedMultiTaskLoss(device='cuda')
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=0.1
)

In [150]:
mock_labels = get_mock_labels()
print(f"mock stages: {mock_labels}")

mock stages: {'sub-01': 0.0, 'sub-02': 0.25, 'sub-03': 0.5, 'sub-04': 0.75, 'sub-05': 1.0}


In [151]:
dataset = FMRIDataset()
train_loader, val_loader = create_dataloaders(dataset, config)

KeyboardInterrupt: 

In [152]:
trainer = Trainer(model, config)

  self.scaler = torch.cuda.amp.GradScaler()


In [153]:
for batch in train_loader:
    x, task_ids, targets = batch
    x = x.to(device)
    task_ids = task_ids.to(device)
    outputs = model(x, task_ids)

TypeError: SequentialBrainViT.forward() takes 2 positional arguments but 3 were given

In [None]:
trainer.train(train_loader, val_loader)

## Inference Utils & Visualization

In [None]:
def extract_training_stats(model, val_loader, config):
    param_count = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

    layer_norms = defaultdict(list)
    grad_norms = defaultdict(list)

    for name, param in model.named_parameters():
        if param.grad is not None:
            layer_norms[name.split('.')[0]].append(param.norm().item())
            grad_norms[name.split('.')[0]].append(param.grad.norm().item())

    model.eval()
    metrics = defaultdict(list)

    with torch.no_grad():
        for batch in val_loader:
            x = batch[0].to(config.device)
            task_ids = batch[1].to(config.device)
            targets = batch[2]
            outputs = model(x, task_ids)

            stage_pred = outputs['learning_stage'].cpu().numpy()
            stage_true = targets['learning_stage'].numpy()
            metrics['stage_mse'].append(np.mean((stage_pred - stage_true)**2))

            region_pred = outputs['region_activation'].cpu().numpy()
            region_true = targets['region_activation'].numpy()
            metrics['region_mae'].append(np.mean(np.abs(region_pred - region_true)))

            temp_pred = outputs['temporal_pattern'].cpu().numpy()
            temp_true = targets['temporal_pattern'].numpy()
            temp_corr = np.array([np.corrcoef(p,t)[0,1] for p,t in zip(temp_pred, temp_true)])
            metrics['temp_corr'].append(np.mean(temp_corr))

    print("\nmodel architecture:")
    print(f"total params: {param_count:,}")
    print(f"trainable params: {trainable:,}")

    print("\nlayer statistics:")
    for layer in layer_norms:
        print(f"{layer:20} weight_norm: {np.mean(layer_norms[layer]):.3f}  grad_norm: {np.mean(grad_norms[layer]):.3e}")

    print("\nvalidation metrics:")
    for k,v in metrics.items():
        print(f"{k:15} {np.mean(v):.3f} ± {np.std(v):.3f}")

    return {
        'architecture': {
            'total_params': param_count,
            'trainable_params': trainable
        },
        'layer_stats': {
            k: {
                'weight_norm': np.mean(v),
                'grad_norm': np.mean(grad_norms[k])
            } for k,v in layer_norms.items()
        },
        'val_metrics': {
            k: {'mean': np.mean(v), 'std': np.std(v)}
            for k,v in metrics.items()
        }
    }

stats = extract_training_stats(model, val_loader, config)

In [None]:
def plot_learning_stage_confusion(model, val_loader, config):
    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

    true_stages = []
    pred_stages = []

    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            x = batch[0].to(config.device)
            task_ids = batch[1].to(config.device)
            targets = batch[2]

            outputs = model(x, task_ids)

            pred = outputs['learning_stage'].cpu().numpy()
            true = targets['learning_stage'].numpy()

            pred_disc = np.digitize(pred, bins=[0.25, 0.5, 0.75])
            true_disc = np.digitize(true, bins=[0.25, 0.5, 0.75])

            true_stages.extend(true_disc)
            pred_stages.extend(pred_disc)

    cm = confusion_matrix(true_stages, pred_stages, normalize='true')

    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt='.2f', cmap='RdYlBu_r',
                xticklabels=['naive', 'early', 'intermediate', 'advanced'],
                yticklabels=['naive', 'early', 'intermediate', 'advanced'])
    plt.title('learning stage confusion matrix (normalized)')
    plt.xlabel('predicted stage')
    plt.ylabel('true stage')

    diag = np.diag(cm)
    acc = diag.mean()
    misclass = 1 - acc

    print(f"\nconfusion matrix analysis:")
    print(f"accuracy: {acc:.3f}")
    print(f"misclassification: {misclass:.3f}")
    print(f"per-stage accuracy:", {stage: f"{acc:.3f}" for stage, acc in
                                 zip(['naive', 'early', 'inter', 'adv'], diag)})

    plt.tight_layout()
    plt.show()

    return {
        'confusion_matrix': cm,
        'accuracy': acc,
        'misclassification': misclass,
        'per_stage_accuracy': dict(zip(['naive', 'early', 'inter', 'adv'], diag))
    }

results = plot_learning_stage_confusion(model, val_loader, config)

In [None]:
def plot_comprehensive_analysis(model, val_loader):
    def temporal_analysis():
        x, task_ids, targets = next(iter(val_loader))
        x = x.to(model.device)
        task_ids = task_ids.to(model.device)

        with torch.no_grad():
            preds = model(x, task_ids)

        plt.figure(figsize=(15,10))

        plt.subplot(2,2,1)
        for i in range(min(4, len(targets['temporal_pattern']))):
            true = targets['temporal_pattern'][i].numpy()
            pred = preds['temporal_pattern'][i].cpu().numpy()
            if not np.any(np.isnan(true)) and not np.any(np.isnan(pred)):
                plt.plot(true, f'C{i}-', label=f'true_{i}', alpha=0.7)
                plt.plot(pred, f'C{i}--', label=f'pred_{i}', alpha=0.7)
        plt.title('temporal dynamics')
        plt.legend()

        plt.subplot(2,2,2)
        true_flat = targets['temporal_pattern'].numpy().reshape(-1)
        pred_flat = preds['temporal_pattern'].cpu().numpy().reshape(-1)
        mask = ~np.isnan(true_flat) & ~np.isnan(pred_flat)
        if mask.any():
            plt.scatter(true_flat[mask], pred_flat[mask], alpha=0.5)
            plt.plot([-2,2], [-2,2], 'r--', alpha=0.5)  # standardized range
        plt.xlabel('true')
        plt.ylabel('pred')
        plt.title('temporal correlation')

        plt.subplot(2,2,3)
        errors = pred_flat[mask] - true_flat[mask] if mask.any() else np.array([])
        if len(errors):
            sns.histplot(errors, kde=True)
            plt.axvline(0, color='r', linestyle='--', alpha=0.5)
            plt.title(f'error dist (μ={errors.mean():.3f}, σ={errors.std():.3f})')

        plt.subplot(2,2,4)
        if mask.any():
            f_true, Pxx_true = signal.welch(true_flat[mask], fs=1.0/2.0)
            f_pred, Pxx_pred = signal.welch(pred_flat[mask], fs=1.0/2.0)
            plt.semilogy(f_true, Pxx_true, label='true', alpha=0.7)
            plt.semilogy(f_pred, Pxx_pred, label='pred', alpha=0.7)
            plt.title('power spectra')
            plt.xlabel('freq (hz)')
            plt.legend()

        plt.tight_layout()
        plt.show()

    def attention_analysis():
        x, task_ids, _ = next(iter(val_loader))
        x = x.to(model.device)
        task_ids = task_ids.to(model.device)

        attn_maps = []
        def hook_fn(module, input, output):
            attn = output[1].detach().cpu()[0]
            if not torch.isnan(attn).any():
                attn_maps.append(attn)

        hooks = []
        for block in model.blocks:
            hooks.append(block.attn.register_forward_hook(hook_fn))

        with torch.no_grad():
            model(x, task_ids)
        for h in hooks:
            h.remove()

        if not attn_maps:
            print("no valid attention maps found")
            return

        plt.figure(figsize=(20,5))
        for i, attn in enumerate(attn_maps):
            plt.subplot(1, len(attn_maps), i+1)
            attn_viz = attn.reshape(1281, -1) if len(attn.shape) == 1 else attn
            plt.imshow(attn_viz.numpy(), aspect='auto', cmap='RdBu_r')
            plt.title(f'L{i+1}')
            if i == 0:
                plt.ylabel('query')
            plt.xlabel('key')
        plt.tight_layout()
        plt.show()

        if len(attn_maps) > 1:
            plt.figure(figsize=(15,5))
            plt.subplot(131)
            means = [attn.mean().item() for attn in attn_maps]
            plt.plot(means, 'o-')
            plt.title('mean attention')
            plt.xlabel('layer')

            plt.subplot(132)
            sparsity = [(attn < 0.1).float().mean().item() for attn in attn_maps]
            plt.plot(sparsity, 'o-')
            plt.title('sparsity')
            plt.xlabel('layer')

            plt.subplot(133)
            valid_eigs = []
            for attn in attn_maps:
                try:
                    eigs = torch.linalg.eigvalsh(attn)[-5:]
                    if not torch.isnan(eigs).any():
                        valid_eigs.append(eigs)
                except:
                    continue
            if valid_eigs:
                plt.plot(torch.stack(valid_eigs).cpu().numpy())
                plt.title('top-5 eigenvalues')
                plt.xlabel('layer')

            plt.tight_layout()
            plt.show()

    def network_analysis():
        param_norms = []
        grad_norms = []
        names = []

        for name, p in model.named_parameters():
            if p.requires_grad and not torch.isnan(p).any():
                param_norms.append(p.norm().item())
                if p.grad is not None and not torch.isnan(p.grad).any():
                    grad_norms.append(p.grad.norm().item())
                    names.append(name.split('.')[0])

        if not names:
            print("no valid params found")
            return

        plt.figure(figsize=(10,5))

        plt.subplot(121)
        if len(param_norms) > 1:
            sns.boxplot(x=names, y=param_norms)
            plt.xticks(rotation=45)
            plt.title('param norms')

        plt.subplot(122)
        if len(grad_norms) > 1:
            sns.boxplot(x=names, y=grad_norms)
            plt.xticks(rotation=45)
            plt.title('grad norms')

        plt.tight_layout()
        plt.show()

    print("temporal:")
    temporal_analysis()
    print("\nattention:")
    attention_analysis()
    print("\nnetwork:")
    network_analysis()

plot_comprehensive_analysis(model, val_loader)