# Alzheimer's Disease Stage Classification from Brain MRI Scans

## Abstract

This notebook presents a deep learning framework for classifying brain MRI scans into three clinical stages: Cognitively Normal (CN), Mild Cognitive Impairment (MCI), and Alzheimer's Disease (AD). The system employs transfer learning with a pretrained convolutional neural network backbone, augmented with attention mechanisms and uncertainty quantification. The model is designed as a clinical decision-support tool to assist radiologists in staging cognitive decline, not as a standalone diagnostic system.

## Medical Background

Alzheimer's Disease (AD) is a progressive neurodegenerative disorder characterized by cognitive decline and brain atrophy. The disease progression follows a continuum from Cognitively Normal (CN) through Mild Cognitive Impairment (MCI) to full Alzheimer's Disease (AD).

Structural MRI reveals characteristic patterns:
- **CN**: Normal brain volume and structure
- **MCI**: Early hippocampal atrophy, mild cortical thinning
- **AD**: Significant hippocampal and entorhinal cortex atrophy, ventricular enlargement, widespread cortical thinning

Automated classification can assist clinicians by providing quantitative assessments and highlighting regions of interest, but requires careful validation and clinical oversight.

## Dataset Description

The dataset consists of brain MRI scans organized into three classes:
- `data/CN/`: Cognitively Normal subjects
- `data/MCI/`: Mild Cognitive Impairment subjects
- `data/AD/`: Alzheimer's Disease subjects

Each scan may be provided as:
- PNG slices (axial plane)
- NIfTI (.nii) volumes requiring slice extraction

The dataset exhibits class imbalance, which will be addressed through weighted sampling and loss functions.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
import nibabel as nib
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, roc_curve
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## Data Preprocessing

In [None]:
class MRIDataset(Dataset):
    def __init__(self, data_dir, transform=None, mode='train'):
        self.data_dir = Path(data_dir)
        self.transform = transform
        self.mode = mode
        self.samples = []
        self.labels = []
        
        class_map = {'CN': 0, 'MCI': 1, 'AD': 2}
        
        for class_name, label in class_map.items():
            class_dir = self.data_dir / class_name
            if not class_dir.exists():
                continue
                
            for item in class_dir.iterdir():
                if item.suffix == '.png' or item.suffix == '.jpg':
                    self.samples.append(str(item))
                    self.labels.append(label)
                elif item.suffix == '.nii' or item.suffix == '.nii.gz':
                    slices = self._extract_slices(str(item))
                    for slice_img in slices:
                        self.samples.append(slice_img)
                        self.labels.append(label)
    
    def _extract_slices(self, nii_path, num_slices=5):
        try:
            nii_img = nib.load(nii_path)
            data = nii_img.get_fdata()
            
            if len(data.shape) == 3:
                mid_idx = data.shape[2] // 2
                slice_indices = np.linspace(mid_idx - num_slices//2, mid_idx + num_slices//2, num_slices, dtype=int)
                slice_indices = np.clip(slice_indices, 0, data.shape[2] - 1)
                
                slices = []
                for idx in slice_indices:
                    slice_data = data[:, :, idx]
                    slice_normalized = self._normalize_slice(slice_data)
                    slice_img = Image.fromarray((slice_normalized * 255).astype(np.uint8))
                    slices.append(slice_img)
                return slices
            return []
        except:
            return []
    
    def _normalize_slice(self, slice_data):
        slice_data = slice_data.astype(np.float32)
        p2, p98 = np.percentile(slice_data, [2, 98])
        slice_data = np.clip(slice_data, p2, p98)
        if p98 > p2:
            slice_data = (slice_data - p2) / (p98 - p2)
        else:
            slice_data = np.zeros_like(slice_data)
        return slice_data
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        if isinstance(self.samples[idx], str):
            img = Image.open(self.samples[idx]).convert('RGB')
        else:
            img = self.samples[idx].convert('RGB')
        
        label = self.labels[idx]
        
        if self.transform:
            img = self.transform(img)
        
        return img, label

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## Model Architecture

In [None]:
class AttentionPooling(nn.Module):
    def __init__(self, feature_dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 4),
            nn.ReLU(),
            nn.Linear(feature_dim // 4, 1),
            nn.Sigmoid()
        )
    
    def forward(self, features):
        if len(features.shape) == 2:
            features = features.unsqueeze(0)
        attention_weights = self.attention(features)
        weighted_features = features * attention_weights
        pooled = weighted_features.sum(dim=1)
        return pooled, attention_weights.squeeze(-1)

In [None]:
class AlzheimerClassifier(nn.Module):
    def __init__(self, num_classes=3, dropout_rate=0.5, use_attention=True):
        super().__init__()
        
        backbone = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
        self.backbone = nn.Sequential(*list(backbone.features.children()))
        
        self.feature_dim = 1536
        self.use_attention = use_attention
        
        if use_attention:
            self.attention_pool = AttentionPooling(self.feature_dim)
        else:
            self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(self.feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x, return_attention=False):
        features = self.backbone(x)
        
        if self.use_attention:
            b, c, h, w = features.shape
            features_flat = features.view(b, c, h * w).permute(0, 2, 1)
            pooled, attention_weights = self.attention_pool(features_flat)
        else:
            pooled = self.global_pool(features).view(features.size(0), -1)
            attention_weights = None
        
        logits = self.classifier(pooled)
        
        if return_attention:
            return logits, attention_weights, features
        return logits

## Training Strategy

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        return focal_loss.mean()

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(dataloader, desc='Training'):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    return running_loss / len(dataloader), 100 * correct / total

In [None]:
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc='Validating'):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            probs = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    accuracy = 100 * np.mean(np.array(all_preds) == np.array(all_labels))
    
    all_probs = np.array(all_probs)
    roc_auc = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='macro')
    
    return running_loss / len(dataloader), accuracy, roc_auc, all_preds, all_labels, all_probs

