# Paper 2: Utilities and Configuration

## Classification of Alzheimer's Disease using MRI Data Based on Deep Learning Techniques

This notebook contains:
- Configuration parameters for all 5 models
- Helper functions for data processing
- Evaluation metrics functions
- Visualization utilities

**Dataset:** OASIS-2 Raw MRI Data  
**Models:** CNNs-without-Aug, CNNs-with-Aug, CNNs-LSTM-with-Aug, CNNs-SVM-with-Aug, VGG16-SVM-with-Aug

## 1. Import Required Libraries

In [2]:
# Standard libraries
import os
import sys
import time
import json
import warnings
warnings.filterwarnings('ignore')

# Data processing
import numpy as np
import pandas as pd
from pathlib import Path

# Image processing
import nibabel as nib  # For reading NIfTI files
from PIL import Image
import cv2

# Deep Learning - PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms, models as torchvision_models

# Machine Learning
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report, roc_curve, auc
)
from sklearn.svm import SVC

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Device: {device}")

PyTorch version: 2.9.1+cu130
CUDA available: True
CUDA version: 13.0
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
Device: cuda


## 2. Configuration Parameters

In [3]:
# Global configuration matching Paper 2 specifications
CONFIG = {
    # Data paths
    'base_path': Path('../'),  # Paper2 directory
    'raw_data_path': Path('../Raw_Data/'),
    'demographics_file': Path('../Raw_Data/OASIS_demographic.xlsx'),
    'processed_data_path': Path('../processed_data/'),
    'saved_models_path': Path('../saved_models/'),
    'results_path': Path('../results/'),
    
    # Dataset parameters (from Paper 2)
    'target_total_images': 6400,
    'train_split': 0.80,  # 5120 images
    'test_split': 0.20,   # 1280 images
    'random_state': 42,
    
    # Binary classification
    'num_classes': 2,
    'class_names': ['Non-Demented', 'Demented'],
    'cdr_threshold': 0.5,  # CDR >= 0.5 = Demented
    
    # Slice extraction from 3D volumes
    'slices_per_volume': 4,  # Extract 4-5 representative slices
    'slice_axis': 'axial',   # Axial slices (most common in AD studies)
    
    # Model-specific parameters (from Paper 2 Table 3)
    'models': {
        'cnn_without_aug': {
            'name': 'CNNs-without-Aug',
            'input_shape': (224, 224, 3),
            'epochs': 100,
            'batch_size': 30,
            'learning_rate': 0.0001,
            'optimizer': 'adam',
            'loss': 'binary_crossentropy',
            'augmentation': False,
            'total_params': 2129250,
            'target_accuracy': 0.9922
        },
        'cnn_with_aug': {
            'name': 'CNNs-with-Aug',
            'input_shape': (128, 128, 3),
            'epochs': 100,
            'batch_size': 65,
            'learning_rate': 0.0001,
            'optimizer': 'adam',
            'loss': 'binary_crossentropy',
            'augmentation': True,
            'total_params': 6454626,
            'target_accuracy': 0.9961
        },
        'cnn_lstm_with_aug': {
            'name': 'CNNs-LSTM-with-Aug',
            'input_shape': (1, 128, 128, 3),  # Time-distributed
            'epochs': 25,
            'batch_size': 16,
            'learning_rate': 0.0001,
            'optimizer': 'adam',
            'loss': 'binary_crossentropy',
            'augmentation': True,
            'lstm_units': 100,
            'total_params': 11580858,
            'target_accuracy': 0.9992  # BEST MODEL
        },
        'cnn_svm_with_aug': {
            'name': 'CNNs-SVM-with-Aug',
            'input_shape': (224, 224, 3),
            'epochs': 20,
            'batch_size': 32,
            'learning_rate': 0.0001,
            'optimizer': 'adam',
            'loss': 'squared_hinge',  # For SVM compatibility
            'augmentation': True,
            'total_params': 206882,
            'target_accuracy': 0.9914
        },
        'vgg16_svm_with_aug': {
            'name': 'VGG16-SVM-with-Aug',
            'input_shape': (224, 224, 3),
            'svm_kernel': 'linear',
            'augmentation': True,
            'total_features': 14714688,
            'target_accuracy': 0.9867
        }
    },
    
    # Data augmentation parameters (from Paper 2)
    'augmentation_params': {
        'rotation_range': 90,        # 0-90 degrees
        'horizontal_flip': True,
        'vertical_flip': True,
        'zoom_range': 0.2,           # Random magnification
        'width_shift_range': 0.1,    # Random shifting
        'height_shift_range': 0.1,
        'fill_mode': 'nearest'
    },
    
    # Visualization parameters
    'plot_style': 'seaborn',
    'figure_size': (12, 8),
    'dpi': 100
}

