# MRI Image Classification for Alzheimer's Detection
Intermediate-level template using PyTorch (ResNet50 + custom CNN) with Grad-CAM explainability. Follow the cells to load data, train, evaluate, and visualize attention.

## 1) Import Required Libraries
Core imports for data handling, modeling, metrics, and visualization.

In [None]:
# Imports
import os
import random
import numpy as np
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms as T
import torchvision.models as models

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
import matplotlib.pyplot as plt

from PIL import Image
import nibabel as nib

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

## 2) Load and Explore MRI Dataset
Point the paths to your MRI data (PNG/JPG or NIfTI). Inspect shapes, pixel ranges, and class balance.

In [None]:
# TODO: set your data paths
DATA_DIR = Path('../data')  # adjust to actual dataset
IMG_EXT = ('.png', '.jpg', '.jpeg')
NIFTI_EXT = ('.nii', '.nii.gz')

# Example: load file list and simple EDA
image_paths = []
labels = []  # map to {class_name: index}

# Placeholder: populate image_paths and labels here
print(f"Found {len(image_paths)} samples")

# Quick class balance check
import collections
counter = collections.Counter(labels)
print('Class distribution:', counter)


## 3) Preprocess and Normalize Images
Resize to 224x224, normalize, and add augmentation (rotation, flip). Handle NIfTI by taking a middle slice.

In [None]:
from models.preprocessing import normalize_image

class MRIDataset(Dataset):
    def __init__(self, paths, labels, train=True):
        self.paths = paths
        self.labels = labels
        self.train = train
        self.base_transform = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485], std=[0.229])
        ])
        self.aug_transform = T.Compose([
            T.RandomRotation(15),
            T.RandomHorizontalFlip(),
        ])

    def __len__(self):
        return len(self.paths)

    def load_image(self, path):
        path = str(path)
        if path.endswith(NIFTI_EXT):
            vol = nib.load(path).get_fdata()
            mid = vol.shape[2] // 2
            arr = vol[:, :, mid]
            arr = normalize_image(arr)
            arr = (arr * 255).astype(np.uint8)
            img = Image.fromarray(arr).convert('L')
        else:
            img = Image.open(path).convert('L')
        return img

    def __getitem__(self, idx):
        img = self.load_image(self.paths[idx])
        if self.train:
            img = self.aug_transform(img)
        img = self.base_transform(img)
        label = self.labels[idx]
        return img, label


## 4) Split Data into Train, Validation, and Test Sets
Use 70/15/15 split, keeping class balance if possible (stratify when you build `paths/labels`).

In [None]:
# Stratified split placeholder
indices = np.arange(len(image_paths))
np.random.shuffle(indices)

train_end = int(0.7 * len(indices))
val_end = int(0.85 * len(indices))

train_idx = indices[:train_end]
val_idx = indices[train_end:val_end]
test_idx = indices[val_end:]

train_ds = MRIDataset([image_paths[i] for i in train_idx], [labels[i] for i in train_idx], train=True)
val_ds = MRIDataset([image_paths[i] for i in val_idx], [labels[i] for i in val_idx], train=False)
test_ds = MRIDataset([image_paths[i] for i in test_idx], [labels[i] for i in test_idx], train=False)

BATCH_SIZE = 16
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


## 5) Build Convolutional Neural Network Model
Use transfer learning (ResNet50 adapted to 1-channel) or the provided custom CNN baseline.

In [None]:
from models.cnn_model import ResNetModel, AlzheimersCNN

NUM_CLASSES = 4

# Choose model: 'resnet' or 'cnn'
model_choice = 'resnet'
if model_choice == 'resnet':
    model = ResNetModel(pretrained=True, num_classes=NUM_CLASSES)
else:
    model = AlzheimersCNN(num_classes=NUM_CLASSES)

model = model.to(DEVICE)
print(model.__class__.__name__)


## 6) Train the CNN Model
Define loss, optimizer, and basic training loop. Monitor validation loss to catch overfitting.

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

def train_one_epoch(loader):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for imgs, lbls in loader:
        imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, lbls)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        _, pred = out.max(1)
        total += lbls.size(0)
        correct += (pred == lbls).sum().item()
    return total_loss / len(loader), correct / total

def validate(loader):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
            out = model(imgs)
            loss = criterion(out, lbls)
            total_loss += loss.item()
            _, pred = out.max(1)
            total += lbls.size(0)
            correct += (pred == lbls).sum().item()
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(lbls.cpu().numpy())
    return total_loss / len(loader), correct / total, all_preds, all_labels

EPOCHS = 5  # increase for real training
for epoch in range(EPOCHS):
    tr_loss, tr_acc = train_one_epoch(train_loader)
    val_loss, val_acc, vp, vl = validate(val_loader)
    scheduler.step(val_loss)
    print(f"Epoch {epoch+1}: train_loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={val_loss:.4f} acc={val_acc:.3f}")


## 7) Evaluate Model Performance
Compute accuracy, precision, recall, F1, ROC-AUC, and confusion matrix on the test set.

In [None]:
def evaluate(loader):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
            out = model(imgs)
            probs = torch.softmax(out, dim=1)
            _, pred = out.max(1)
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(lbls.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    rec = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    cm = confusion_matrix(all_labels, all_preds)
    print(f"Test: acc={acc:.3f} prec={prec:.3f} rec={rec:.3f} f1={f1:.3f}")
    print('Confusion matrix:\n', cm)

evaluate(test_loader)


## 8) Compare Multiple Models
Swap between ResNet50, custom CNN, or other backbones (VGG, ResNet18). Track metrics side by side.

In [None]:
candidate_models = {
    'resnet50': lambda: ResNetModel(pretrained=True, num_classes=NUM_CLASSES),
    'custom_cnn': lambda: AlzheimersCNN(num_classes=NUM_CLASSES)
    # Add 'resnet18' or VGG variants as needed
}

# TODO: loop over candidate_models, train briefly, and log metrics per model
model_scores = {}
for name, builder in candidate_models.items():
    print(f"\n[Candidate] {name}")
    # NOTE: for brevity, re-use loaders; for real runs, re-init model/optimizer each loop
    model = builder().to(DEVICE)
    # train/validate as above, store best val acc/F1
    model_scores[name] = {'val_acc': None, 'val_f1': None}

print('Model comparison (fill after running):', model_scores)


## 9) Visualize Predictions and Feature Maps
Plot correct vs incorrect predictions and inspect activation maps/Grad-CAM overlays.

In [None]:
from models.grad_cam import GradCAM

def show_grad_cam(img_tensor, model, target_class=None):
    grad_cam = GradCAM(model, target_layer='layer4')
    cam = grad_cam.generate(img_tensor, target_class)
    plt.figure(figsize=(6,3))
    plt.subplot(1,2,1)
    plt.imshow(img_tensor.cpu().squeeze(), cmap='gray')
    plt.axis('off'); plt.title('Input')
    plt.subplot(1,2,2)
    plt.imshow(img_tensor.cpu().squeeze(), cmap='gray')
    plt.imshow(cam, alpha=0.5)
    plt.axis('off'); plt.title('Grad-CAM')
    plt.show()

# TODO: grab a batch from test_loader and visualize


## 10) Interpret Model Results
Summarize findings, note biases/limitations, and pick the best model. Export metrics/figures for the report and model card.