In [None]:
data_dir = 'data'
full_dataset = MRIDataset(data_dir, transform=train_transform, mode='train')

labels = [full_dataset[i][1] for i in range(len(full_dataset))]
class_weights = compute_class_weight('balanced', classes=np.unique(labels), y=labels)
class_weights = torch.FloatTensor(class_weights).to(device)

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
indices = np.arange(len(full_dataset))

fold_results = []

for fold, (train_idx, val_idx) in enumerate(skf.split(indices, labels)):
    print(f'\nFold {fold + 1}/5')
    
    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
    
    train_loader = DataLoader(full_dataset, batch_size=16, sampler=train_sampler, num_workers=2)
    val_loader = DataLoader(full_dataset, batch_size=16, sampler=val_sampler, num_workers=2, transform=val_transform)
    
    model = AlzheimerClassifier(num_classes=3, dropout_rate=0.5, use_attention=True).to(device)
    
    criterion = FocalLoss(alpha=class_weights, gamma=2.0)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    
    best_roc_auc = 0.0
    patience_counter = 0
    max_patience = 15
    
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    val_rocs = []
    
    for epoch in range(50):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, val_roc, _, _, _ = validate(model, val_loader, criterion, device)
        
        scheduler.step(val_loss)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        val_rocs.append(val_roc)
        
        if val_roc > best_roc_auc:
            best_roc_auc = val_roc
            patience_counter = 0
            torch.save(model.state_dict(), f'best_model_fold_{fold}.pth')
        else:
            patience_counter += 1
        
        if patience_counter >= max_patience:
            break
    
    model.load_state_dict(torch.load(f'best_model_fold_{fold}.pth'))
    _, _, final_roc, final_preds, final_labels, final_probs = validate(model, val_loader, criterion, device)
    
    fold_results.append({
        'fold': fold + 1,
        'roc_auc': final_roc,
        'predictions': final_preds,
        'labels': final_labels,
        'probabilities': final_probs
    })
    
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'Fold {fold + 1} - Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Acc')
    plt.plot(val_accs, label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.title(f'Fold {fold + 1} - Accuracy')
    plt.tight_layout()
    plt.show()

mean_roc_auc = np.mean([r['roc_auc'] for r in fold_results])
std_roc_auc = np.std([r['roc_auc'] for r in fold_results])
print(f'\nCross-Validation ROC-AUC: {mean_roc_auc:.4f} Â± {std_roc_auc:.4f}')

## Evaluation

In [None]:
all_preds = np.concatenate([r['predictions'] for r in fold_results])
all_labels = np.concatenate([r['labels'] for r in fold_results])
all_probs = np.concatenate([r['probabilities'] for r in fold_results])

cm = confusion_matrix(all_labels, all_preds)
class_names = ['CN', 'MCI', 'AD']

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix')
plt.show()

print(classification_report(all_labels, all_preds, target_names=class_names))

In [None]:
plt.figure(figsize=(10, 8))

for i, class_name in enumerate(class_names):
    y_true_binary = (all_labels == i).astype(int)
    y_score = all_probs[:, i]
    
    fpr, tpr, _ = roc_curve(y_true_binary, y_score)
    roc_auc = roc_auc_score(y_true_binary, y_score)
    
    plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.3f})', linewidth=2)