# Create directories if they don't exist
for path_key in ['processed_data_path', 'saved_models_path', 'results_path']:
    CONFIG[path_key].mkdir(parents=True, exist_ok=True)

print("Configuration loaded successfully!")
print(f"\nBase path: {CONFIG['base_path'].resolve()}")
print(f"Raw data path: {CONFIG['raw_data_path'].resolve()}")
print(f"Number of models: {len(CONFIG['models'])}")

Configuration loaded successfully!

Base path: C:\Users\rishi\CV_Assignment\Paper2
Raw data path: C:\Users\rishi\CV_Assignment\Paper2\Raw_Data
Number of models: 5


## 3. Helper Functions for Data Processing

In [4]:
def load_nifti_volume(header_file_path):
    """
    Load a NIfTI volume from .hdr/.img file pair.
    
    Args:
        header_file_path (str or Path): Path to .hdr file
        
    Returns:
        numpy.ndarray: 3D volume data
    """
    try:
        # Load using nibabel
        img = nib.load(str(header_file_path))
        volume_data = img.get_fdata()
        return volume_data
    except Exception as e:
        print(f"Error loading {header_file_path}: {e}")
        return None


def extract_representative_slices(volume, num_slices=4, axis=2):
    """
    Extract representative 2D slices from 3D MRI volume.
    Focuses on middle brain region (most informative for AD).
    
    Args:
        volume (numpy.ndarray): 3D volume (H, W, D)
        num_slices (int): Number of slices to extract
        axis (int): Axis along which to slice (2=axial, 1=coronal, 0=sagittal)
        
    Returns:
        list: List of 2D numpy arrays
    """
    if volume is None:
        return []
    
    depth = volume.shape[axis]
    
    # Extract slices from middle 50% of volume (skip top and bottom)
    start_idx = depth // 4
    end_idx = 3 * depth // 4
    
    # Evenly spaced indices
    slice_indices = np.linspace(start_idx, end_idx, num_slices, dtype=int)
    
    slices = []
    for idx in slice_indices:
        if axis == 0:  # Sagittal
            slice_2d = volume[idx, :, :]
        elif axis == 1:  # Coronal
            slice_2d = volume[:, idx, :]
        else:  # Axial (default)
            slice_2d = volume[:, :, idx]
        
        slices.append(slice_2d)
    
    return slices


