# Alzheimer's Disease Stage Classification from Brain MRI Scans

## Abstract

This notebook presents a deep learning framework for classifying brain MRI scans into three clinical stages: Cognitively Normal (CN), Mild Cognitive Impairment (MCI), and Alzheimer's Disease (AD). The system employs transfer learning with EfficientNet-B3, augmented with attention mechanisms and uncertainty quantification. The model serves as a clinical decision-support tool to assist radiologists in staging cognitive decline, not as a standalone diagnostic system.

## Medical Background

Alzheimer's Disease (AD) is a progressive neurodegenerative disorder characterized by cognitive decline and brain atrophy. The disease progression follows a continuum from Cognitively Normal (CN) through Mild Cognitive Impairment (MCI) to full Alzheimer's Disease (AD).

Structural MRI reveals characteristic patterns:
- **CN**: Normal brain volume and structure
- **MCI**: Early hippocampal atrophy, mild cortical thinning
- **AD**: Significant hippocampal and entorhinal cortex atrophy, ventricular enlargement, widespread cortical thinning

Automated classification can assist clinicians by providing quantitative assessments and highlighting regions of interest, but requires careful validation and clinical oversight.

## Dataset Description

The dataset consists of brain MRI scans from Kaggle, organized into training and test sets with four impairment levels:
- `No Impairment` → mapped to **CN** (Cognitively Normal)
- `Very Mild Impairment` → mapped to **MCI** (Mild Cognitive Impairment)
- `Mild Impairment` and `Moderate Impairment` → mapped to **AD** (Alzheimer's Disease)

Images are provided as PNG/JPG slices in axial plane. The dataset exhibits class imbalance, which will be addressed through weighted sampling and focal loss.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, roc_curve
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import shutil
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## Preprocessing

In [None]:
def verify_structure(data_dir):
    data_path = Path(data_dir)
    train_path = data_path / 'train'
    test_path = data_path / 'test'
    
    if not train_path.exists() or not test_path.exists():
        raise ValueError(f'Expected train/ and test/ directories in {data_dir}')
    
    expected_classes = ['No Impairment', 'Very Mild Impairment', 'Mild Impairment', 'Moderate Impairment']
    
    for split_path in [train_path, test_path]:
        found_classes = [d.name for d in split_path.iterdir() if d.is_dir()]
        for exp_class in expected_classes:
            if exp_class not in found_classes:
                raise ValueError(f'Missing class directory: {exp_class} in {split_path}')
    
    print('Dataset structure verified')
    return train_path, test_path

In [None]:
def process_images(source_dir, target_dir, target_size=(224, 224)):
    source_path = Path(source_dir)
    target_path = Path(target_dir)
    
    class_mapping = {
        'No Impairment': 'CN',
        'Very Mild Impairment': 'MCI',
        'Mild Impairment': 'AD',
        'Moderate Impairment': 'AD'
    }
    
    for split in ['train', 'test']:
        split_source = source_path / split
        split_target = target_path / split
        
        for old_class, new_class in class_mapping.items():
            old_class_dir = split_source / old_class
            new_class_dir = split_target / new_class
            new_class_dir.mkdir(parents=True, exist_ok=True)
            
            image_files = list(old_class_dir.glob('*.png')) + list(old_class_dir.glob('*.jpg')) + list(old_class_dir.glob('*.jpeg'))
            
            for img_file in tqdm(image_files, desc=f'Processing {split}/{old_class}'):
                try:
                    img = Image.open(img_file)
                    if img.mode != 'RGB':
                        img = img.convert('RGB')
                    img = img.resize(target_size, Image.Resampling.LANCZOS)
                    
                    new_filename = f'{new_class}_{img_file.stem}.png'
                    img.save(new_class_dir / new_filename, 'PNG')
                except Exception as e:
                    continue
    
    print('Image preprocessing completed')

In [None]:
train_path, test_path = verify_structure('./data')

if Path('data_processed').exists():
    shutil.rmtree('data_processed')

process_images('./data', 'data_processed', target_size=(224, 224))

In [None]:
class MRIDataset(Dataset):
    def __init__(self, data_dir, transform=None, split='train'):
        self.data_dir = Path(data_dir) / split
        self.transform = transform
        self.samples = []
        self.labels = []
        
        class_map = {'CN': 0, 'MCI': 1, 'AD': 2}
        
        for class_name, label in class_map.items():
            class_dir = self.data_dir / class_name
            if not class_dir.exists():
                continue
            
            for img_file in class_dir.glob('*.png'):
                self.samples.append(str(img_file))
                self.labels.append(label)
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img = Image.open(self.samples[idx]).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            img = self.transform(img)
        
        return img, label

In [None]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## Model Architecture

In [None]:
class AttentionPooling(nn.Module):
    def __init__(self, feature_dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 4),
            nn.ReLU(),
            nn.Linear(feature_dim // 4, 1),
            nn.Sigmoid()
        )
    
    def forward(self, features):
        if len(features.shape) == 2:
            features = features.unsqueeze(0)
        attention_weights = self.attention(features)
        weighted_features = features * attention_weights
        pooled = weighted_features.sum(dim=1)
        return pooled, attention_weights.squeeze(-1)

In [None]:
class AlzheimerClassifier(nn.Module):
    def __init__(self, num_classes=3, dropout_rate=0.5, use_attention=True):
        super().__init__()
        
        backbone = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
        self.backbone = nn.Sequential(*list(backbone.features.children()))
        
        self.feature_dim = 1536
        self.use_attention = use_attention
        
        if use_attention:
            self.attention_pool = AttentionPooling(self.feature_dim)
        else:
            self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(self.feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x, return_attention=False):
        features = self.backbone(x)
        
        if self.use_attention:
            b, c, h, w = features.shape
            features_flat = features.view(b, c, h * w).permute(0, 2, 1)
            pooled, attention_weights = self.attention_pool(features_flat)
        else:
            pooled = self.global_pool(features).view(features.size(0), -1)
            attention_weights = None
        
        logits = self.classifier(pooled)
        
        if return_attention:
            return logits, attention_weights, features
        return logits

## Training Strategy

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        return focal_loss.mean()

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(dataloader, desc='Training', leave=False):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    return running_loss / len(dataloader), 100 * correct / total

In [None]:
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc='Validating', leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            probs = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    accuracy = 100 * np.mean(np.array(all_preds) == np.array(all_labels))
    
    all_probs = np.array(all_probs)
    roc_auc = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='macro')
    
    return running_loss / len(dataloader), accuracy, roc_auc, all_preds, all_labels, all_probs