plt.plot([0, 1], [0, 1], 'k--', linewidth=1)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('ROC Curves per Class', fontsize=14)
plt.legend(loc='lower right', fontsize=11)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

### Clinical Tradeoffs

The confusion matrix reveals important clinical considerations:

1. **CN vs MCI distinction**: Early-stage MCI may be misclassified as CN due to subtle structural changes. This is clinically acceptable as MCI diagnosis often requires longitudinal follow-up.

2. **MCI vs AD boundary**: The model may struggle with the MCI-AD transition, which reflects the continuum nature of disease progression. False positives in AD classification require careful clinical correlation.

3. **Sensitivity vs Specificity**: High sensitivity for AD detection is critical for early intervention, but must be balanced against false positives that could cause unnecessary patient anxiety.

4. **Class imbalance impact**: The model's performance on minority classes (typically MCI) may be lower, necessitating careful interpretation in clinical settings.

## Explainability

In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]
        
        def forward_hook(module, input, output):
            self.activations = output
        
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_full_backward_hook(backward_hook)
    
    def generate(self, input_tensor, class_idx=None):
        self.model.eval()
        input_tensor = input_tensor.unsqueeze(0).to(device)
        
        output = self.model(input_tensor)
        if class_idx is None:
            class_idx = output.argmax(dim=1)
        
        self.model.zero_grad()
        output[0, class_idx].backward()
        
        gradients = self.gradients[0]
        activations = self.activations[0]
        
        weights = torch.mean(gradients, dim=(1, 2), keepdim=True)
        cam = torch.sum(weights * activations, dim=0)
        cam = F.relu(cam)
        cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(384, 384), mode='bilinear', align_corners=False)
        cam = cam.squeeze().cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        
        return cam

In [None]:
model.load_state_dict(torch.load('best_model_fold_0.pth'))
model.eval()

target_layer = model.backbone[-1]
gradcam = GradCAM(model, target_layer)

val_dataset = MRIDataset(data_dir, transform=val_transform, mode='val')
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

for idx, (image, label) in enumerate(val_loader):
    if idx >= 3:
        break
    
    image = image.to(device)
    output = model(image)
    pred_class = output.argmax(dim=1).item()
    prob = F.softmax(output, dim=1)[0, pred_class].item()
    
    cam = gradcam.generate(image.squeeze(0), pred_class)
    
    img_unnorm = image.squeeze(0).cpu().permute(1, 2, 0)
    img_unnorm = img_unnorm * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])
    img_unnorm = torch.clamp(img_unnorm, 0, 1)
    
    axes[idx, 0].imshow(img_unnorm.numpy())
    axes[idx, 0].set_title(f'Original\nTrue: {class_names[label.item()]}')
    axes[idx, 0].axis('off')
    
    axes[idx, 1].imshow(cam, cmap='jet')
    axes[idx, 1].set_title('Grad-CAM Heatmap')
    axes[idx, 1].axis('off')
    
    overlay = img_unnorm.numpy().copy()
    cam_resized = np.stack([cam] * 3, axis=-1)
    overlay = 0.6 * overlay + 0.4 * cam_resized
    
    axes[idx, 2].imshow(overlay)
    axes[idx, 2].set_title(f'Overlay\nPred: {class_names[pred_class]} ({prob:.2f})')
    axes[idx, 2].axis('off')

plt.tight_layout()
plt.show()

### Clinical Relevance of Saliency Maps

Grad-CAM visualizations highlight regions the model uses for classification decisions. In Alzheimer's Disease:

1. **Hippocampal regions**: Expected high activation in medial temporal lobe structures, consistent with known AD pathology.

2. **Ventricular enlargement**: Attention to lateral ventricles aligns with disease progression markers.

3. **Cortical regions**: Activation in parietal and frontal cortices may reflect cortical thinning patterns.

4. **Validation**: Overlap with known neuroanatomical AD markers provides face validity, though clinical correlation remains essential.

