# ***Cell 1 — Imports and setup***

In [1]:
# Cell 1: Mount Drive, imports, paths
# !rm -rf /content/drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import nibabel as nib
import zipfile
import tempfile
import timm
import torchvision.transforms as T

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


CSV_PATH = "/content/drive/MyDrive/merged_ADNI_dataset.csv"
ZIP_PATH = "/content/drive/MyDrive/ADNI_clean_data.zip"
OUT_DIR = "/content/drive/MyDrive/ADNI_project_outputs"
os.makedirs(OUT_DIR, exist_ok=True)


MessageError: Error: credential propagation was unsuccessful

# Cell 2 — Load CSV and **inspect** *italicised text*

In [None]:
# Cell 2: Load CSV
if not os.path.exists(CSV_PATH):
    raise FileNotFoundError(f"CSV not found at {CSV_PATH}. Please check the path!")

df = pd.read_csv(CSV_PATH)
print("CSV rows:", len(df))
print("Columns:", df.columns.tolist())
print("Group counts:\n", df['Group'].value_counts())

df.head()


# ***Cell 3 — List NIfTI files in ZIP***

In [None]:
# Cell 3: List NIfTI files in ZIP
with zipfile.ZipFile(ZIP_PATH, 'r') as zf:
    nii_files = [n for n in zf.namelist() if n.lower().endswith(('.nii', '.nii.gz'))]

print("Total NIfTI files in ZIP:", len(nii_files))
print("First 10 files:\n", nii_files[:10])


# ***Cell 4 — Read NIfTI directly from ZIP***

In [None]:
# Cell 4: Robust NIfTI reader
import tempfile
import zipfile
import nibabel as nib
import os

# Filter valid files
with zipfile.ZipFile(ZIP_PATH, 'r') as zf:
    nii_files = [n for n in zf.namelist()
                 if n.lower().endswith(('.nii', '.nii.gz')) and '__macosx' not in n.lower()]

print("Total valid NIfTI files in ZIP:", len(nii_files))
print("First 10 files:\n", nii_files[:10])

def read_nifti_from_zip(zip_path, nii_filename):
    """
    Reads a NIfTI file (.nii or .nii.gz) directly from ZIP using a temporary file.
    Keeps original extension to avoid ImageFileError.
    """
    ext = os.path.splitext(nii_filename)[1]  # .nii or .gz
    if ext.lower() == '.gz':  # double extension .nii.gz
        suffix = '.nii.gz'
    else:
        suffix = '.nii'

    with zipfile.ZipFile(zip_path, 'r') as zf:
        with zf.open(nii_filename) as f:
            with tempfile.NamedTemporaryFile(suffix=suffix) as tmp:
                tmp.write(f.read())
                tmp.flush()
                nim = nib.load(tmp.name)
                vol = nim.get_fdata()
    return vol

# Quick test
try:
    vol = read_nifti_from_zip(ZIP_PATH, nii_files[0])
    print("First volume shape:", vol.shape)
except Exception as e:
    print("Failed to read first NIfTI:", e)


# ***Cell 5 — Convert 3D volume to 2D slices***

In [None]:
# Cell 5: Convert 3D volume to slices
def get_slices_from_volume(vol, n_slices=16, out_size=224):
    """
    vol: 3D numpy array (Z,H,W)
    n_slices: number of slices to sample
    out_size: resize slices to (H,W)
    """
    z = vol.shape[0]
    idxs = np.linspace(0, z-1, n_slices).astype(int)
    slices = []
    for i in idxs:
        sl = vol[i,:,:]
        sl = resize(sl, (out_size, out_size), preserve_range=True, anti_aliasing=True)
        sl = (sl - sl.mean()) / (sl.std() if sl.std() > 0 else 1.0)  # normalize
        sl3 = np.stack([sl, sl, sl], axis=0)  # 3 channels
        slices.append(sl3.astype(np.float32))
    return np.stack(slices, axis=0)  # (n_slices, 3, H, W)


# ***Cell 6 — Dataset class (reads slices + tabular features)bold text***

In [None]:
# # Cell 6: Slice-based Dataset
# # Cell 6: safer SliceDataset mapping by CSV filename
import torch
from torch.utils.data import Dataset
import numpy as np
import nibabel as nib
import zipfile
import tempfile
import cv2

