# Install and Import Required Libraries


In [5]:
%pip install torch torchvision torchaudio nibabel numpy tqdm

DATASET = "./DATA/ADNI_SPLIT"


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch

print(torch.backends.mps.is_available())  # Should return True

True


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models.video as models
import nibabel as nib
from tqdm import tqdm

# Check if Metal is available on macOS
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Metal) device")
else:
    device = torch.device("cpu")
    print("MPS not available, using CPU")


# Dataset class for loading .nii.gz files
class MRIDataset(Dataset):
    def __init__(self, root_dir, split="train"):
        self.root_dir = root_dir
        self.split = split
        self.samples = []
        self.labels = []

        # Get all files from AD and CN directories
        ad_dir = os.path.join(root_dir, split, "AD")
        cn_dir = os.path.join(root_dir, split, "CN")

        # Load AD samples (label 1)
        for file in os.listdir(ad_dir):
            if file.endswith(".nii.gz"):
                self.samples.append(os.path.join(ad_dir, file))
                self.labels.append(1)  # AD class

        # Load CN samples (label 0)
        for file in os.listdir(cn_dir):
            if file.endswith(".nii.gz"):
                self.samples.append(os.path.join(cn_dir, file))
                self.labels.append(0)  # CN class

        print(f"Loaded {len(self.samples)} samples for {split} split")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Load the .nii.gz file
        img_path = self.samples[idx]
        label = self.labels[idx]

        # Load image using nibabel
        img = nib.load(img_path)
        img_data = img.get_fdata()

        # Normalize to [0, 1] if not already
        if img_data.max() > 1.0:
            img_data = img_data / img_data.max()

        # Center crop to 128x128x128 (optional - reduces memory requirements)
        # Adjust these values based on your needs
        d, h, w = img_data.shape
        d_center, h_center, w_center = d // 2, h // 2, w // 2
        img_data = img_data[
            max(0, d_center - 64) : min(d, d_center + 64),
            max(0, h_center - 64) : min(h, h_center + 64),
            max(0, w_center - 64) : min(w, w_center + 64),
        ]

        # Ensure the cropped size is exactly 128x128x128
        current_d, current_h, current_w = img_data.shape
        if current_d != 128 or current_h != 128 or current_w != 128:
            temp = np.zeros((128, 128, 128))
            temp[
                : min(current_d, 128), : min(current_h, 128), : min(current_w, 128)
            ] = img_data[
                : min(current_d, 128), : min(current_h, 128), : min(current_w, 128)
            ]
            img_data = temp

        # Convert to tensor and add channel dimension
        img_tensor = torch.tensor(img_data, dtype=torch.float32).unsqueeze(
            0
        )  # Add channel dim

        return img_tensor, label


# Modified 3D ResNet model
class MRIModel(nn.Module):
    def __init__(self, num_classes=2):
        super(MRIModel, self).__init__()
        # Using a video ResNet and modifying it for 3D MRI
        # Fix the deprecation warning by using weights parameter
        self.resnet = models.r3d_18(weights=models.R3D_18_Weights.KINETICS400_V1)

        # Replace the first layer to accept single-channel input instead of 3
        self.resnet.stem[0] = nn.Conv3d(
            1,
            64,
            kernel_size=(3, 7, 7),
            stride=(1, 2, 2),
            padding=(1, 3, 3),
            bias=False,
        )

        # Replace the final fully connected layer for binary classification
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        # Input: (B, 1, D, H, W)
        return self.resnet(x)


# Training function
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in tqdm(dataloader, desc="Training"):
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Calculate accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100 * correct / total

    return epoch_loss, epoch_acc


# Validation function
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = running_loss / len(dataloader)
    val_acc = 100 * correct / total

    return val_loss, val_acc


def main():
    # Parameters
    data_root = DATASET  # Update this to your dataset path
    batch_size = 2  # Reduced batch size for memory constraints
    num_epochs = 5
    learning_rate = 0.0001

    # Create datasets
    train_dataset = MRIDataset(data_root, split="train")
    val_dataset = MRIDataset(data_root, split="val")

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=0
    )

    # Initialize the model
    model = MRIModel(num_classes=2)
    model = model.to(device)

    # We're skipping torch.compile which doesn't work well with MPS
    print("Using standard model without compilation for MPS compatibility")

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device
        )
        val_loss, val_acc = validate(model, val_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Save the best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")
            print("Model saved!")


if __name__ == "__main__":
    main()

Using MPS (Metal) device
Loaded 680 samples for train split
Loaded 85 samples for val split
Using standard model without compilation for MPS compatibility

Epoch 1/5


Training:   1%|          | 3/340 [01:48<3:22:31, 36.06s/it]


KeyboardInterrupt: 