In [1]:
# =========================
# Cell 0: Install dependencies (run once)
# =========================
!pip install --quiet nibabel tqdm scikit-image opencv-python matplotlib scikit-learn torch torchvision snntorch timm

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/125.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m122.9/125.6 kB[0m [31m4.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
# =========================
# Cell 1: Mount Google Drive (if needed)
# =========================
from google.colab import drive
drive.mount('/content/drive')   # follow prompts

# Example structure expected (recommended):
# /content/drive/MyDrive/MM_NII/
#   ├── subj_0001/
#   │     ├── T1.nii.gz
#   │     ├── T2.nii.gz
#   │     └── FLAIR.nii.gz
#   ├── subj_0002/
#   │     ├── T1.nii.gz
#   │     ├── T2.nii.gz
#   │     └── FLAIR.nii.gz
#   └── ...

# If your data is flat (all files in one folder), see the NOTE inside Cell 3 to adapt matching by filename substrings.


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
# =========================
# Cell 2: Imports + device + global settings
# =========================
import os, glob, math, random, time, re
from tqdm import tqdm
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import cv2
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from sklearn.metrics.pairwise import cosine_similarity

# Repro
random.seed(42); np.random.seed(42); torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# ============ EDIT ME ============
DATA_DIR = '/content/drive/MyDrive/MRI Extracted'   # root folder containing subject subfolders
MODALITIES = ['T1', 'T2', 'FLAIR']           # choose the modalities you want to fuse as channels
BASE_MODALITY_FOR_VIS = 'T1'                 # which modality to use for overlays/visuals
# =================================

# Model/runtime sizes (shrink if OOM)
IMG_SIZE = 224
PATCH_SIZE = 16
EMBED_DIM = 128      # try 64 to save RAM
NUM_LAYERS = 4
NUM_HEADS = 8
BATCH_SIZE = 4       # reduce if OOM
TTFS_T = 30
SNN_OUT_NEURONS = 128

assert IMG_SIZE % PATCH_SIZE == 0, "IMG_SIZE must be divisible by PATCH_SIZE"


Device: cpu


In [5]:
# =========================
# Cell 3: Index subjects and their modality files
# =========================
def find_subjects_with_modalities(root, modalities):
    """
    Expected: root/subj_xxx/<files containing modality names>
    We match files by the presence of modality keyword (case-insensitive) in filename.
    """
    subjects = []
    subdirs = sorted([d for d in glob.glob(os.path.join(root, '*')) if os.path.isdir(d)])
    for sdir in subdirs:
        sid = os.path.basename(sdir)
        mod2file = {}
        for m in modalities:
            # match file that contains the modality token (case-insensitive)
            cand = sorted(glob.glob(os.path.join(sdir, f'*{m}*.nii*')))  # flexible: .nii or .nii.gz
            if len(cand) == 0:
                mod2file[m] = None
            else:
                mod2file[m] = cand[0]
        # keep subject only if all required modalities exist
        if all(mod2file[m] is not None for m in modalities):
            subjects.append((sid, mod2file))
    return subjects

subjects = find_subjects_with_modalities(DATA_DIR, MODALITIES)
print(f"Found {len(subjects)} subjects with modalities {MODALITIES}")

# NOTE (flat folder case):
# If your data is all in one folder, create a CSV mapping or use filename rules to group by subject_id.
# Example heuristic: subject id is the prefix before first '_' and modality appears in name.
# Then, build {subject_id: {mod: path}} and keep only subjects containing all required modalities.


Found 0 subjects with modalities ['T1', 'T2', 'FLAIR']


In [6]:
# =========================
# Cell 4: Utilities — load center slice aligned across modalities + preprocessing
# =========================
def robust_minmax(x, lo=1, hi=99):
    """Percentile min-max -> [0,1]"""
    x = np.nan_to_num(x)
    lo_v, hi_v = np.percentile(x, [lo, hi])
    if hi_v - lo_v < 1e-8:
        return np.zeros_like(x, dtype=np.float32)
    x = np.clip(x, lo_v, hi_v)
    x = (x - lo_v) / (hi_v - lo_v + 1e-8)
    return x.astype(np.float32)

def load_center_slice_volume(path):
    vol = nib.load(path).get_fdata()
    z = vol.shape[2] // 2
    sl = vol[:, :, z]
    return sl  # 2D slice

def resize_2d(img2d, size=IMG_SIZE):
    return cv2.resize(img2d, (size, size), interpolation=cv2.INTER_CUBIC).astype(np.float32)

def build_multimodal_center_slice(mod2file, modalities, ref_modal):
    """
    Returns a stacked array shape (C, H, W) with C=len(modalities).
    Uses the z-center slice index from the reference modality for alignment.
    """
    # reference z from ref modality
    ref_vol = nib.load(mod2file[ref_modal]).get_fdata()
    z_ref = ref_vol.shape[2] // 2

    chans = []
    for m in modalities:
        vol = nib.load(mod2file[m]).get_fdata()
        z = min(z_ref, vol.shape[2]-1)  # clamp if depths differ
        sl = vol[:, :, z]
        sl = robust_minmax(sl)          # [0,1]
        sl = resize_2d(sl, IMG_SIZE)    # HxW
        chans.append(sl)
    mm = np.stack(chans, axis=0)        # CxHxW
    return mm

# Quick sanity visualization on the first subject
if len(subjects) > 0:
    sid0, mfiles0 = subjects[0]
    mm0 = build_multimodal_center_slice(mfiles0, MODALITIES, BASE_MODALITY_FOR_VIS)
    print("Sample multi-modal tensor shape (C,H,W):", mm0.shape)
    # show each modality
    n = mm0.shape[0]
    plt.figure(figsize=(4*n, 4))
    for i in range(n):
        plt.subplot(1, n, i+1)
        plt.imshow(mm0[i], cmap='gray')
        plt.title(f'{sid0} - {MODALITIES[i]} (center slice)'); plt.axis('off')
    plt.show()


In [7]:
# =========================
# Cell 5: Dataset for multimodal center slices
# =========================
class MultiModalCenterSliceDataset(Dataset):
    def __init__(self, subjects, modalities, ref_modal, transform=None):
        self.subjects = subjects
        self.modalities = modalities
        self.ref_modal = ref_modal
        self.transform = transform  # not used (we pre-resize), but kept for API symmetry
    def __len__(self):
        return len(self.subjects)
    def __getitem__(self, idx):
        sid, mod2file = self.subjects[idx]
        mm = build_multimodal_center_slice(mod2file, self.modalities, self.ref_modal)  # CxHxW in [0,1]
        # Normalize to [-1,1] per channel to match earlier code
        mm = (mm * 2.0 - 1.0).astype(np.float32)
        x = torch.from_numpy(mm)  # CxHxW
        return {'img': x, 'sid': sid, 'mod2file': mod2file}

dataset = MultiModalCenterSliceDataset(subjects, MODALITIES, BASE_MODALITY_FOR_VIS)
print("Dataset size (subjects):", len(dataset))

# Visualize one subject stacked as RGB-like (if exactly 3 modalities)
if len(dataset) > 0:
    sample = dataset[0]
    x = sample['img'].numpy()  # CxHxW in [-1,1]
    # show channels individually (de-normalize to [0,1] for display)
    n = x.shape[0]
    plt.figure(figsize=(4*n, 4))
    for i in range(n):
        plt.subplot(1, n, i+1)
        plt.imshow((x[i]*0.5+0.5), cmap='gray'); plt.axis('off'); plt.title(MODALITIES[i])
    plt.suptitle(f"Subject: {sample['sid']}  (multi-modal center slices)")
    plt.show()


Dataset size (subjects): 0


In [10]:
import os

root = "/content/drive/MyDrive/MRI Extracted"   # change to your actual path
all_files = []
for path, dirs, files in os.walk(root):
    for f in files:
        if f.endswith(".nii") or f.endswith(".nii.gz"):
            all_files.append(os.path.join(path, f))

print("Found files:", len(all_files))
print(all_files[:10])

Found files: 180
['/content/drive/MyDrive/MRI Extracted/patient6mri.nii', '/content/drive/MyDrive/MRI Extracted/patient8mri.nii', '/content/drive/MyDrive/MRI Extracted/patient2mri.nii', '/content/drive/MyDrive/MRI Extracted/patient4mri.nii', '/content/drive/MyDrive/MRI Extracted/patient5mri.nii', '/content/drive/MyDrive/MRI Extracted/patient1mri.nii', '/content/drive/MyDrive/MRI Extracted/patient7mri.nii', '/content/drive/MyDrive/MRI Extracted/patient3mri.nii', '/content/drive/MyDrive/MRI Extracted/patient26mri.nii', '/content/drive/MyDrive/MRI Extracted/patient12mri.nii']


In [15]:
class MultiModalDataset(torch.utils.data.Dataset):
    def __init__(self, subjects, modalities, ref_modal):
        self.subjects = subjects
        self.modalities = modalities
        self.ref_modal = ref_modal

    def __len__(self):
        return len(self.subjects)

    def __getitem__(self, idx):
        if len(self.subjects) == 0:
            raise ValueError("❌ No subjects found! Check dataset path and preprocessing.")
        sid, mod2file = self.subjects[idx]
        mm = build_multimodal_center_slice(mod2file, self.modalities, self.ref_modal)
        mm = (mm - 0.5) / 0.5
        return {'id': sid, 'img': mm}

In [16]:
# =========================
# Cell 6: Conv Tokenizer (multi-channel) + patch grid visualization
# =========================
class ConvTokenizer(nn.Module):
    def __init__(self, in_ch, embed_dim=EMBED_DIM, patch_size=PATCH_SIZE, img_size=IMG_SIZE):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_ch, embed_dim, kernel_size=patch_size, stride=patch_size)
        n_patches = (img_size // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
    def forward(self, x):
        x = self.proj(x)                   # B x E x H' x W'
        B, C, H, W = x.shape
        x_flat = x.flatten(2).transpose(1,2)   # B x N x E
        cls = self.cls_token.expand(B, -1, -1)
        tokens = torch.cat([cls, x_flat], dim=1)
        tokens = tokens + self.pos_embed[:, :tokens.size(1), :]
        return tokens, x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=EMBED_DIM, num_heads=NUM_HEADS, mlp_dim=EMBED_DIM*2, dropout=0.0):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim), nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(embed_dim)
    def forward(self, x):
        y, attn = self.attn(x, x, x, need_weights=True, average_attn_weights=False)
        x = self.norm1(x + y)
        x = self.norm2(x + self.mlp(x))
        return x, attn

