In [5]:


from __future__ import annotations
import os, re, json
from glob import glob
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

import numpy as np
import pandas as pd
import nibabel as nib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import train_test_split

# XAI
from captum.attr import IntegratedGradients, LayerGradCam

# Optional resampling utilities
from nilearn.image import resample_to_img

In [44]:
# --------------------------
# Config
# --------------------------
@dataclass
class Config:
    # Root directory containing all subjects' first‑level outputs
    beta_root: str = "/local/anpa/ds003242-1.0.0/derivatives/firstlevel_separate_runs"

    # Which files to use (glob relative to beta_root). Default: effect‑size maps.
    # Examples: "**/*_effsize.nii.gz" or "**/*_zmap.nii.gz" or "**/*_tstat.nii.gz"
    file_glob: str = "**/*_effsize.nii.gz"

    # Regex for extracting labels from file names
    # This captures tokens like 'Food_1', 'Social_3', 'Control_2' and collapses to base class
    label_regex: str = r"_(Food|Social|Control)_(?:[123])_"

    # Extract run id (first token like '0_' or '1_'). Used for LORO CV or metadata only.
    run_regex: str = r"/(\d+)_"

    # Extract subject id from path (directories named sub-XXXX)
    subj_regex: str = r"/(sub-[^/]+)/"

    # If not None, resample all images to this reference NIfTI (common space).
    resample_ref_nii: Optional[str] = None

    # Mask: if provided, will be resampled to ref and applied; else auto‑mask = nonzero voxels
    mask_path: Optional[str] = None

    # Training
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    epochs: int = 12
    batch_size: int = 4
    lr: float = 1e-3
    weight_decay: float = 1e-4
    num_workers: int = 4
    seed: int = 42

    # CV mode: 'loso' (leave‑one‑subject‑out) or 'loro' (leave‑one‑run‑out within a chosen subject)
    cv_mode: str = 'loso'

    # Output
    out_dir: str = "./beta_cnn_outputs"

cfg = Config()
os.makedirs(cfg.out_dir, exist_ok=True)

In [67]:
# --------------------------
# Index beta maps and build a dataframe
# --------------------------

def index_beta_maps(cfg: Config) -> pd.DataFrame:
    paths = sorted(glob(os.path.join(cfg.beta_root, cfg.file_glob), recursive=True))
    if len(paths) == 0:
        raise FileNotFoundError("No beta maps found. Check cfg.beta_root and cfg.file_glob.")

    rows = []
    for p in paths:
        m_label = re.search(cfg.label_regex, p)
        if not m_label:
            continue
        label = m_label.group(1)  # 'Food'/'Social'/'Control'
        m_run = re.search(cfg.run_regex, p)
        run_id = m_run.group(1) 
        m_sub = re.search(cfg.subj_regex, p)
        subj = m_sub.group(1) 
        condition = 'baseline' if subj.endswith('b') else 'isolation' if subj.endswith('s') else 'fasting'
        rows.append({"path": p, "label": label, "run": run_id, "subject": subj, 'Condition': condition})

    df = pd.DataFrame(rows)
    if df.empty:
        raise RuntimeError("No files matched the label regex; inspect cfg.label_regex and filenames.")
    return df

In [68]:
df = index_beta_maps(cfg)
df

Unnamed: 0,path,label,run,subject,Condition
0,/local/anpa/ds003242-1.0.0/derivatives/firstle...,Control,0,sub-SAXSISO01b,baseline
1,/local/anpa/ds003242-1.0.0/derivatives/firstle...,Control,0,sub-SAXSISO01b,baseline
2,/local/anpa/ds003242-1.0.0/derivatives/firstle...,Control,0,sub-SAXSISO01b,baseline
3,/local/anpa/ds003242-1.0.0/derivatives/firstle...,Food,0,sub-SAXSISO01b,baseline
4,/local/anpa/ds003242-1.0.0/derivatives/firstle...,Food,0,sub-SAXSISO01b,baseline
...,...,...,...,...,...
5152,/local/anpa/ds003242-1.0.0/derivatives/firstle...,Food,5,sub-SAXSISO42s,isolation
5153,/local/anpa/ds003242-1.0.0/derivatives/firstle...,Food,5,sub-SAXSISO42s,isolation
5154,/local/anpa/ds003242-1.0.0/derivatives/firstle...,Social,5,sub-SAXSISO42s,isolation
5155,/local/anpa/ds003242-1.0.0/derivatives/firstle...,Social,5,sub-SAXSISO42s,isolation