def preprocess_slice(slice_2d, target_size=(224, 224), normalize=True, to_rgb=True):
    """
    Preprocess a 2D MRI slice according to Paper 2 methodology:
    1. Resize to target size
    2. Normalize pixel values to [0, 1]
    3. Convert to RGB (3-channel) format
    
    Args:
        slice_2d (numpy.ndarray): 2D grayscale slice
        target_size (tuple): Target (height, width)
        normalize (bool): Whether to normalize to [0, 1]
        to_rgb (bool): Whether to convert to 3-channel RGB
        
    Returns:
        numpy.ndarray: Preprocessed slice or None if slice is invalid
    """
    # Handle NaN and infinite values
    slice_2d = np.nan_to_num(slice_2d, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Ensure slice is 2D
    if len(slice_2d.shape) > 2:
        # If 3D, take the first 2D slice
        slice_2d = np.squeeze(slice_2d)
        if len(slice_2d.shape) > 2:
            slice_2d = slice_2d[:, :, 0] if slice_2d.shape[2] == 1 else slice_2d[0, :, :]
    
    # Check if slice has valid dimensions (at least 2x2)
    if slice_2d.shape[0] < 2 or slice_2d.shape[1] < 2:
        return None
    
    # Normalize to [0, 255] range for uint8
    slice_min = slice_2d.min()
    slice_max = slice_2d.max()
    if slice_max > slice_min:
        slice_normalized = ((slice_2d - slice_min) / (slice_max - slice_min) * 255).astype(np.uint8)
    else:
        slice_normalized = np.zeros_like(slice_2d, dtype=np.uint8)
    
    # Resize using cv2 (more robust than PIL for edge cases)
    slice_resized = cv2.resize(slice_normalized, target_size, interpolation=cv2.INTER_LINEAR)
    
    # Convert to RGB (3-channel) if needed
    if to_rgb:
        if len(slice_resized.shape) == 2:
            slice_rgb = np.stack([slice_resized] * 3, axis=-1)
        else:
            slice_rgb = slice_resized
    else:
        slice_rgb = slice_resized
    
    # Normalize to [0, 1] if required
    if normalize:
        slice_final = slice_rgb.astype(np.float32) / 255.0
    else:
        slice_final = slice_rgb
    
    return slice_final


def load_demographics(demographics_path):
    """
    Load OASIS demographics file and extract CDR scores for labeling.
    
    Args:
        demographics_path (str or Path): Path to OASIS_demographic.xlsx
        
    Returns:
        pandas.DataFrame: Demographics with Subject ID and CDR score
    """
    df = pd.read_excel(demographics_path)
    return df


def get_binary_label(cdr_score, threshold=0.5):
    """
    Convert CDR score to binary label.
    
    CDR (Clinical Dementia Rating) scale:
    - 0 = No dementia (Non-Demented)
    - 0.5 = Very mild dementia (Demented)
    - 1.0 = Mild dementia (Demented)
    - 2.0 = Moderate dementia (Demented)
    - 3.0 = Severe dementia (Demented)
    
    Args:
        cdr_score (float): CDR score
        threshold (float): Threshold for classification
        
    Returns:
        int: 0 (Non-Demented) or 1 (Demented)
    """
    if pd.isna(cdr_score):
        return None
    return 1 if cdr_score >= threshold else 0


print("Data processing functions loaded successfully!")

Data processing functions loaded successfully!


## 4. Evaluation Metrics Functions

In [5]:
def calculate_all_metrics(y_true, y_pred, y_pred_proba=None):
    """
    Calculate all evaluation metrics as specified in Paper 2:
    - Accuracy
    - Precision
    - Recall
    - F1-score
    - Specificity
    
    Args:
        y_true (array-like): True labels
        y_pred (array-like): Predicted labels
        y_pred_proba (array-like, optional): Prediction probabilities
        
    Returns:
        dict: Dictionary containing all metrics
    """
    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()
    
    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='binary', zero_division=0)
    recall = recall_score(y_true, y_pred, average='binary', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='binary', zero_division=0)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'specificity': specificity,
        'confusion_matrix': cm,
        'tp': tp,
        'tn': tn,
        'fp': fp,
        'fn': fn
    }
    
    # Add AUC if probabilities provided
    if y_pred_proba is not None:
        fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
        auc_score = auc(fpr, tpr)
        metrics['auc'] = auc_score
        metrics['fpr'] = fpr
        metrics['tpr'] = tpr
    
    return metrics


def print_metrics(metrics, model_name="Model"):
    """
    Print evaluation metrics in a formatted way.
    
    Args:
        metrics (dict): Dictionary of metrics from calculate_all_metrics
        model_name (str): Name of the model
    """
    print(f"\n{'='*60}")
    print(f"{model_name} - Evaluation Metrics")
    print(f"{'='*60}")
    print(f"Accuracy:    {metrics['accuracy']*100:.2f}%")
    print(f"Precision:   {metrics['precision']*100:.2f}%")
    print(f"Recall:      {metrics['recall']*100:.2f}%")
    print(f"F1-Score:    {metrics['f1_score']*100:.2f}%")
    print(f"Specificity: {metrics['specificity']*100:.2f}%")
    if 'auc' in metrics:
        print(f"AUC:         {metrics['auc']:.4f}")
    print(f"\nConfusion Matrix:")
    print(f"  TN={metrics['tn']}, FP={metrics['fp']}")
    print(f"  FN={metrics['fn']}, TP={metrics['tp']}")
    print(f"{'='*60}\n")


print("Evaluation metrics functions loaded successfully!")

Evaluation metrics functions loaded successfully!


## 5. Visualization Functions