In [None]:
train_dataset = MRIDataset('data_processed', transform=train_transform, split='train')
test_dataset = MRIDataset('data_processed', transform=val_transform, split='test')

train_indices, val_indices = train_test_split(
    np.arange(len(train_dataset)),
    test_size=0.2,
    stratify=[train_dataset[i][1] for i in range(len(train_dataset))],
    random_state=42
)

train_subset = torch.utils.data.Subset(train_dataset, train_indices)
val_subset = torch.utils.data.Subset(train_dataset, val_indices)

train_labels = [train_dataset[i][1] for i in train_indices]
labels_array = np.array(train_labels, dtype=np.int64)
unique_classes = np.unique(labels_array)
class_weights = compute_class_weight('balanced', classes=unique_classes, y=labels_array)
class_weights_dict = dict(zip(unique_classes, class_weights))
class_weights_tensor = torch.FloatTensor([class_weights_dict[i] for i in range(3)]).to(device)

train_loader = DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

print(f'Train samples: {len(train_subset)}, Val samples: {len(val_subset)}, Test samples: {len(test_dataset)}')

In [None]:
model = AlzheimerClassifier(num_classes=3, dropout_rate=0.5, use_attention=True).to(device)

criterion = FocalLoss(alpha=class_weights_tensor, gamma=2.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

best_roc_auc = 0.0
patience_counter = 0
max_patience = 15

train_losses = []
val_losses = []
train_accs = []
val_accs = []
val_rocs = []

for epoch in range(50):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, val_roc, _, _, _ = validate(model, val_loader, criterion, device)
    
    scheduler.step(val_loss)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    val_rocs.append(val_roc)
    
    print(f'Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%, Val ROC-AUC={val_roc:.4f}')
    
    if val_roc > best_roc_auc:
        best_roc_auc = val_roc
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1
    
    if patience_counter >= max_patience:
        print(f'Early stopping at epoch {epoch+1}')
        break

model.load_state_dict(torch.load('best_model.pth'))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Training and Validation Accuracy')
plt.tight_layout()
plt.show()

## Evaluation

In [None]:
test_loss, test_acc, test_roc, test_preds, test_labels, test_probs = validate(model, test_loader, criterion, device)

print(f'Test Accuracy: {test_acc:.2f}%')
print(f'Test Macro ROC-AUC: {test_roc:.4f}')

In [None]:
cm = confusion_matrix(test_labels, test_preds)
class_names = ['CN', 'MCI', 'AD']

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.show()

print(classification_report(test_labels, test_preds, target_names=class_names))

In [None]:
plt.figure(figsize=(10, 8))

for i, class_name in enumerate(class_names):
    y_true_binary = (np.array(test_labels) == i).astype(int)
    y_score = test_probs[:, i]
    
    fpr, tpr, _ = roc_curve(y_true_binary, y_score)
    roc_auc = roc_auc_score(y_true_binary, y_score)
    
    plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.3f})', linewidth=2)