In [69]:
df.groupby(["subject", "label"]).size().unstack(fill_value=0)

label,Control,Food,Social
subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
sub-SAXSISO01b,18,18,18
sub-SAXSISO01f,18,18,18
sub-SAXSISO01s,18,18,18
sub-SAXSISO02b,18,18,18
sub-SAXSISO02f,18,18,18
...,...,...,...
sub-SAXSISO41f,18,18,18
sub-SAXSISO41s,18,18,18
sub-SAXSISO42b,18,18,18
sub-SAXSISO42f,18,18,18


In [72]:
# --------------------------
# Mask utilities
# --------------------------

def load_mask(ref_img: nib.Nifti1Image, mask_path: Optional[str]) -> np.ndarray:
    if mask_path is None:
        data = ref_img.get_fdata()
        if data.ndim == 4:
            m = data.mean(axis=-1) != 0
        else:
            m = data != 0
        return m.astype(bool)
    else:
        mask_img = nib.load(mask_path)
        if mask_img.shape != ref_img.shape[:3]:
            mask_img = resample_to_img(mask_img, ref_img, interpolation='nearest')
        return (mask_img.get_fdata() > 0).astype(bool)

In [73]:
# --------------------------
# Dataset: lazy load + per‑sample normalization
# --------------------------
class BetaDataset(Dataset):
    def __init__(self, df: pd.DataFrame, label_to_idx: Dict[str,int], mask: Optional[np.ndarray]=None,
                 resample_ref_img: Optional[nib.Nifti1Image]=None):
        self.df = df.reset_index(drop=True)
        self.label_to_idx = label_to_idx
        self.mask = mask
        self.resample_ref_img = resample_ref_img

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = nib.load(row.path)
        if self.resample_ref_img is not None and img.shape != self.resample_ref_img.shape:
            img = resample_to_img(img, self.resample_ref_img)
        vol = img.get_fdata().astype(np.float32)
        # apply mask (auto or provided)
        if self.mask is None:
            mask = (vol != 0)
        else:
            mask = self.mask
        vol = np.where(mask, vol, 0)
        # per‑sample z‑score within mask
        m = vol[mask]
        mu, sigma = float(m.mean()), float(m.std() + 1e-6)
        vol = (vol - mu) / sigma
        vol = np.expand_dims(vol, 0)  # (1, D, H, W)
        y = self.label_to_idx[row.label]
        meta = {"subject": row.subject, "run": row.run, "path": row.path, "label": row.label}
        return torch.from_numpy(vol).float(), torch.tensor(y).long(), meta

In [74]:
# --------------------------
# 3D‑CNN (tiny)
# --------------------------
class Tiny3DCNN(nn.Module):
    def __init__(self, n_classes: int):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 8, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(8)
        self.conv2 = nn.Conv3d(8, 16, 3, padding=1)
        self.bn2 = nn.BatchNorm3d(16)
        self.conv3 = nn.Conv3d(16, 32, 3, padding=1)
        self.bn3 = nn.BatchNorm3d(32)
        self.pool = nn.MaxPool3d(2)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Linear(32, n_classes)
        )
    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        return self.head(x)


In [None]:


# --------------------------
# 3D‑CNN (deeper, residual)
# --------------------------
class BasicBlock3D(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, dropout_p=0.0):
        super().__init__()
        self.conv1 = nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm3d(out_ch)
        self.conv2 = nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm3d(out_ch)
        self.relu  = nn.ReLU(inplace=True)
        self.do    = nn.Dropout3d(p=dropout_p) if dropout_p > 0 else nn.Identity()
        self.down  = None
        if stride != 1 or in_ch != out_ch:
            self.down = nn.Sequential(
                nn.Conv3d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(out_ch)
            )
    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.do(out)
        out = self.bn2(self.conv2(out))
        if self.down is not None:
            identity = self.down(identity)
        out += identity
        return self.relu(out)

