# Alzheimer's Disease Prediction Model

A deep learning model for classifying Alzheimer's disease severity from MRI images.

**Classes:**
- NonDemented
- VeryMildDemented
- MildDemented
- ModerateDemented

**Models:** CNN and ResNet architectures

**Features:**
- GPU optimization
- Data augmentation
- Hyperparameter tuning
- Comprehensive evaluation


## 1. Setup and Installation

First, let's install required packages and setup the environment.

In [None]:
# Install required packages
!pip install torch torchvision torchaudio
!pip install scikit-learn matplotlib seaborn
!pip install opencv-python-headless
!pip install nibabel

# Import libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2
import random
import warnings
from pathlib import Path
import json
from datetime import datetime

# PyTorch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models

# Sklearn imports
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
from sklearn.model_selection import ParameterGrid

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Suppress warnings
warnings.filterwarnings('ignore')

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB')

## 2. Data Upload

Upload your Alzheimer's dataset to Google Colab. The dataset should have the following structure:
```
Alzheimer_s Dataset/
├── train/
│   ├── MildDemented/
│   ├── ModerateDemented/
│   ├── NonDemented/
│   └── VeryMildDemented/
└── test/
    ├── MildDemented/
    ├── ModerateDemented/
    ├── NonDemented/
    └── VeryMildDemented/
```

In [None]:
# Option 1: Upload via Google Colab files
from google.colab import files
import zipfile

# Uncomment the following lines to upload your dataset as a zip file
# uploaded = files.upload()
# for filename in uploaded.keys():
#     with zipfile.ZipFile(filename, 'r') as zip_ref:
#         zip_ref.extractall('/')

# Option 2: Mount Google Drive (recommended)
from google.colab import drive
drive.mount('/content/drive')

# Set the path to your dataset (adjust as needed)
# If you uploaded via option 1:
# DATA_DIR = '/content/Alzheimer_s Dataset'

# If you're using Google Drive:
DATA_DIR = '/content/drive/MyDrive/Alzheimer_s Dataset'

print(f'Data directory: {DATA_DIR}')
print(f'Directory exists: {os.path.exists(DATA_DIR)}')

## 3. Data Preprocessing

Custom dataset class and data preprocessing functions.