class SliceDataset(Dataset):
    def __init__(self, df, zip_path, n_slices=16, tabular_cols=['Age','Sex','MMSCORE'], filename_col='nii_filename'):
        self.df = df.reset_index(drop=True)
        self.zip_path = zip_path
        self.n_slices = n_slices
        self.tabular_cols = tabular_cols
        self.filename_col = filename_col

        # Map groups to integer labels
        self.label_map = {'CN': 0, 'MCI': 1, 'AD': 2}
        self.df = self.df[self.df['Group'].isin(self.label_map.keys())].reset_index(drop=True)

        # I just wanted to ensure valid numeric MMSE and Age
        self.df['Age'] = pd.to_numeric(self.df['Age'], errors='coerce').fillna(self.df['Age'].median())
        self.df['MMSCORE'] = pd.to_numeric(self.df['MMSCORE'], errors='coerce').fillna(self.df['MMSCORE'].median())
        self.df['Sex'] = self.df['Sex'].map({'F': 0, 'M': 1}).fillna(0)

        from collections import Counter
        print("Label counts in dataset:", Counter(self.df['Group']))

    def __len__(self):
        return len(self.df)

    def read_nifti(self, nii_filename):
        """Read a NIfTI file directly from ZIP."""
        suffix = '.nii.gz' if nii_filename.endswith('.gz') else '.nii'
        with zipfile.ZipFile(self.zip_path, 'r') as zf:
            with zf.open(nii_filename) as f:
                with tempfile.NamedTemporaryFile(suffix=suffix) as tmp:
                    tmp.write(f.read())
                    tmp.flush()
                    img = nib.load(tmp.name).get_fdata()
        return img

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        nii_path = row[self.filename_col]
        vol = self.read_nifti(nii_path)

        # Normalize to 0-1
        vol = (vol - np.min(vol)) / (np.ptp(vol) + 1e-8)

        # Take n_slices evenly spaced across the z-axis
        z_slices = np.linspace(0, vol.shape[2] - 1, self.n_slices, dtype=int)
        slices = np.stack([cv2.resize(vol[:, :, z], (224, 224)) for z in z_slices], axis=0)

        # Convert to 3-channel (repeat)
        slices = np.repeat(slices[..., None], 3, axis=-1)  # (n_slices, 224, 224, 3)
        slices = torch.tensor(slices).permute(0, 3, 1, 2).float()  # (n_slices, 3, 224, 224)

        # Tabular data
        tab = torch.tensor([row[c] for c in self.tabular_cols], dtype=torch.float32)

        label = torch.tensor(self.label_map[row['Group']], dtype=torch.long)
        return {'slices': slices, 'tab': tab, 'label': label}



# ***Cell 7 — Model Definition (Slice CNN + Tabular Fusion)***