In [6]:
def plot_confusion_matrix(cm, class_names, title="Confusion Matrix", save_path=None):
    """
    Plot confusion matrix with annotations.
    
    Args:
        cm (numpy.ndarray): Confusion matrix
        class_names (list): List of class names
        title (str): Plot title
        save_path (str, optional): Path to save figure
    """
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar=True, square=True)
    plt.title(title, fontsize=14, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()


def plot_training_history(history, model_name="Model", save_path=None):
    """
    Plot training and validation accuracy/loss curves.
    
    Args:
        history (dict): Training history with keys 'train_acc', 'train_loss', 'val_acc', 'val_loss'
        model_name (str): Name of the model
        save_path (str, optional): Path to save figure
    """
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Accuracy plot
    if 'train_acc' in history:
        axes[0].plot(history['train_acc'], label='Train', linewidth=2)
    if 'val_acc' in history:
        axes[0].plot(history['val_acc'], label='Validation', linewidth=2)
    axes[0].set_title(f'{model_name} - Model Accuracy', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Accuracy', fontsize=12)
    axes[0].legend(loc='lower right')
    axes[0].grid(True, alpha=0.3)
    
    # Loss plot
    if 'train_loss' in history:
        axes[1].plot(history['train_loss'], label='Train', linewidth=2)
    if 'val_loss' in history:
        axes[1].plot(history['val_loss'], label='Validation', linewidth=2)
    axes[1].set_title(f'{model_name} - Model Loss', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Loss', fontsize=12)
    axes[1].legend(loc='upper right')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()


def plot_roc_curve(metrics, model_name="Model", save_path=None):
    """
    Plot ROC curve.
    
    Args:
        metrics (dict): Metrics dictionary with 'fpr', 'tpr', 'auc'
        model_name (str): Name of the model
        save_path (str, optional): Path to save figure
    """
    if 'fpr' not in metrics or 'tpr' not in metrics:
        print("FPR and TPR not available in metrics")
        return
    
    plt.figure(figsize=(8, 6))
    plt.plot(metrics['fpr'], metrics['tpr'], linewidth=2, 
             label=f"AUC = {metrics.get('auc', 0):.4f}")
    plt.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random Classifier')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate', fontsize=12)
    plt.title(f'{model_name} - ROC Curve', fontsize=14, fontweight='bold')
    plt.legend(loc='lower right')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()


def plot_sample_slices(images, labels, class_names, num_samples=10, title="Sample MRI Slices"):
    """
    Plot sample MRI slices with their labels.
    
    Args:
        images (numpy.ndarray or torch.Tensor): Array of images
        labels (numpy.ndarray or torch.Tensor): Array of labels
        class_names (list): List of class names
        num_samples (int): Number of samples to plot
        title (str): Plot title
    """
    # Convert tensors to numpy if needed
    if isinstance(images, torch.Tensor):
        images = images.cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    
    indices = np.random.choice(len(images), num_samples, replace=False)
    
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    axes = axes.ravel()
    
    for i, idx in enumerate(indices):
        img = images[idx]
        label = labels[idx]
        
        # Handle different image formats
        if len(img.shape) == 3 and img.shape[0] == 3:
            # PyTorch format (C, H, W) - convert to (H, W, C)
            display_img = np.transpose(img, (1, 2, 0))
        elif len(img.shape) == 3 and img.shape[-1] == 3:
            # RGB image (H, W, C)
            display_img = img
        elif len(img.shape) == 3:
            # Take first channel
            display_img = img[0] if img.shape[0] < img.shape[-1] else img[:, :, 0]
        else:
            display_img = img
        
        axes[i].imshow(display_img, cmap='gray')
        axes[i].set_title(f"{class_names[int(label)]}")
        axes[i].axis('off')
    
    plt.suptitle(title, fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()


print("Visualization functions loaded successfully!")

Visualization functions loaded successfully!


## 6. Data Augmentation Setup

In [7]:
def create_data_augmentation():
    """
    Create transforms composition with augmentation parameters from Paper 2.
    
    Returns:
        torchvision.transforms.Compose: Configured data augmentation transforms
    """
    aug_params = CONFIG['augmentation_params']
    
    # PyTorch transforms for data augmentation
    train_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomRotation(degrees=(0, aug_params['rotation_range'])),
        transforms.RandomHorizontalFlip(p=0.5 if aug_params['horizontal_flip'] else 0),
        transforms.RandomVerticalFlip(p=0.5 if aug_params['vertical_flip'] else 0),
        transforms.RandomAffine(degrees=0, 
                                translate=(aug_params['width_shift_range'], aug_params['height_shift_range']),
                                scale=(1-aug_params['zoom_range'], 1+aug_params['zoom_range'])),
        transforms.ToTensor(),
    ])
    
    # No augmentation for validation/test
    test_transforms = transforms.Compose([
        transforms.ToTensor(),
    ])
    
    return train_transforms, test_transforms


# Custom Dataset class for augmentation
class AugmentedDataset(Dataset):
    """Custom Dataset with augmentation support."""
    
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]
        
        # Convert from PyTorch format (C, H, W) to PIL format (H, W, C)
        if len(image.shape) == 3 and image.shape[0] in [1, 3]:
            # Already in (C, H, W) format, convert to (H, W, C)
            image = np.transpose(image, (1, 2, 0))
        
        # Convert to uint8 for transforms if needed
        if image.dtype != np.uint8:
            image = (image * 255).astype(np.uint8)
        
        # Remove single channel dimension if grayscale
        if image.shape[-1] == 1:
            image = image.squeeze(-1)
        
        if self.transform:
            image = self.transform(image)
        else:
            # Convert to tensor if no transform
            if not isinstance(image, torch.Tensor):
                # Convert back to (C, H, W) format for PyTorch
                if len(image.shape) == 3:
                    image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
                else:
                    image = torch.from_numpy(image).unsqueeze(0).float() / 255.0
        
        return image, torch.tensor(label, dtype=torch.long)


