## Step 1: Setup and Data Download

First, install the Kaggle API and download the dataset.

In [None]:
# Install required packages
# !pip install kaggle pillow torch torchvision scikit-learn pandas numpy matplotlib seaborn

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Set up paths
PROJECT_ROOT = Path(r"c:\Users\ottav\OneDrive - Politecnico di Milano\Desktop\ComplexData\AlzheimerComplexDataProject")
DATA_DIR = PROJECT_ROOT / "Data" / "OASIS"
DATA_DIR.mkdir(parents=True, exist_ok=True)

print(f"Data directory: {DATA_DIR}")

### Download Dataset from Kaggle

**Manual steps**:
1. Go to https://www.kaggle.com/datasets/ninadaithal/imagesoasis/data
2. Click "Download" and extract to the `Data/OASIS/` directory
3. The structure should be:
   ```
   Data/OASIS/
   ├── oasis_cross-sectional.csv  (metadata)
   └── [MRI image folders]/
   ```

**Or use Kaggle API**:
```python
# Make sure you have ~/.kaggle/kaggle.json configured
# !kaggle datasets download -d ninadaithal/imagesoasis -p Data/OASIS --unzip
```

## Step 2: Exploratory Data Analysis (EDA)

Load and explore the metadata to understand the dataset structure.

In [None]:
# Load metadata
csv_path = DATA_DIR / "oasis_cross-sectional.csv"

# If you haven't downloaded yet, this will fail - follow Step 1 first
df = pd.read_csv(csv_path)

print(f"Dataset shape: {df.shape}")
print(f"\nColumn names:\n{df.columns.tolist()}")
print(f"\nFirst few rows:")
df.head()

In [None]:
# Display dataset statistics
print("Dataset Information:")
print(df.info())
print("\n" + "="*50)
print("Statistical Summary:")
print(df.describe())

In [None]:
# Check for missing values
print("Missing values per column:")
print(df.isnull().sum())

# Distribution of key variables
print(f"\nCDR (Clinical Dementia Rating) distribution:")
print(df['CDR'].value_counts().sort_index())

print(f"\nGender distribution:")
print(df['M/F'].value_counts())

In [None]:
# Visualize key distributions
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# CDR distribution
df['CDR'].value_counts().sort_index().plot(kind='bar', ax=axes[0, 0], color='steelblue')
axes[0, 0].set_title('Clinical Dementia Rating (CDR) Distribution')
axes[0, 0].set_xlabel('CDR Score')
axes[0, 0].set_ylabel('Count')

# Age distribution
df['Age'].hist(bins=20, ax=axes[0, 1], color='coral', edgecolor='black')
axes[0, 1].set_title('Age Distribution')
axes[0, 1].set_xlabel('Age')
axes[0, 1].set_ylabel('Frequency')

# MMSE distribution (if available)
if 'MMSE' in df.columns:
    df['MMSE'].dropna().hist(bins=15, ax=axes[1, 0], color='green', alpha=0.7, edgecolor='black')
    axes[1, 0].set_title('MMSE Score Distribution')
    axes[1, 0].set_xlabel('MMSE Score')
    axes[1, 0].set_ylabel('Frequency')

# CDR by Gender
if 'M/F' in df.columns:
    pd.crosstab(df['M/F'], df['CDR']).plot(kind='bar', ax=axes[1, 1], color=['lightblue', 'pink'])
    axes[1, 1].set_title('CDR Distribution by Gender')
    axes[1, 1].set_xlabel('Gender')
    axes[1, 1].set_ylabel('Count')
    axes[1, 1].legend(title='CDR', bbox_to_anchor=(1.05, 1))

plt.tight_layout()
plt.show()

## Step 3: Create Classification Labels

Create binary labels for dementia classification:
- **0 (Non-Demented)**: CDR = 0
- **1 (Demented)**: CDR > 0

In [None]:
# Create binary classification labels
df['Demented'] = (df['CDR'] > 0).astype(int)

print("Class distribution:")
print(df['Demented'].value_counts())
print(f"\nClass balance: {df['Demented'].value_counts(normalize=True)}")

# Visualize
plt.figure(figsize=(8, 5))
df['Demented'].value_counts().plot(kind='bar', color=['green', 'red'], alpha=0.7)
plt.title('Dementia Classification Distribution')
plt.xlabel('Class (0=Non-Demented, 1=Demented)')
plt.ylabel('Count')
plt.xticks(rotation=0)
plt.show()

## Step 4: Custom PyTorch Dataset

Create a custom dataset class to load MRI images and their labels.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import glob

class OASISDataset(Dataset):
    """
    Custom Dataset for OASIS MRI images
    """
    def __init__(self, dataframe, img_dir, transform=None, task='classification'):
        """
        Args:
            dataframe: pandas DataFrame with metadata
            img_dir: Directory containing MRI images
            transform: Optional transform to be applied on images
            task: 'classification' or 'regression'
        """
        self.df = dataframe.copy()
        self.img_dir = Path(img_dir)
        self.transform = transform
        self.task = task
        
        # Find image paths
        self.image_paths = []
        self.labels = []
        
        for idx, row in self.df.iterrows():
            # Adjust this based on actual folder structure
            subject_id = row['ID']
            # Look for images in subject folder
            subject_folder = self.img_dir / subject_id
            
            if subject_folder.exists():
                # Get first MRI image (you may want to select specific slices)
                img_files = list(subject_folder.glob('*.gif')) + list(subject_folder.glob('*.jpg')) + list(subject_folder.glob('*.png'))
                
                if img_files:
                    self.image_paths.append(img_files[0])
                    
                    if self.task == 'classification':
                        self.labels.append(row['Demented'])
                    else:  # regression
                        self.labels.append(row['CDR'])
        
        print(f"Found {len(self.image_paths)} images")
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        # Get label
        label = self.labels[idx]
        
        if self.task == 'classification':
            label = torch.tensor(label, dtype=torch.long)
        else:  # regression
            label = torch.tensor(label, dtype=torch.float32)
        
        return image, label