In [None]:
# Cell 7: Slice-based EfficientNet + Tabular Fusion
class SliceEfficientNetFusion(nn.Module):
    def __init__(self, backbone_name='efficientnet_b0', pretrained=True, n_slices=16, tabular_dim=3, fuse_dim=256, n_classes=3):
        super().__init__()
        self.n_slices = n_slices
        # Pretrained 2D CNN backbone (output feature vector)
        self.backbone = timm.create_model(backbone_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        feat_dim = self.backbone.num_features
        self.img_fc = nn.Linear(feat_dim, fuse_dim)

        # Tabular features MLP
        self.tab_fc = nn.Sequential(nn.Linear(tabular_dim, fuse_dim//2), nn.ReLU())

        # Fusion MLP
        self.fuse = nn.Sequential(
            nn.Linear(fuse_dim + fuse_dim//2, fuse_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # Final classifier
        self.classifier = nn.Linear(fuse_dim, n_classes)

    def forward(self, slices, tab):
        # slices: (B, n_slices, 3, H, W)
        B = slices.shape[0]
        x = slices.view(B * self.n_slices, 3, slices.shape[-2], slices.shape[-1])
        feats = self.backbone(x)  # (B*n_slices, feat_dim)
        feats = feats.view(B, self.n_slices, -1).mean(dim=1)  # average across slices
        img_emb = self.img_fc(feats)
        tab_emb = self.tab_fc(tab)
        f = torch.cat([img_emb, tab_emb], dim=1)
        f = self.fuse(f)
        out = self.classifier(f)
        return out

# Instantiate model
tabular_cols = ['Age','Sex','MMSCORE']  # adjust based on your CSV
model = SliceEfficientNetFusion(tabular_dim=len(tabular_cols), n_slices=16, n_classes=3).to(device)
print(model)


In [None]:
# --- Just new cell  -_- Added Later : Build filename mapping ---

import re
from collections import Counter

with zipfile.ZipFile(ZIP_PATH, 'r') as zf:
    nii_files = [
        n for n in zf.namelist()
        if n.lower().endswith(('.nii', '.nii.gz')) and '__macosx' not in n.lower()
    ]

print(f"Clean NIfTI count: {len(nii_files)}")

# Extract ID (e.g. I73937) from filenames
id_to_file = {}
pattern = re.compile(r'I\d+')

for f in nii_files:
    match = pattern.search(f)
    if match:
        id_to_file[match.group()] = f

# Match CSV rows
df['nii_filename'] = df['Image Data ID'].map(id_to_file)

# Show mapping stats
print("Rows with matched NIfTI:", df['nii_filename'].notna().sum())
print("Unmatched rows:", df['nii_filename'].isna().sum())

# Show examples
print("\nSample matched rows:")
print(df[['Image Data ID', 'nii_filename', 'Group']].head(10))


In [None]:
# --- This one too... filters valid rows and verifies class balance. becuase I couldn't find better dataset ---

# Keep only rows with valid NIfTI paths
df_matched = df[df['nii_filename'].notna()].reset_index(drop=True)

# Ensure group mapping is consistent
df_matched = df_matched[df_matched['Group'].isin(['AD', 'CN', 'MCI'])]

print(f"Final dataset size: {len(df_matched)}")
print("Label distribution:")
print(df_matched['Group'].value_counts())

# Verify a few random samples
print("\nRandom sample rows:")
print(df_matched.sample(5)[['Subject', 'Group', 'Age', 'Sex', 'MMSCORE', 'nii_filename']])


# ***Cell 8 — Training and Evaluation (5-Fold CV, Metrics, Plots)***

In [None]:
# Cell 8: Training + Evaluation with Class-Weighted Loss

from torch.utils.data import DataLoader, WeightedRandomSampler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from torch.optim import AdamW
import torch.nn as nn
from tqdm import tqdm

labels = df_matched['Group'].map({'CN':0,'MCI':1,'AD':2}).values

# Class counts and weights
unique, counts = np.unique(labels, return_counts=True)
class_counts = {int(k): int(v) for k, v in zip(unique, counts)}
print("Class counts (derived):", class_counts)

cw = np.zeros(3, dtype=np.float32)
for c in range(3):
    cw[c] = 1.0 / class_counts.get(c, 1)
class_weights = torch.tensor(cw, dtype=torch.float32).to(device)
print("Class weights used for criterion:", class_weights.cpu().numpy())

criterion = nn.CrossEntropyLoss(weight=class_weights)

# 5-fold Stratified CV.. I need to tune the, later
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

EPOCHS = 30
BATCH_SIZE = 8
LR = 1e-4
patience = 6

for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f"\n=== Fold {fold+1} ===")
    train_df = df_matched.iloc[train_idx].reset_index(drop=True)
    val_df = df_matched.iloc[val_idx].reset_index(drop=True)

    # datasets
    train_dataset = SliceDataset(train_df, ZIP_PATH, n_slices=16, tabular_cols=tabular_cols, filename_col='nii_filename')
    val_dataset   = SliceDataset(val_df, ZIP_PATH, n_slices=16, tabular_cols=tabular_cols, filename_col='nii_filename')

    # Weighted sampler for training
    train_labels_arr = labels[train_idx]
    sample_weights = np.array([1.0 / class_counts[int(l)] for l in train_labels_arr], dtype=np.float32)
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2)  # reduce workers to avoid freeze
    val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    # model & optimizer
    model = SliceEfficientNetFusion(tabular_dim=len(tabular_cols), n_slices=16, n_classes=3).to(device)
    optimizer = AdamW(model.parameters(), lr=LR, weight_decay=1e-4)

    best_macro_f1 = -1.0
    bad_epochs = 0

    for epoch in range(1, EPOCHS+1):
        # --- Train ---
        model.train()
        running_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Fold{fold+1} Train E{epoch}"):
            slices = batch['slices'].to(device)
            tab = batch['tab'].to(device)
            labels_t = batch['label'].to(device)

            optimizer.zero_grad()
            outputs = model(slices, tab)
            loss = criterion(outputs, labels_t)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * labels_t.size(0)

        avg_train_loss = running_loss / len(train_loader.dataset)

        # --- Validation ---
        model.eval()
        all_preds, all_labels = [], []
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                slices = batch['slices'].to(device)
                tab = batch['tab'].to(device)
                labels_t = batch['label'].to(device)
                outputs = model(slices, tab)
                loss = criterion(outputs, labels_t)
                val_loss += loss.item() * labels_t.size(0)
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy().tolist())
                all_labels.extend(labels_t.cpu().numpy().tolist())

        avg_val_loss = val_loss / len(val_loader.dataset)
        val_macro_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
        print(f"Epoch {epoch}: TrainLoss={avg_train_loss:.4f}, ValLoss={avg_val_loss:.4f}, ValMacroF1={val_macro_f1:.4f}")

        # early stopping & checkpoint
        if val_macro_f1 > best_macro_f1:
            best_macro_f1 = val_macro_f1
            bad_epochs = 0
            torch.save(model.state_dict(), f"best_model_fold{fold+1}.pt")
            torch.save(model.state_dict(), f"{OUT_DIR}/best_model_fold{fold+1}.pt")

        else:
            bad_epochs += 1
            if bad_epochs >= patience:
                print(f"Early stopping: no improvement for {patience} epochs.")
                break

    # --- Final evaluation ---
    model.load_state_dict(torch.load(f"best_model_fold{fold+1}.pt"))
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            slices = batch['slices'].to(device)
            tab = batch['tab'].to(device)
            labels_t = batch['label'].to(device)
            outputs = model(slices, tab)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy().tolist())
            all_labels.extend(labels_t.cpu().numpy().tolist())

    cm = confusion_matrix(all_labels, all_preds)
    print("Fold Confusion Matrix:\n", cm)
    print("Fold Classification Report:\n", classification_report(all_labels, all_preds, target_names=['CN','MCI','AD'], zero_division=0))

    fold_results.append({'fold': fold+1, 'cm': cm, 'report': classification_report(all_labels, all_preds, target_names=['CN','MCI','AD'], output_dict=True, zero_division=0)})