In [None]:
class AlzheimerDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None, class_names=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.class_names = class_names or ['NonDemented', 'VeryMildDemented', 'MildDemented', 'ModerateDemented']
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        try:
            # Load image
            image = Image.open(image_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
            
            return image, label
        except Exception as e:
            print(f'Error loading image {image_path}: {e}')
            # Return a black image in case of error
            if self.transform:
                image = self.transform(Image.new('RGB', (224, 224), (0, 0, 0)))
            else:
                image = torch.zeros(3, 224, 224)
            return image, label

def get_transforms(augment=True):
    """Get image transforms for training and validation"""
    if augment:
        # Training transforms with augmentation
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    # Validation/test transforms (no augmentation)
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform

In [None]:
def load_data_from_folders(data_dir, class_names=None):
    """Load image paths and labels from folder structure"""
    if class_names is None:
        class_names = ['NonDemented', 'VeryMildDemented', 'MildDemented', 'ModerateDemented']
    
    class_to_idx = {class_name: idx for idx, class_name in enumerate(class_names)}
    
    all_image_paths = []
    all_labels = []
    
    for split in ['train', 'test']:
        split_dir = os.path.join(data_dir, split)
        if not os.path.exists(split_dir):
            continue
            
        for class_name in class_names:
            class_dir = os.path.join(split_dir, class_name)
            if not os.path.exists(class_dir):
                continue
                
            # Get all image files
            image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
            image_files = []
            for ext in image_extensions:
                image_files.extend([f for f in os.listdir(class_dir) if f.lower().endswith(ext)])
            
            for image_file in image_files:
                image_path = os.path.join(class_dir, image_file)
                all_image_paths.append(image_path)
                all_labels.append(class_to_idx[class_name])
    
    return all_image_paths, all_labels, class_names, class_to_idx

def split_train_val_test(data_dir, val_split=0.2, batch_size=32, augment=True, num_workers=2):
    """Create train, validation, and test dataloaders"""
    class_names = ['NonDemented', 'VeryMildDemented', 'MildDemented', 'ModerateDemented']
    class_to_idx = {class_name: idx for idx, class_name in enumerate(class_names)}
    
    # Load training data
    train_dir = os.path.join(data_dir, 'train')
    train_paths, train_labels = [], []
    
    for class_name in class_names:
        class_dir = os.path.join(train_dir, class_name)
        if os.path.exists(class_dir):
            image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
            image_files = []
            for ext in image_extensions:
                image_files.extend([f for f in os.listdir(class_dir) if f.lower().endswith(ext)])
            
            for image_file in image_files:
                image_path = os.path.join(class_dir, image_file)
                train_paths.append(image_path)
                train_labels.append(class_to_idx[class_name])
    
    # Load test data
    test_dir = os.path.join(data_dir, 'test')
    test_paths, test_labels = [], []
    
    for class_name in class_names:
        class_dir = os.path.join(test_dir, class_name)
        if os.path.exists(class_dir):
            image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
            image_files = []
            for ext in image_extensions:
                image_files.extend([f for f in os.listdir(class_dir) if f.lower().endswith(ext)])
            
            for image_file in image_files:
                image_path = os.path.join(class_dir, image_file)
                test_paths.append(image_path)
                test_labels.append(class_to_idx[class_name])
    
    # Split training data into train and validation
    train_size = int((1 - val_split) * len(train_paths))
    val_size = len(train_paths) - train_size
    
    # Create indices for splitting
    indices = list(range(len(train_paths)))
    random.shuffle(indices)
    
    train_indices = indices[:train_size]
    val_indices = indices[train_size:]
    
    # Split the data
    train_paths_split = [train_paths[i] for i in train_indices]
    train_labels_split = [train_labels[i] for i in train_indices]
    val_paths = [train_paths[i] for i in val_indices]
    val_labels = [train_labels[i] for i in val_indices]
    
    # Get transforms
    train_transform, val_transform = get_transforms(augment)
    
    # Create datasets
    train_dataset = AlzheimerDataset(train_paths_split, train_labels_split, train_transform, class_names)
    val_dataset = AlzheimerDataset(val_paths, val_labels, val_transform, class_names)
    test_dataset = AlzheimerDataset(test_paths, test_labels, val_transform, class_names)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    
    # Print statistics
    print(f'Dataset Statistics:')
    print(f'Training samples: {len(train_dataset)}')
    print(f'Validation samples: {len(val_dataset)}')
    print(f'Test samples: {len(test_dataset)}')
    print(f'Number of classes: {len(class_names)}')
    print(f'Class names: {class_names}')
    
    return train_loader, val_loader, test_loader, class_names, class_to_idx

## 4. Model Architecture

Define CNN and ResNet models for Alzheimer's classification.

In [None]:
class CNNModel(nn.Module):
    def __init__(self, num_classes=4, dropout_rate=0.5):
        super(CNNModel, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        
        self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(128)
        
        # Pooling and dropout
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(dropout_rate)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Fully connected layers
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, num_classes)
        
    def forward(self, x):
        # First conv block
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = self.dropout(x)
        
        # Second conv block
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool(x)
        x = self.dropout(x)
        
        # Third conv block
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = self.pool(x)
        x = self.dropout(x)
        
        # Global average pooling
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x

In [None]:
class ResNetModel(nn.Module):
    def __init__(self, num_classes=4, dropout_rate=0.5, pretrained=True):
        super(ResNetModel, self).__init__()
        
        # Load pretrained ResNet18
        self.resnet = models.resnet18(pretrained=pretrained)
        
        # Replace the final layer
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, x):
        return self.resnet(x)

def get_model(model_type='cnn', num_classes=4, dropout_rate=0.5):
    """Factory function to get model"""
    if model_type.lower() == 'cnn':
        return CNNModel(num_classes=num_classes, dropout_rate=dropout_rate)
    elif model_type.lower() == 'resnet':
        return ResNetModel(num_classes=num_classes, dropout_rate=dropout_rate)
    else:
        raise ValueError(f'Unknown model type: {model_type}. Choose from ["cnn", "resnet"]')

## 5. Training Functions