class Bigger3DResNet(nn.Module):
    """A compact ResNet‑style 3D CNN: stem → stages [16,32,64,128] → GAP → FC.
    Much larger capacity than Tiny3DCNN but still light enough for beta volumes.
    """
    def __init__(self, n_classes: int, dropout_p: float = 0.2):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm3d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        )
        # stages
        self.layer1 = nn.Sequential(
            BasicBlock3D(16, 16, stride=1, dropout_p=dropout_p),
            BasicBlock3D(16, 16, stride=1, dropout_p=dropout_p),
        )
        self.layer2 = nn.Sequential(
            BasicBlock3D(16, 32, stride=2, dropout_p=dropout_p),
            BasicBlock3D(32, 32, stride=1, dropout_p=dropout_p),
        )
        self.layer3 = nn.Sequential(
            BasicBlock3D(32, 64, stride=2, dropout_p=dropout_p),
            BasicBlock3D(64, 64, stride=1, dropout_p=dropout_p),
        )
        self.layer4 = nn.Sequential(
            BasicBlock3D(64, 128, stride=2, dropout_p=dropout_p),
            BasicBlock3D(128, 128, stride=1, dropout_p=dropout_p),
        )
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Dropout(p=dropout_p),
            nn.Linear(128, n_classes)
        )
    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return self.head(x)



In [75]:
# --------------------------
# Train / eval helpers
# --------------------------

def train(model, train_loader, val_loader, device, epochs=12, lr=1e-3, weight_decay=1e-4):
    torch.manual_seed(cfg.seed)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    crit = nn.CrossEntropyLoss()
    best, best_path = -np.inf, os.path.join(cfg.out_dir, 'best_model.pt')

    for ep in range(1, epochs+1):
        model.train(); tot=0; cor=0; loss_sum=0.0
        for xb, yb, _ in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(); logits = model(xb); loss = crit(logits, yb); loss.backward(); opt.step()
            loss_sum += float(loss.item()) * xb.size(0)
            cor += (logits.argmax(1) == yb).sum().item(); tot += xb.size(0)
        tr_acc, tr_loss = cor/max(tot,1), loss_sum/max(tot,1)

        model.eval(); tot=0; cor=0
        with torch.no_grad():
            for xb, yb, _ in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                logits = model(xb)
                cor += (logits.argmax(1) == yb).sum().item(); tot += xb.size(0)
        va_acc = cor/max(tot,1)
        print(f"Epoch {ep:02d} | train loss {tr_loss:.4f} acc {tr_acc:.3f} | val acc {va_acc:.3f}")
        if va_acc > best:
            best = va_acc; torch.save(model.state_dict(), best_path)
    print(f"Best val acc: {best:.3f}; saved {best_path}")
    return best_path


def evaluate(model, loader, device, label_names: List[str]):
    model.eval()
    ys, ps = [], []
    metas = []
    with torch.no_grad():
        for xb, yb, meta in loader:
            xb = xb.to(device)
            logits = model(xb)
            ps.extend(logits.argmax(1).cpu().numpy().tolist())
            ys.extend(yb.numpy().tolist())
            metas.extend(meta)
    cm = confusion_matrix(ys, ps)
    print("Confusion matrix:\n", cm)
    print(classification_report(ys, ps, target_names=label_names, digits=3))
    return cm, ys, ps, metas


In [76]:
# --------------------------
# XAI: Integrated Gradients & Layer Grad‑CAM
# --------------------------

def explain_IG(model, volume_3d: np.ndarray, target_class: int, device: str) -> np.ndarray:
    model.eval()
    tens = torch.from_numpy(volume_3d[None, None].astype(np.float32)).to(device)
    ig = IntegratedGradients(model)
    attr = ig.attribute(tens, target=target_class, baselines=torch.zeros_like(tens), n_steps=64)
    return attr.detach().cpu().numpy()[0,0]