# Memory-mapped dataset for large files
class MemoryMappedDataset(Dataset):
    """Memory-efficient dataset using memory-mapped numpy arrays."""
    
    def __init__(self, X_path, y_path, transform=None, normalize=True):
        """
        Args:
            X_path: Path to .npy file with images
            y_path: Path to .npy file with labels
            transform: Optional transform to apply
            normalize: Whether to normalize to [0, 1]
        """
        # Load labels (small, can fit in memory)
        self.labels = np.load(y_path)
        
        # Memory-map the image data (doesn't load into RAM)
        self.images = np.load(X_path, mmap_mode='r')
        
        self.transform = transform
        self.normalize = normalize
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Load single image from memory-mapped file
        image = np.array(self.images[idx])  # Load only this one image
        label = self.labels[idx]
        
        # Normalize if needed
        if self.normalize and image.dtype == np.uint8:
            image = image.astype('float32') / 255.0
        
        # Convert to PyTorch format (C, H, W) if needed
        if len(image.shape) == 3 and image.shape[-1] in [1, 3]:
            # (H, W, C) -> (C, H, W)
            image = np.transpose(image, (2, 0, 1))
        
        # Apply transforms if provided
        if self.transform:
            # Transform expects (H, W, C), so convert back
            image = np.transpose(image, (1, 2, 0))
            if image.dtype != np.uint8:
                image = (image * 255).astype(np.uint8)
            image = self.transform(image)
        else:
            # Convert to tensor
            image = torch.from_numpy(image).float()
        
        return image, torch.tensor(label, dtype=torch.long)


print("Data augmentation setup loaded successfully!")

Data augmentation setup loaded successfully!


## 7. Summary

In [8]:
print("\n" + "="*80)
print("PAPER 2 UTILITIES AND CONFIGURATION - SUMMARY")
print("="*80)
print("\n✓ All libraries imported successfully")
print("✓ Configuration parameters loaded")
print("✓ Data processing functions defined")
print("✓ Evaluation metrics functions defined")
print("✓ Visualization functions defined")
print("✓ Data augmentation configured")
print("\nReady to proceed with:")
print("  - Notebook 01: Data Preparation")
print("  - Notebooks 02-06: Model Implementations")
print("  - Notebook 07: Results Comparison")
print("\n" + "="*80)


PAPER 2 UTILITIES AND CONFIGURATION - SUMMARY

✓ All libraries imported successfully
✓ Configuration parameters loaded
✓ Data processing functions defined
✓ Evaluation metrics functions defined
✓ Visualization functions defined
✓ Data augmentation configured

Ready to proceed with:
  - Notebook 01: Data Preparation
  - Notebooks 02-06: Model Implementations
  - Notebook 07: Results Comparison