plt.plot([0, 1], [0, 1], 'k--', linewidth=1)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('ROC Curves per Class', fontsize=14)
plt.legend(loc='lower right', fontsize=11)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

### Clinical Tradeoffs

The confusion matrix reveals important clinical considerations:

1. **CN vs MCI distinction**: Early-stage MCI may be misclassified as CN due to subtle structural changes. This is clinically acceptable as MCI diagnosis often requires longitudinal follow-up.

2. **MCI vs AD boundary**: The model may struggle with the MCI-AD transition, which reflects the continuum nature of disease progression. False positives in AD classification require careful clinical correlation.

3. **Sensitivity vs Specificity**: High sensitivity for AD detection is critical for early intervention, but must be balanced against false positives that could cause unnecessary patient anxiety.

4. **Class imbalance impact**: The model's performance on minority classes (typically MCI) may be lower, necessitating careful interpretation in clinical settings.

## Explainability

In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]
        
        def forward_hook(module, input, output):
            self.activations = output
        
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_full_backward_hook(backward_hook)
    
    def generate(self, input_tensor, class_idx=None):
        self.model.eval()
        input_tensor = input_tensor.unsqueeze(0).to(device) if len(input_tensor.shape) == 3 else input_tensor.to(device)
        
        output = self.model(input_tensor)
        if class_idx is None:
            class_idx = output.argmax(dim=1)
        
        self.model.zero_grad()
        output[0, class_idx].backward()
        
        gradients = self.gradients[0]
        activations = self.activations[0]
        
        weights = torch.mean(gradients, dim=(1, 2), keepdim=True)
        cam = torch.sum(weights * activations, dim=0)
        cam = F.relu(cam)
        cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False)
        cam = cam.squeeze().cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        
        return cam

In [None]:
target_layer = model.backbone[-1]
gradcam = GradCAM(model, target_layer)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

for idx, (image, label) in enumerate(test_dataset):
    if idx >= 3:
        break
    
    image_tensor = val_transform(image).to(device)
    output = model(image_tensor.unsqueeze(0))
    pred_class = output.argmax(dim=1).item()
    prob = F.softmax(output, dim=1)[0, pred_class].item()
    
    cam = gradcam.generate(image_tensor, pred_class)
    
    img_array = np.array(image)
    if len(img_array.shape) == 3 and img_array.shape[2] == 3:
        img_display = img_array
    else:
        img_display = np.stack([img_array] * 3, axis=-1) if len(img_array.shape) == 2 else img_array
    
    axes[idx, 0].imshow(img_display)
    axes[idx, 0].set_title(f'Original\nTrue: {class_names[label]}')
    axes[idx, 0].axis('off')
    
    axes[idx, 1].imshow(cam, cmap='jet')
    axes[idx, 1].set_title('Grad-CAM Heatmap')
    axes[idx, 1].axis('off')
    
    overlay = img_display.copy() / 255.0 if img_display.max() > 1 else img_display.copy()
    cam_resized = np.stack([cam] * 3, axis=-1)
    overlay = 0.6 * overlay + 0.4 * cam_resized
    
    axes[idx, 2].imshow(overlay)
    axes[idx, 2].set_title(f'Overlay\nPred: {class_names[pred_class]} ({prob:.2f})')
    axes[idx, 2].axis('off')

plt.tight_layout()
plt.show()

### Clinical Relevance of Saliency Maps

Grad-CAM visualizations highlight regions the model uses for classification decisions. In Alzheimer's Disease:

1. **Hippocampal regions**: Expected high activation in medial temporal lobe structures, consistent with known AD pathology.

2. **Ventricular enlargement**: Attention to lateral ventricles aligns with disease progression markers.

3. **Cortical regions**: Activation in parietal and frontal cortices may reflect cortical thinning patterns.