def explain_GradCAM(model, volume_3d: np.ndarray, target_class: int, device: str, layer: nn.Module) -> np.ndarray:
    model.eval()
    tens = torch.from_numpy(volume_3d[None, None].astype(np.float32)).to(device)
    lgc = LayerGradCam(model, layer)
    attr = lgc.attribute(tens, target=target_class)
    up = F.interpolate(attr, size=volume_3d.shape, mode='trilinear', align_corners=False)
    return up.detach().cpu().numpy()[0,0]


def save_as_nii(arr: np.ndarray, like_path: str, out_path: str):
    ref = nib.load(like_path)
    nib.save(nib.Nifti1Image(arr.astype(np.float32), affine=ref.affine, header=ref.header), out_path)

In [77]:
# --------------------------
# Build splits (LOSO or LORO)
# --------------------------

def make_splits(df: pd.DataFrame, mode: str='loso', seed: int=42):
    if mode == 'loso':
        subjects = sorted(df.subject.unique())
        # pick the last subject as test (example); loop over subjects for full CV
        test_subj = subjects[-1]
        tr = df[df.subject != test_subj].reset_index(drop=True)
        te = df[df.subject == test_subj].reset_index(drop=True)
        # split train into train/val stratified by label
        tr_idx = np.arange(len(tr))
        tr_idx, va_idx = train_test_split(tr_idx, test_size=0.2, stratify=tr.label, random_state=seed)
        return tr.iloc[tr_idx], tr.iloc[va_idx], te, {"test_subject": test_subj}
    elif mode == 'loro':
        # assume single subject; hold out one run id
        runs = sorted(df.run.unique())
        test_run = runs[-1]
        tr = df[df.run != test_run].reset_index(drop=True)
        te = df[df.run == test_run].reset_index(drop=True)
        tr_idx = np.arange(len(tr))
        tr_idx, va_idx = train_test_split(tr_idx, test_size=0.2, stratify=tr.label, random_state=seed)
        return tr.iloc[tr_idx], tr.iloc[va_idx], te, {"test_run": test_run}
    else:
        raise ValueError("mode must be 'loso' or 'loro'")