class SmallViT(nn.Module):
    def __init__(self, in_ch, embed_dim=EMBED_DIM, patch_size=PATCH_SIZE, img_size=IMG_SIZE, num_layers=NUM_LAYERS, num_heads=NUM_HEADS):
        super().__init__()
        self.tokenizer = ConvTokenizer(in_ch, embed_dim, patch_size, img_size)
        self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads) for _ in range(num_layers)])
    def forward(self, x, return_attn=False):
        tokens, conv_feat = self.tokenizer(x)
        attn_maps = []
        for blk in self.blocks:
            tokens, attn = blk(tokens)
            attn_maps.append(attn)
        pooled = tokens[:, 1:, :].mean(1)
        if return_attn:
            return pooled, tokens, attn_maps, conv_feat
        return pooled, tokens, None, conv_feat

model = SmallViT(in_ch=len(MODALITIES)).to(device)
model.eval()

# one forward + visuals
item0 = dataset[0]
img0 = item0['img'].unsqueeze(0).to(device)   # 1 x C x H x W
with torch.no_grad():
    pooled0, tokens0, attn_maps0, conv_feat0 = model(img0, return_attn=True)

print('tokens shape:', tokens0.shape, 'conv_feat shape:', conv_feat0.shape, 'pooled shape:', pooled0.shape)