## Step 5: Data Preprocessing and Augmentation

Define transforms for training and validation.

In [None]:
from torchvision import transforms

# Image preprocessing
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

print("Transforms defined successfully!")

## Step 6: Train-Test Split

Split the data into training and validation sets.

In [None]:
from sklearn.model_selection import train_test_split

# Split data (stratified for classification)
train_df, val_df = train_test_split(
    df, 
    test_size=0.2, 
    random_state=42,
    stratify=df['Demented']  # Ensure balanced split
)

print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")

print(f"\nTraining set class distribution:")
print(train_df['Demented'].value_counts())
print(f"\nValidation set class distribution:")
print(val_df['Demented'].value_counts())

In [None]:
# Create datasets
train_dataset = OASISDataset(
    train_df, 
    img_dir=DATA_DIR,  # Adjust based on actual structure
    transform=train_transform,
    task='classification'
)

val_dataset = OASISDataset(
    val_df, 
    img_dir=DATA_DIR,
    transform=val_transform,
    task='classification'
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=16, 
    shuffle=True,
    num_workers=0  # Set to 0 for Windows compatibility
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=16, 
    shuffle=False,
    num_workers=0
)

print(f"Train loader batches: {len(train_loader)}")
print(f"Val loader batches: {len(val_loader)}")

## Step 7: Model Architecture

Choose a pre-trained model and adapt it for our task.

In [None]:
import torch.nn as nn
from torchvision import models

def create_classification_model(num_classes=2, pretrained=True):
    """
    Create ResNet50 model for classification
    """
    model = models.resnet50(pretrained=pretrained)
    
    # Freeze early layers (optional)
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace final layer
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_ftrs, 512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, num_classes)
    )
    
    return model

def create_regression_model(pretrained=True):
    """
    Create ResNet50 model for regression (CDR prediction)
    """
    model = models.resnet50(pretrained=pretrained)
    
    # Freeze early layers
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace final layer for regression
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_ftrs, 512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, 1)  # Single output for regression
    )
    
    return model

# Create model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = create_classification_model(num_classes=2, pretrained=True)
model = model.to(device)

print(f"\nModel architecture:")
print(model)

## Step 8: Training Setup

Define loss function, optimizer, and training parameters.

In [None]:
# Classification setup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)

# Training parameters
num_epochs = 20
best_val_loss = float('inf')

print("Training setup complete!")

## Step 9: Training and Validation Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

def validate(model, dataloader, criterion, device):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

print("Training functions defined!")

## Step 10: Training Loop

In [None]:
# Training history
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': []
}

print("Starting training...")
print("=" * 60)

for epoch in range(num_epochs):
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Print progress
    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    print("-" * 60)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), DATA_DIR / 'best_model.pth')
        print(f"  ✓ Saved best model (val_loss: {val_loss:.4f})")

print("\nTraining complete!")

## Step 11: Visualize Training Results

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax1.plot(history['train_loss'], label='Train Loss', marker='o')
ax1.plot(history['val_loss'], label='Val Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy curves
ax2.plot(history['train_acc'], label='Train Acc', marker='o')
ax2.plot(history['val_acc'], label='Val Acc', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(DATA_DIR / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## Step 12: Model Evaluation

In [None]:
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, roc_curve

# Load best model
model.load_state_dict(torch.load(DATA_DIR / 'best_model.pth'))
model.eval()

# Collect predictions
all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        _, preds = outputs.max(1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())
        all_probs.extend(probs[:, 1].cpu().numpy())  # Probability of class 1

# Classification report
print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=['Non-Demented', 'Demented']))

# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Non-Demented', 'Demented'],
            yticklabels=['Non-Demented', 'Demented'])
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.savefig(DATA_DIR / 'confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

# ROC curve
fpr, tpr, _ = roc_curve(all_labels, all_probs)
auc_score = roc_auc_score(all_labels, all_probs)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, linewidth=2, label=f'AUC = {auc_score:.3f}')
plt.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - Dementia Classification')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(DATA_DIR / 'roc_curve.png', dpi=150, bbox_inches='tight')
plt.show()

## Next Steps & Variations

### For Regression Task (Predicting CDR Score):
1. Change `task='regression'` in dataset creation
2. Use `create_regression_model()`
3. Use `nn.MSELoss()` as criterion
4. Evaluate with MAE, RMSE, R² metrics

### Model Improvements:
- Try different architectures (VGG16, EfficientNet, Vision Transformer)
- Use 3D CNNs for volumetric MRI data
- Implement attention mechanisms
- Ensemble multiple models

### Data Improvements:
- Use multiple MRI slices per subject
- Add clinical features (age, gender, education) as additional inputs
- Apply medical image preprocessing (skull stripping, intensity normalization)
- Use data augmentation specific to medical images

### Advanced Techniques:
- Cross-validation for robust evaluation
- Class weighting for imbalanced data
- Transfer learning from medical image datasets
- Explainability with GradCAM or SHAP

In [None]:
# Save final results summary
results_summary = {
    'Best Val Loss': best_val_loss,
    'Final Val Accuracy': history['val_acc'][-1],
    'AUC Score': auc_score,
    'Training Epochs': num_epochs
}

print("\n" + "="*60)
print("FINAL RESULTS SUMMARY")
print("="*60)
for key, value in results_summary.items():
    print(f"{key}: {value:.4f}" if isinstance(value, float) else f"{key}: {value}")
print("="*60)