In [78]:
# --------------------------
# MAIN
# --------------------------
if __name__ == "__main__":
    np.random.seed(cfg.seed); torch.manual_seed(cfg.seed)

    df = index_beta_maps(cfg)

    # Create label mapping
    labels = sorted(df.label.unique().tolist())
    label_to_idx = {c:i for i,c in enumerate(labels)}
    idx_to_label = {i:c for c,i in label_to_idx.items()}
    with open(os.path.join(cfg.out_dir, "label_mapping.json"), "w") as f:
        json.dump(idx_to_label, f, indent=2)

    # Choose reference image for resampling/masking
    ref_path = df.path.iloc[0]
    ref_img = nib.load(ref_path)
    mask = load_mask(ref_img, cfg.mask_path)
    resample_ref_img = None
    if cfg.resample_ref_nii is not None:
        resample_ref_img = nib.load(cfg.resample_ref_nii)
        mask = load_mask(resample_ref_img, cfg.mask_path)

    # Splits
    train_df, val_df, test_df, split_meta = make_splits(df, mode=cfg.cv_mode, seed=cfg.seed)
    print("Split info:", split_meta)
    print("Train/Val/Test sizes:", len(train_df), len(val_df), len(test_df))

    # Datasets
    ds_tr = BetaDataset(train_df, label_to_idx, mask=mask, resample_ref_img=resample_ref_img)
    ds_va = BetaDataset(val_df,   label_to_idx, mask=mask, resample_ref_img=resample_ref_img)
    ds_te = BetaDataset(test_df,  label_to_idx, mask=mask, resample_ref_img=resample_ref_img)

    # Loaders
    dl_tr = DataLoader(ds_tr, batch_size=cfg.batch_size, shuffle=True,  num_workers=cfg.num_workers)
    dl_va = DataLoader(ds_va, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)
    dl_te = DataLoader(ds_te, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

    # Model
    model = Tiny3DCNN(n_classes=len(labels)).to(cfg.device)

    # Train
    best_path = train(model, dl_tr, dl_va, cfg.device, epochs=cfg.epochs, lr=cfg.lr, weight_decay=cfg.weight_decay)
    model.load_state_dict(torch.load(best_path, map_location=cfg.device))

    # Evaluate
    cm, ys, ps, metas = evaluate(model, dl_te, cfg.device, labels)

    # --------------
    # XAI on a few test samples (one per class if available)
    # --------------
    os.makedirs(cfg.out_dir, exist_ok=True)
    # Collect one example path per predicted class
    examples: Dict[str, Dict] = {}
    for y, p, meta in zip(ys, ps, metas):
        cls = labels[p]
        if cls not in examples:
            examples[cls] = {"meta": meta, "y": y, "p": p}
        if len(examples) == len(labels):
            break

    # Generate and save attributions
    for cls, ex in examples.items():
        # Load the original volume as the model saw it (with mask + zscore)
        img = nib.load(ex["meta"]["path"])
        if resample_ref_img is not None and img.shape != resample_ref_img.shape:
            img = resample_to_img(img, resample_ref_img)
        vol = img.get_fdata().astype(np.float32)
        vol = np.where(mask, vol, 0)
        m = vol[mask]; vol = (vol - float(m.mean()))/(float(m.std())+1e-6)

        target_idx = ex["p"]  # explain the predicted class; change to ex["y"] to explain true class
        ig_map = explain_IG(model, vol, target_idx, cfg.device)
        gc_map = explain_GradCAM(model, vol, target_idx, cfg.device, layer=model.conv3)

        base = os.path.splitext(os.path.basename(ex["meta"]["path"]))[0]
        ig_path = os.path.join(cfg.out_dir, f"xai_IG_{cls}_{base}.nii.gz")
        gc_path = os.path.join(cfg.out_dir, f"xai_GradCAM_{cls}_{base}.nii.gz")
        save_as_nii(ig_map, ex["meta"]["path"], ig_path)
        save_as_nii(gc_map, ex["meta"]["path"], gc_path)
        print(f"Saved XAI maps for class {cls}:\n  {ig_path}\n  {gc_path}")

    # Save split metadata for provenance
    with open(os.path.join(cfg.out_dir, "split_meta.json"), "w") as f:
        json.dump(split_meta, f, indent=2)

    print("Done.")

Split info: {'test_subject': 'sub-SAXSISO42s'}
Train/Val/Test sizes: 4082 1021 54
Epoch 01 | train loss 1.1044 acc 0.338 | val acc 0.344
Epoch 02 | train loss 1.1023 acc 0.339 | val acc 0.332
Epoch 03 | train loss 1.0977 acc 0.361 | val acc 0.356
Epoch 04 | train loss 1.0959 acc 0.353 | val acc 0.338
Epoch 05 | train loss 1.0967 acc 0.368 | val acc 0.324
Epoch 06 | train loss 1.0958 acc 0.351 | val acc 0.331
Epoch 07 | train loss 1.0962 acc 0.357 | val acc 0.350
Epoch 08 | train loss 1.0934 acc 0.371 | val acc 0.341
Epoch 09 | train loss 1.0934 acc 0.368 | val acc 0.341
Epoch 10 | train loss 1.0940 acc 0.361 | val acc 0.346
Epoch 11 | train loss 1.0922 acc 0.370 | val acc 0.343
Epoch 12 | train loss 1.0929 acc 0.373 | val acc 0.339
Best val acc: 0.356; saved ./beta_cnn_outputs/best_model.pt
Confusion matrix:
 [[ 9  0  9]
 [10  0  8]
 [ 7  0 11]]
              precision    recall  f1-score   support

     Control      0.346     0.500     0.409        18
        Food      0.000     0.000

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


TypeError: string indices must be integers

In [81]:
examples

{'Control': {'meta': 'subject', 'y': 0, 'p': 0},
 'Social': {'meta': 'path', 'y': 0, 'p': 2}}