4. **Validation**: Overlap with known neuroanatomical AD markers provides face validity, though clinical correlation remains essential.

These maps assist radiologists by directing attention to regions of interest, but should not replace comprehensive image review.

## Ethics & Limitations

### Dataset Bias

The model is trained on a specific dataset that may not represent global population diversity. Potential biases include:

- **Demographic bias**: Age, sex, ethnicity, and socioeconomic factors may not be representative
- **Geographic bias**: Data from specific regions or healthcare systems
- **Scanner bias**: MRI acquisition protocols, field strengths, and manufacturers vary
- **Selection bias**: Inclusion/exclusion criteria may favor certain patient populations

### Scanner Variability

MRI scans vary significantly across:

- Field strength (1.5T vs 3T)
- Acquisition protocols (slice thickness, resolution, contrast)
- Manufacturer-specific image characteristics
- Preprocessing pipelines

The model's performance may degrade when applied to scans from different scanners or protocols not represented in training data.

### Generalization Limits

1. **Temporal stability**: Model performance may change as imaging technology evolves
2. **Comorbidity**: Performance may degrade with concurrent neurological conditions
3. **Early-stage detection**: Limited sensitivity for very early MCI or pre-symptomatic AD
4. **Population shifts**: Performance on populations with different disease prevalence

### Non-Diagnostic Nature

This model is a **decision-support tool**, not a diagnostic system. Key limitations:

- AD diagnosis requires comprehensive clinical evaluation, not imaging alone
- Model outputs are probabilities, not definitive diagnoses
- False positives/negatives can have significant clinical consequences
- Model cannot account for patient history, symptoms, or other diagnostic information

### Need for Clinician Oversight

1. **Mandatory review**: All model predictions must be reviewed by qualified radiologists or neurologists
2. **Clinical context**: Model outputs must be integrated with patient history, symptoms, and other diagnostic tests
3. **Quality assurance**: Regular monitoring of model performance in clinical deployment
4. **Continuous validation**: Ongoing evaluation against ground truth and clinical outcomes
5. **Regulatory compliance**: Deployment must comply with medical device regulations (FDA, CE marking, etc.)

### Ethical Considerations

- **Patient autonomy**: Patients should be informed about AI-assisted analysis
- **Privacy**: MRI data contains sensitive health information requiring strict privacy protection
- **Equity**: Ensure model benefits are accessible across diverse populations
- **Transparency**: Clinicians and patients should understand model limitations
- **Accountability**: Clear responsibility for clinical decisions remains with healthcare providers

## Future Work

### Model Improvements

1. **Multi-modal fusion**: Integrate structural MRI with functional MRI (fMRI), PET scans, or cerebrospinal fluid biomarkers
2. **Longitudinal modeling**: Incorporate temporal information from follow-up scans to track disease progression
3. **3D architectures**: Full volumetric analysis using 3D CNNs or vision transformers
4. **Self-supervised pretraining**: Domain-specific pretraining on large unlabeled medical imaging datasets
5. **Transformer architectures**: Vision transformers (ViT) or medical-specific transformer designs

### Clinical Integration

1. **Prospective validation**: Large-scale multi-center prospective studies
2. **Real-world deployment**: Integration into clinical PACS systems and radiology workflows
3. **Outcome prediction**: Extend beyond classification to predict disease progression rates
4. **Treatment response**: Model response to therapeutic interventions

### Technical Enhancements

1. **Uncertainty quantification**: Bayesian neural networks or ensemble methods for reliable confidence estimates
2. **Federated learning**: Train across institutions while preserving data privacy
3. **Domain adaptation**: Techniques to improve generalization across scanner types and protocols
4. **Interpretability**: Enhanced explainability methods tailored for medical imaging

### Dataset Expansion

1. **Diverse populations**: Inclusion of underrepresented demographic groups
2. **Multi-center data**: Aggregation of data from diverse healthcare systems
3. **Rare variants**: Training on atypical AD presentations and variants
4. **Comorbidity cases**: Inclusion of patients with multiple neurological conditions

### Regulatory & Validation

1. **Regulatory approval**: Pursue FDA 510(k) or De Novo pathways for clinical use
2. **Clinical trials**: Randomized controlled trials assessing impact on diagnostic accuracy and patient outcomes
3. **Cost-effectiveness**: Health economic analyses of AI-assisted diagnosis
4. **Standardization**: Development of standardized evaluation protocols and benchmarks