Functions for training and evaluating the model.

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train the model for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        if batch_idx % 10 == 0:
            print(f'Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

def evaluate_model(model, data_loader, criterion, device):
    """Evaluate the model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            running_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    epoch_loss = running_loss / len(data_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc, all_predictions, all_targets

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, 
                num_epochs, device, save_path=None):
    """Train the model for multiple epochs"""
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    best_val_acc = 0.0
    best_model_state = None
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 50)
        
        # Training phase
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validation phase
        val_loss, val_acc, _, _ = evaluate_model(model, val_loader, criterion, device)
        
        # Update learning rate
        if scheduler:
            scheduler.step(val_loss)
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
            if save_path:
                torch.save(model.state_dict(), save_path)
                print(f'New best model saved with validation accuracy: {val_acc:.2f}%')
    
    # Load best model
    if best_model_state:
        model.load_state_dict(best_model_state)
    
    return history

## 6. Visualization Functions

Functions for plotting training history and evaluation metrics.

In [None]:
def plot_training_history(history, save_path=None):
    """Plot training and validation loss and accuracy"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot loss
    ax1.plot(history['train_loss'], label='Training Loss', marker='o')
    ax1.plot(history['val_loss'], label='Validation Loss', marker='s')
    ax1.set_title('Model Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Plot accuracy
    ax2.plot(history['train_acc'], label='Training Accuracy', marker='o')
    ax2.plot(history['val_acc'], label='Validation Accuracy', marker='s')
    ax2.set_title('Model Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

def plot_confusion_matrix(y_true, y_pred, class_names, save_path=None):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

def visualize_batch(data_loader, class_names, num_samples=8):
    """Visualize a batch of images with their labels"""
    # Get a batch of data
    data_iter = iter(data_loader)
    images, labels = next(data_iter)
    
    # Select samples to display
    num_samples = min(num_samples, len(images))
    
    # Create subplot
    fig, axes = plt.subplots(2, 4, figsize=(15, 8))
    axes = axes.ravel()
    
    for i in range(num_samples):
        # Denormalize image
        img = images[i]
        img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        img = torch.clamp(img, 0, 1)
        
        # Convert to numpy and transpose
        img_np = img.permute(1, 2, 0).numpy()
        
        # Plot
        axes[i].imshow(img_np)
        axes[i].set_title(f'Class: {class_names[labels[i]]}')
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

## 7. Hyperparameter Tuning

Grid search for optimal hyperparameters.

In [None]:
def hyperparameter_tuning(train_loader, val_loader, num_classes, device, num_epochs=10):
    """Perform hyperparameter tuning using grid search"""
    
    # Define parameter grid
    param_grid = {
        'model_type': ['cnn', 'resnet'],
        'optimizer': ['adam', 'sgd'],
        'learning_rate': [0.001, 0.0001],
        'use_scheduler': [True, False]
    }
    
    best_params = None
    best_val_acc = 0.0
    results = []
    
    # Grid search
    for params in ParameterGrid(param_grid):
        print(f'\nTesting parameters: {params}')
        print('-' * 50)
        
        # Create model
        model = get_model(model_type=params['model_type'], num_classes=num_classes)
        model = model.to(device)
        
        # Define loss function
        criterion = nn.CrossEntropyLoss()
        
        # Define optimizer
        if params['optimizer'] == 'adam':
            optimizer = optim.Adam(model.parameters(), lr=params['learning_rate'])
        else:
            optimizer = optim.SGD(model.parameters(), lr=params['learning_rate'], momentum=0.9)
        
        # Define scheduler
        scheduler = None
        if params['use_scheduler']:
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
        
        # Train model
        history = train_model(model, train_loader, val_loader, criterion, optimizer, 
                              scheduler, num_epochs, device)
        
        # Get best validation accuracy
        val_acc = max(history['val_acc'])
        
        # Store results
        result = params.copy()
        result['val_acc'] = val_acc
        results.append(result)
        
        print(f'Validation Accuracy: {val_acc:.2f}%')
        
        # Update best parameters
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_params = params.copy()
    
    print(f'\nBest parameters: {best_params}')
    print(f'Best validation accuracy: {best_val_acc:.2f}%')
    
    # Convert results to DataFrame for easier analysis
    results_df = pd.DataFrame(results)
    print(f'\nAll results:')
    print(results_df.sort_values('val_acc', ascending=False))
    
    return best_params, best_val_acc, results_df

## 8. Main Execution

Load data, train model, and evaluate performance.

In [None]:
# Configuration
BATCH_SIZE = 32
NUM_EPOCHS = 20
LEARNING_RATE = 0.001
MODEL_TYPE = 'resnet'  # 'cnn' or 'resnet'
USE_AUGMENTATION = True
DROPOUT_RATE = 0.5
VAL_SPLIT = 0.2
NUM_WORKERS = 2

# Create output directory
output_dir = '/content/results'
os.makedirs(output_dir, exist_ok=True)

print(f'Configuration:')
print(f'Data directory: {DATA_DIR}')
print(f'Batch size: {BATCH_SIZE}')
print(f'Number of epochs: {NUM_EPOCHS}')
print(f'Learning rate: {LEARNING_RATE}')
print(f'Model type: {MODEL_TYPE}')
print(f'Use augmentation: {USE_AUGMENTATION}')
print(f'Device: {device}')


In [None]:
# Load and prepare data
print('Loading data...')
train_loader, val_loader, test_loader, class_names, class_to_idx = split_train_val_test(
    DATA_DIR, 
    val_split=VAL_SPLIT,
    batch_size=BATCH_SIZE,
    augment=USE_AUGMENTATION,
    num_workers=NUM_WORKERS
)

print(f'\nClass mapping: {class_to_idx}')
print(f'Number of classes: {len(class_names)}')

In [None]:
# Visualize some training data
print('Visualizing training data samples...')
visualize_batch(train_loader, class_names, num_samples=8)

In [None]:
# Create model
print(f'Creating {MODEL_TYPE} model...')
model = get_model(model_type=MODEL_TYPE, num_classes=len(class_names), dropout_rate=DROPOUT_RATE)
model = model.to(device)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

print(f'Model created successfully!')

## 9. Optional: Hyperparameter Tuning

Uncomment the following cell to perform hyperparameter tuning (this will take longer).

In [None]:
# Uncomment to run hyperparameter tuning
# print('Starting hyperparameter tuning...')
# best_params, best_val_acc, results_df = hyperparameter_tuning(
#     train_loader, val_loader, len(class_names), device, num_epochs=10
# )
# 
# # Update model with best parameters
# MODEL_TYPE = best_params['model_type']
# LEARNING_RATE = best_params['learning_rate']
# 
# # Recreate model with best parameters
# model = get_model(model_type=MODEL_TYPE, num_classes=len(class_names), dropout_rate=DROPOUT_RATE)
# model = model.to(device)
# 
# if best_params['optimizer'] == 'adam':
#     optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# else:
#     optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
# 
# if best_params['use_scheduler']:
#     scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
# else:
#     scheduler = None
# 
# print(f'Using best parameters: {best_params}')

In [None]:
# Train the model
print('Starting training...')
history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=NUM_EPOCHS,
    device=device,
    save_path=os.path.join(output_dir, 'best_model.pth')
)

print('Training completed!')

In [None]:
# Plot training history
print('Plotting training history...')
plot_training_history(history, save_path=os.path.join(output_dir, 'training_history.png'))

## 10. Model Evaluation

Evaluate the trained model on the test set.

In [None]:
# Evaluate on test set
print('Evaluating on test set...')
test_loss, test_acc, test_predictions, test_targets = evaluate_model(
    model, test_loader, criterion, device
)

print(f'Test Loss: {test_loss:.4f}')
print(f'Test Accuracy: {test_acc:.2f}%')

# Calculate detailed metrics
precision, recall, f1_score, support = precision_recall_fscore_support(
    test_targets, test_predictions, average=None
)

# Print per-class metrics
print('\nPer-class metrics:')
for i, class_name in enumerate(class_names):
    print(f'{class_name}:')
    print(f'  Precision: {precision[i]:.4f}')
    print(f'  Recall: {recall[i]:.4f}')
    print(f'  F1-score: {f1_score[i]:.4f}')
    print(f'  Support: {support[i]}')
    print()

# Overall metrics
overall_precision = np.mean(precision)
overall_recall = np.mean(recall)
overall_f1 = np.mean(f1_score)

print(f'Overall Metrics:')
print(f'Precision: {overall_precision:.4f}')
print(f'Recall: {overall_recall:.4f}')
print(f'F1-score: {overall_f1:.4f}')

# Classification report
print('\nClassification Report:')
print(classification_report(test_targets, test_predictions, target_names=class_names))

In [None]:
# Plot confusion matrix
print('Plotting confusion matrix...')
plot_confusion_matrix(
    test_targets, test_predictions, class_names, 
    save_path=os.path.join(output_dir, 'confusion_matrix.png')
)

## 11. Save Results

Save model, metrics, and configuration.

In [None]:
# Save metrics and configuration
results = {
    'model_type': MODEL_TYPE,
    'num_epochs': NUM_EPOCHS,
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
    'dropout_rate': DROPOUT_RATE,
    'use_augmentation': USE_AUGMENTATION,
    'test_accuracy': test_acc,
    'test_loss': test_loss,
    'overall_precision': overall_precision,
    'overall_recall': overall_recall,
    'overall_f1': overall_f1,
    'class_names': class_names,
    'class_to_idx': class_to_idx,
    'per_class_metrics': {
        class_names[i]: {
            'precision': precision[i],
            'recall': recall[i],
            'f1_score': f1_score[i],
            'support': int(support[i])
        } for i in range(len(class_names))
    },
    'training_history': history,
    'timestamp': datetime.now().isoformat()
}

# Save results as JSON
with open(os.path.join(output_dir, 'results.json'), 'w') as f:
    json.dump(results, f, indent=2)

print(f'Results saved to {output_dir}')
print(f'Files saved:')
print(f'- best_model.pth: Trained model weights')
print(f'- results.json: Detailed metrics and configuration')
print(f'- training_history.png: Training curves')
print(f'- confusion_matrix.png: Confusion matrix')

# List all files in output directory
print(f'\nOutput directory contents:')
for file in os.listdir(output_dir):
    print(f'- {file}')

## 12. Inference Function

Function to make predictions on new images.

In [None]:
def predict_image(model, image_path, class_names, device, transform=None):
    """Make prediction on a single image"""
    model.eval()
    
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    
    if transform is None:
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Make prediction
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = F.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs, 1)
    
    # Get prediction results
    predicted_class = class_names[predicted.item()]
    confidence = probabilities[0][predicted.item()].item()
    
    # Get all class probabilities
    class_probabilities = {
        class_names[i]: probabilities[0][i].item() 
        for i in range(len(class_names))
    }
    
    return predicted_class, confidence, class_probabilities

# Example usage (uncomment to test with a specific image)
# image_path = '/path/to/your/test/image.jpg'
# predicted_class, confidence, class_probs = predict_image(model, image_path, class_names, device)
# print(f'Predicted class: {predicted_class}')
# print(f'Confidence: {confidence:.4f}')
# print(f'All class probabilities: {class_probs}')

## 13. Download Results

Download the trained model and results.

In [None]:
# Create a zip file with all results
import zipfile

zip_path = '/content/alzheimer_model_results.zip'

with zipfile.ZipFile(zip_path, 'w') as zipf:
    for root, dirs, files in os.walk(output_dir):
        for file in files:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, output_dir)
            zipf.write(file_path, arcname)

print(f'Results packaged in: {zip_path}')

# Download the zip file
from google.colab import files
files.download(zip_path)

print('Download started! Check your browser downloads.')

## Conclusion

This notebook provides a complete pipeline for Alzheimer's disease classification using deep learning:

1. **Data Loading**: Supports the standard train/test folder structure
2. **Preprocessing**: Image augmentation and normalization
3. **Models**: CNN and ResNet architectures
4. **Training**: GPU-optimized training with validation
5. **Evaluation**: Comprehensive metrics and visualizations
6. **Hyperparameter Tuning**: Grid search for optimal parameters
7. **Inference**: Easy prediction on new images

**Key Features:**
- 4-class classification (NonDemented, VeryMildDemented, MildDemented, ModerateDemented)
- GPU acceleration with Google Colab
- Data augmentation for better generalization
- Comprehensive evaluation metrics
- Model saving and downloading

**To use this notebook:**
1. Upload your dataset to Google Drive or directly to Colab
2. Set the correct DATA_DIR path
3. Run all cells sequentially
4. Download the trained model and results

The trained model can be used for inference on new MRI images to predict Alzheimer's disease severity.