These maps assist radiologists by directing attention to regions of interest, but should not replace comprehensive image review.

## Ethics & Limitations

### Dataset Bias

The model is trained on a specific dataset that may not represent global population diversity. Potential biases include:

- **Demographic bias**: Age, sex, ethnicity, and socioeconomic factors may not be representative
- **Geographic bias**: Data from specific regions or healthcare systems
- **Scanner bias**: MRI acquisition protocols, field strengths, and manufacturers vary
- **Selection bias**: Inclusion/exclusion criteria may favor certain patient populations

### Scanner Variability

MRI scans vary significantly across:

- Field strength (1.5T vs 3T)
- Acquisition protocols (slice thickness, resolution, contrast)
- Manufacturer-specific image characteristics
- Preprocessing pipelines

The model's performance may degrade when applied to scans from different scanners or protocols not represented in training data.

### Generalization Limits

1. **Temporal stability**: Model performance may change as imaging technology evolves
2. **Comorbidity**: Performance may degrade with concurrent neurological conditions
3. **Early-stage detection**: Limited sensitivity for very early MCI or pre-symptomatic AD
4. **Population shifts**: Performance on populations with different disease prevalence

### Non-Diagnostic Nature

This model is a **decision-support tool**, not a diagnostic system. Key limitations:

- AD diagnosis requires comprehensive clinical evaluation, not imaging alone
- Model outputs are probabilities, not definitive diagnoses
- False positives/negatives can have significant clinical consequences
- Model cannot account for patient history, symptoms, or other diagnostic information

### Need for Clinician Oversight

1. **Mandatory review**: All model predictions must be reviewed by qualified radiologists or neurologists
2. **Clinical context**: Model outputs must be integrated with patient history, symptoms, and other diagnostic tests
3. **Quality assurance**: Regular monitoring of model performance in clinical deployment
4. **Continuous validation**: Ongoing evaluation against ground truth and clinical outcomes
5. **Regulatory compliance**: Deployment must comply with medical device regulations (FDA, CE marking, etc.)

### Ethical Considerations

- **Patient autonomy**: Patients should be informed about AI-assisted analysis
- **Privacy**: MRI data contains sensitive health information requiring strict privacy protection
- **Equity**: Ensure model benefits are accessible across diverse populations
- **Transparency**: Clinicians and patients should understand model limitations
- **Accountability**: Clear responsibility for clinical decisions remains with healthcare providers

## Future Work

### Model Improvements

1. **Multi-modal fusion**: Integrate structural MRI with functional MRI (fMRI), PET scans, or cerebrospinal fluid biomarkers
2. **Longitudinal modeling**: Incorporate temporal information from follow-up scans to track disease progression
3. **3D architectures**: Full volumetric analysis using 3D CNNs or vision transformers
4. **Self-supervised pretraining**: Domain-specific pretraining on large unlabeled medical imaging datasets
5. **Transformer architectures**: Vision transformers (ViT) or medical-specific transformer designs

### Clinical Integration

1. **Prospective validation**: Large-scale multi-center prospective studies
2. **Real-world deployment**: Integration into clinical PACS systems and radiology workflows
3. **Outcome prediction**: Extend beyond classification to predict disease progression rates
4. **Treatment response**: Model response to therapeutic interventions

### Technical Enhancements

1. **Uncertainty quantification**: Bayesian neural networks or ensemble methods for reliable confidence estimates
2. **Federated learning**: Train across institutions while preserving data privacy
3. **Domain adaptation**: Techniques to improve generalization across scanner types and protocols
4. **Interpretability**: Enhanced explainability methods tailored for medical imaging

### Dataset Expansion

1. **Diverse populations**: Inclusion of underrepresented demographic groups
2. **Multi-center data**: Aggregation of data from diverse healthcare systems
3. **Rare variants**: Training on atypical AD presentations and variants
4. **Comorbidity cases**: Inclusion of patients with multiple neurological conditions

### Regulatory & Validation

1. **Regulatory approval**: Pursue FDA 510(k) or De Novo pathways for clinical use
2. **Clinical trials**: Randomized controlled trials assessing impact on diagnostic accuracy and patient outcomes
3. **Cost-effectiveness**: Health economic analyses of AI-assisted diagnosis
4. **Standardization**: Development of standardized evaluation protocols and benchmarks