# visualize conv feature map (avg channels)
feat_map = conv_feat0[0].mean(0).detach().cpu().numpy()
plt.figure(figsize=(4,4)); plt.imshow(feat_map, cmap='inferno'); plt.title('Conv feature map (avg over embed channels)'); plt.axis('off'); plt.show()

# overlay patch grid on BASE modality
base_idx = MODALITIES.index(BASE_MODALITY_FOR_VIS)
base_img_vis = (item0['img'][base_idx].numpy()*0.5+0.5)
plt.figure(figsize=(4,4)); plt.imshow(base_img_vis, cmap='gray')
p = IMG_SIZE // PATCH_SIZE
for i in range(1, p):
    plt.axvline(i*PATCH_SIZE, color='cyan', linewidth=0.6)
    plt.axhline(i*PATCH_SIZE, color='cyan', linewidth=0.6)
plt.title(f'Patch grid ({BASE_MODALITY_FOR_VIS})'); plt.axis('off'); plt.show()


IndexError: list index out of range

In [None]:
# =========================
# Cell 7: Attention visualization + rollout overlay (on base modality)
# =========================
def attention_rollout(attn_maps):
    """
    attn_maps: list of (B, heads, S, S)
    returns: (S, S) rollout for batch 0
    """
    result = None
    for att in attn_maps:
        a = att.mean(dim=1).detach().cpu().numpy()  # B x S x S
        a = a + np.eye(a.shape[-1])[None]
        a = a / (a.sum(axis=-1, keepdims=True) + 1e-8)
        result = a if result is None else np.matmul(result, a)
    return result[0]  # SxS

roll0 = attention_rollout(attn_maps0)
cls_to_patch = roll0[0, 1:]
patch_count = cls_to_patch.shape[0]
pgrid = int(math.sqrt(patch_count))
if pgrid*pgrid == patch_count:
    heat = cls_to_patch.reshape(pgrid, pgrid)
    heat_resized = cv2.resize(heat, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_CUBIC)

    plt.figure(figsize=(6,6))
    plt.imshow(base_img_vis, cmap='gray')
    plt.imshow(heat_resized, cmap='jet', alpha=0.45)
    plt.axis('off'); plt.title('Attention rollout overlay (base modality)')
    plt.show()
else:
    print("Warning: non-square patch grid; adjust IMG_SIZE/PATCH_SIZE.")


In [None]:
# =========================
# Cell 8: TTFS encoding (pooled embedding -> spike raster) + visualization
# =========================
def ttfs_encode_vector(vec, T=TTFS_T):
    vec = vec.astype(np.float32)
    mn, mx = vec.min(), vec.max()
    norm = np.zeros_like(vec) if (mx-mn)<1e-8 else (vec-mn)/(mx-mn)
    spike_times = np.round((1.0 - norm) * (T-1)).astype(int)
    spike_train = np.zeros((T, vec.shape[0]), dtype=np.uint8)
    for i, t in enumerate(spike_times):
        spike_train[int(t), i] = 1
    return spike_train, spike_times

pooled_np0 = pooled0[0].cpu().numpy()
spikes0, stimes0 = ttfs_encode_vector(pooled_np0)

plt.figure(figsize=(8,5))
plt.imshow(spikes0.T, aspect='auto', cmap='gray_r')
plt.xlabel('Time'); plt.ylabel('Neuron index'); plt.title('TTFS raster (pooled embedding)')
plt.show()
