# Install and Import Required Libraries


In [1]:
# python 3.10

%pip install torch torchvision torchaudio nibabel numpy tqdm wandb

DATASET = "./DATA/ADNI_96"


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch

print(torch.backends.mps.is_available())  # Should return True

True


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models.video as models
import nibabel as nib
from tqdm import tqdm
import wandb

# Check if Metal is available on macOS
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Metal) device")
else:
    device = torch.device("cpu")
    print("MPS not available, using CPU")


# Dataset class for loading .nii.gz files
class MRIDataset(Dataset):
    def __init__(self, root_dir, split="train"):
        self.root_dir = root_dir
        self.split = split
        self.samples = []
        self.labels = []

        # Get all files from AD and CN directories
        ad_dir = os.path.join(root_dir, split, "AD")
        cn_dir = os.path.join(root_dir, split, "CN")

        # Load AD samples (label 1)
        for file in os.listdir(ad_dir):
            if file.endswith(".nii.gz"):
                self.samples.append(os.path.join(ad_dir, file))
                self.labels.append(1)  # AD class

        # Load CN samples (label 0)
        for file in os.listdir(cn_dir):
            if file.endswith(".nii.gz"):
                self.samples.append(os.path.join(cn_dir, file))
                self.labels.append(0)  # CN class

        print(f"Loaded {len(self.samples)} samples for {split} split")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Load the .nii.gz file
        img_path = self.samples[idx]
        label = self.labels[idx]

        # Load image using nibabel
        img = nib.load(img_path)
        img_data = img.get_fdata()

        # Ensure the size is exactly 96x96x96
        current_d, current_h, current_w = img_data.shape
        if current_d != 96 or current_h != 96 or current_w != 96:
            raise ValueError(
                f"Expected image size 96x96x96 but got {current_d}x{current_h}x{current_w}"
            )

        # Convert to tensor and add channel dimension
        img_tensor = torch.tensor(img_data, dtype=torch.float32).unsqueeze(
            0
        )  # Add channel dim

        return img_tensor, label


# Modified 3D ResNet model
class MRIModel(nn.Module):
    def __init__(self, num_classes=2):
        super(MRIModel, self).__init__()
        # Using a video ResNet and modifying it for 3D MRI
        # Fix the deprecation warning by using weights parameter
        self.resnet = models.r3d_18(weights=models.R3D_18_Weights.KINETICS400_V1)

        # Replace the first layer to accept single-channel input instead of 3
        self.resnet.stem[0] = nn.Conv3d(
            1,
            64,
            kernel_size=(3, 7, 7),
            stride=(1, 2, 2),
            padding=(1, 3, 3),
            bias=False,
        )

        # Replace the final fully connected layer for binary classification
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        # Input: (B, 1, D, H, W)
        return self.resnet(x)


# Training function
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    # Track batch-level metrics
    all_labels = []
    all_preds = []

    for batch_idx, (inputs, labels) in enumerate(tqdm(dataloader, desc="Training")):
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad(set_to_none=True)  # Faster than setting to zero

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Calculate accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Store predictions and labels for metrics
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

        # Log batch-level metrics
        wandb.log(
            {
                "batch_loss": loss.item(),
                "batch_acc": 100 * (predicted == labels).sum().item() / labels.size(0),
                "batch": epoch * len(dataloader) + batch_idx,
            }
        )

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100 * correct / total

    # Calculate additional metrics
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)

    # Class-wise accuracy
    class_0_mask = all_labels == 0
    class_1_mask = all_labels == 1

    class_0_acc = (
        100
        * np.sum(all_preds[class_0_mask] == all_labels[class_0_mask])
        / (np.sum(class_0_mask) + 1e-10)
    )
    class_1_acc = (
        100
        * np.sum(all_preds[class_1_mask] == all_labels[class_1_mask])
        / (np.sum(class_1_mask) + 1e-10)
    )

    # Log epoch-level metrics
    wandb.log(
        {
            "train_loss": epoch_loss,
            "train_acc": epoch_acc,
            "train_CN_acc": class_0_acc,
            "train_AD_acc": class_1_acc,
            "epoch": epoch,
        }
    )

    return epoch_loss, epoch_acc


# Validation function
def validate(model, dataloader, criterion, device, epoch):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    # Track for metrics
    all_labels = []
    all_preds = []
    all_probs = []

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()

            # For accuracy
            probs = torch.nn.functional.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Store for metrics
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())  # Probability of AD class

    val_loss = running_loss / len(dataloader)
    val_acc = 100 * correct / total

    # Convert to numpy for metric calculation
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)

    # Class-wise accuracy
    class_0_mask = all_labels == 0
    class_1_mask = all_labels == 1

    class_0_acc = (
        100
        * np.sum(all_preds[class_0_mask] == all_labels[class_0_mask])
        / (np.sum(class_0_mask) + 1e-10)
    )
    class_1_acc = (
        100
        * np.sum(all_preds[class_1_mask] == all_labels[class_1_mask])
        / (np.sum(class_1_mask) + 1e-10)
    )

    # Custom metrics
    true_positives = np.sum((all_preds == 1) & (all_labels == 1))
    false_positives = np.sum((all_preds == 1) & (all_labels == 0))
    true_negatives = np.sum((all_preds == 0) & (all_labels == 0))
    false_negatives = np.sum((all_preds == 0) & (all_labels == 1))

    precision = true_positives / (true_positives + false_positives + 1e-10)
    recall = true_positives / (true_positives + false_negatives + 1e-10)
    f1_score = 2 * precision * recall / (precision + recall + 1e-10)

    # Generate confusion matrix for visualization
    confusion_matrix = wandb.plot.confusion_matrix(
        preds=all_preds, y_true=all_labels, class_names=["CN", "AD"]
    )

    # ROC curve for validation
    roc_curve = wandb.plot.roc_curve(
        all_labels,
        [1 - p for p in all_probs]
        + all_probs,  # Need [CN_prob, AD_prob] probabilities for each sample
        classes_to_plot=[1],  # Plot ROC for AD class (positive class)
        labels=["CN", "AD"],
    )

    # Log validation metrics
    wandb.log(
        {
            "val_loss": val_loss,
            "val_acc": val_acc,
            "val_CN_acc": class_0_acc,
            "val_AD_acc": class_1_acc,
            "val_precision": precision,
            "val_recall": recall,
            "val_f1": f1_score,
            "confusion_matrix": confusion_matrix,
            "roc_curve": roc_curve,
            "epoch": epoch,
        }
    )

    return val_loss, val_acc


def main():
    # Initialize wandb
    wandb.init(
        project="mri-alzheimers-classification",
        config={
            "architecture": "3D-ResNet18",
            "dataset": "MRI-AD-CN",
            "epochs": 5,
            "batch_size": 2,
            "learning_rate": 0.0001,
            "optimizer": "AdamW",
            "device": str(device),
            "input_dimensions": "128x128x128",
        },
    )

    # Parameters
    data_root = DATASET  # Update this to your dataset path
    batch_size = 2  # Reduced batch size for memory constraints
    num_epochs = 5
    learning_rate = 0.0001

    # Create datasets
    train_dataset = MRIDataset(data_root, split="train")
    val_dataset = MRIDataset(data_root, split="val")

    # Log dataset stats
    wandb.config.update(
        {
            "train_samples": len(train_dataset),
            "val_samples": len(val_dataset),
            "train_AD_samples": train_dataset.labels.count(1),
            "train_CN_samples": train_dataset.labels.count(0),
            "val_AD_samples": val_dataset.labels.count(1),
            "val_CN_samples": val_dataset.labels.count(0),
        }
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=0
    )

    # Initialize the model
    model = MRIModel(num_classes=2)
    model = model.to(device)

    # Log model architecture
    wandb.watch(model, log="all", log_freq=10)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)

    # Learning rate scheduler for better convergence
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=2, verbose=True
    )

    # Training loop
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device, epoch
        )
        val_loss, val_acc = validate(model, val_loader, criterion, device, epoch)

        # Update learning rate based on validation loss
        scheduler.step(val_loss)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Save the best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")
            print("Model saved!")

            # Log best model as artifact
            artifact = wandb.Artifact("best_model", type="model")
            artifact.add_file("best_model.pth")
            wandb.log_artifact(artifact)

    # Close wandb run
    wandb.finish()


if __name__ == "__main__":
    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Using MPS (Metal) device


[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


Loaded 678 samples for train split
Loaded 85 samples for val split





Epoch 1/5


Training:   2%|▏         | 6/339 [00:55<50:24,  9.08s/it]