# Install and Import Required Libraries


In [None]:
# python 3.10

%pip install torch torchvision torchaudio torchinfo nibabel numpy tqdm wandb monai

In [2]:
import os
import random
import time

import wandb
import torch

import numpy as np
import pandas as pd
import nibabel as nib
import torch.nn as nn
import torch.optim as optim
import torchvision.models.video as models

from tqdm import tqdm
from nilearn import plotting
from torchinfo import torchinfo
from torch.utils.data import Dataset, DataLoader
from monai.transforms import (
    Compose,
    RandRotate90,
    RandFlip,
    RandGaussianNoise,
    RandGaussianSmooth,
    RandAdjustContrast,
    RandScaleIntensity,
    NormalizeIntensity,
)

from IPython.display import display

## Constants


In [3]:
DATASET = "./DATA/ADNI_CROPPED_128"

# Visualise Scans


In [4]:
def visualize_scans(dataset_path, split="train", num_samples=3):
    """
    Visualize the first few MRI scans from AD and CN directories.

    Args:
        dataset_path (str): Path to the dataset directory.
        split (str): Dataset split to visualize ('train', 'val', 'test').
        num_samples (int): Number of samples to visualize from each class.
    """
    # Define directories for AD and CN
    ad_dir = os.path.join(dataset_path, split, "AD")
    cn_dir = os.path.join(dataset_path, split, "CN")

    # Get the first few files from each directory
    ad_files = [
        os.path.join(ad_dir, f) for f in os.listdir(ad_dir) if f.endswith(".nii.gz")
    ][:num_samples]
    cn_files = [
        os.path.join(cn_dir, f) for f in os.listdir(cn_dir) if f.endswith(".nii.gz")
    ][:num_samples]

    # Plot the first few AD scans
    print("AD Scans:")
    for file in ad_files:
        plotting.plot_anat(file, title=os.path.basename(file))
    plotting.show()

    # Plot the first few CN scans
    print("CN Scans:")
    for file in cn_files:
        plotting.plot_anat(file, title=os.path.basename(file))
    plotting.show()


# Example usage
# visualize_scans(DATASET, split="train", num_samples=3)

# Train


## Check Metal


In [5]:
# Check if Metal is available on macOS
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Metal) device")
else:
    device = torch.device("cpu")
    print("MPS not available, using CPU")

if device.type == "mps":
    # Empty CUDA cache periodically during training to avoid memory fragmentation
    def empty_cache():
        try:
            # For newer PyTorch versions with MPS cache management
            torch.mps.empty_cache()
        except:
            print("MPS cache management not available")
            pass  # Ignore if this function doesn't exist

Using MPS (Metal) device


## Datasets


In [6]:
class MRIDataset(Dataset):
    def __init__(self, root_dir, split="train", apply_augmentation=False):
        self.root_dir = root_dir
        self.split = split
        self.samples = []
        self.labels = []
        self.apply_augmentation = apply_augmentation

        # Get all files from AD and CN directories
        ad_dir = os.path.join(root_dir, split, "AD")
        cn_dir = os.path.join(root_dir, split, "CN")

        # Load AD samples (label 1)
        for file in os.listdir(ad_dir):
            if file.endswith(".nii.gz"):
                self.samples.append(os.path.join(ad_dir, file))
                self.labels.append(1)  # AD class

        # Load CN samples (label 0)
        for file in os.listdir(cn_dir):
            if file.endswith(".nii.gz"):
                self.samples.append(os.path.join(cn_dir, file))
                self.labels.append(0)  # CN class

        # Setup augmentation transforms using MONAI - WITHOUT normalization
        if apply_augmentation:
            self.transforms = Compose(
                [
                    RandRotate90(prob=0.5, spatial_axes=(1, 2)),
                    RandFlip(prob=0.5, spatial_axis=0),
                    RandGaussianNoise(prob=0.2, mean=0.0, std=0.1),
                    RandGaussianSmooth(prob=0.2, sigma_x=(0.5, 1.5)),
                    RandAdjustContrast(prob=0.3, gamma=(0.7, 1.3)),
                    RandScaleIntensity(prob=0.3, factors=0.2),
                    NormalizeIntensity(nonzero=True),
                ]
            )
        else:
            self.transforms = Compose([NormalizeIntensity(nonzero=True)])

        print(f"Loaded {len(self.samples)} samples for {split} split")
        print(f"Augmentation applied: {apply_augmentation}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Load the .nii.gz file
        img_path = self.samples[idx]
        label = self.labels[idx]

        # Load image using nibabel
        img = nib.load(img_path)
        img_data = img.get_fdata()

        # Ensure the size is exactly 128x128x128
        current_d, current_h, current_w = img_data.shape
        if current_d != 128 or current_h != 128 or current_w != 128:
            raise ValueError(
                f"Expected image size 128x128x128 but got {current_d}x{current_h}x{current_w}"
            )

        # Add channel dimension to numpy array
        img_data = np.expand_dims(img_data, axis=0)

        # Apply transforms (MONAI transforms work with both numpy arrays and tensors)
        img_data = self.transforms(img_data)

        # Convert to tensor if not already a tensor
        if not isinstance(img_data, torch.Tensor):
            img_data = torch.tensor(img_data, dtype=torch.float32)

        # Ensure the label is also a tensor
        label = torch.tensor(label, dtype=torch.long)

        return img_data, label

## Model


In [None]:
# Modified 3D ResNet model with layer freezing
class MRIModel(nn.Module):
    def __init__(self, num_classes=2, freeze_layers=True):
        super(MRIModel, self).__init__()
        # Using a video ResNet and modifying it for 3D MRI
        self.resnet = models.r3d_18(weights=models.R3D_18_Weights.KINETICS400_V1)

        # Replace the first layer to accept single-channel input instead of 3
        self.resnet.stem[0] = nn.Conv3d(
            1,
            64,
            kernel_size=(3, 7, 7),
            stride=(1, 2, 2),
            padding=(1, 3, 3),
            bias=False,
        )

        # Replace the final fully connected layer for binary classification
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)

        # Freeze specific layers if requested
        if freeze_layers:
            self._freeze_layers()

    def _freeze_layers(self):
        """Freeze most layers of the ResNet model, leaving only layer4 and fc unfrozen"""
        # Freeze stem and layers 1-3
        # TODO loook at model in more detail and see where to freeze
        for name, param in self.resnet.named_parameters():
            if "layer4" not in name and "fc" not in name:
                param.requires_grad = False

    def count_trainable_params(self):
        """Count and return trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def count_total_params(self):
        """Count and return total parameters"""
        return sum(p.numel() for p in self.parameters())

    def forward(self, x):
        # Input: (B, 1, D, H, W)
        return self.resnet(x)

## Model Summary


In [8]:
def display_model_summary(model, input_size=(1, 1, 128, 128, 128), detailed=True):
    """
    Display a comprehensive summary of the model architecture and parameters.

    Args:
        model: The PyTorch model to analyze
        input_size: The input tensor size (batch_size, channels, depth, height, width)
        detailed: Whether to show detailed layer information
    """
    # Get basic model summary using torchinfo
    summary = torchinfo.summary(
        model,
        input_size=input_size,
        col_names=["input_size", "output_size", "num_params", "trainable"],
        verbose=0,
    )

    print(f"MODEL ARCHITECTURE SUMMARY:")
    print("=" * 80)
    print(summary)

    # Count parameters by layer type
    layer_counts = {}
    for name, module in model.named_modules():
        layer_type = module.__class__.__name__
        if layer_type not in layer_counts:
            layer_counts[layer_type] = {"count": 0, "params": 0, "trainable_params": 0}

        layer_counts[layer_type]["count"] += 1
        params = sum(p.numel() for p in module.parameters(recurse=False))
        trainable_params = sum(
            p.numel() for p in module.parameters(recurse=False) if p.requires_grad
        )

        layer_counts[layer_type]["params"] += params
        layer_counts[layer_type]["trainable_params"] += trainable_params

    # Create detailed layer information dataframe
    if detailed:
        layers_info = []
        for name, module in model.named_modules():
            if len(list(module.children())) == 0:  # Only leaf modules
                params = sum(p.numel() for p in module.parameters())
                trainable = sum(
                    p.numel() for p in module.parameters() if p.requires_grad
                )

                layers_info.append(
                    {
                        "Layer": name,
                        "Type": module.__class__.__name__,
                        "Parameters": params,
                        "Trainable": trainable,
                        "Frozen": params - trainable,
                    }
                )

        # Create and display DataFrame
        df = pd.DataFrame(layers_info)
        if not df.empty:
            print("\nDETAILED LAYER INFORMATION:")
            print("=" * 80)
            display(df)

    # Show frozen vs trainable stats
    total_params = model.count_total_params()
    trainable_params = model.count_trainable_params()
    frozen_params = total_params - trainable_params

    print("\nPARAMETER STATISTICS:")
    print("=" * 80)
    print(f"Total parameters:    {total_params:,}")
    print(
        f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)"
    )
    print(
        f"Frozen parameters:    {frozen_params:,} ({frozen_params/total_params*100:.2f}%)"
    )

    # Display model architecture as text
    print("\nMODEL ARCHITECTURE DETAILS:")
    print("=" * 80)
    print(model)

    # Return summary for potential further use
    return summary


# Example usage:
# model = MRIModel(num_classes=2, freeze_layers=True)
# display_model_summary(model)

## Train


In [10]:
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch):
    model.train()
    # Use tensors instead of scalar values
    running_loss_tensor = torch.tensor(0.0, device=device)
    correct_tensor = torch.tensor(0, device=device)
    total = 0

    # Store predictions and labels on GPU during training
    batch_predictions = []
    batch_labels = []

    # Timing statistics - with more detailed metrics breakdown
    time_points = {
        "batch_start": [],
        "data_loaded": [],
        "transfer_complete": [],
        "optimizer_zeroed": [],
        "forward_complete": [],
        "loss_complete": [],
        "backward_complete": [],
        "optimizer_complete": [],
        "running_loss_updated": [],  # New timing point
        "accuracy_calc_start": [],  # New timing point
        "accuracy_calc_complete": [],  # New timing point
        "store_tensors_start": [],  # New timing point
        "store_tensors_complete": [],  # New timing point
        "metrics_complete": [],
        "batch_end": [],
        "next_batch_start": [],
    }

    # For the first batch
    time_points["batch_start"].append(time.time())

    for batch_idx, (inputs, labels) in enumerate(tqdm(dataloader, desc="Training")):

        # Add this at the beginning of your training loop to check input size
        if batch_idx == 0:
            print(
                f"Input shape: {inputs.shape}, Memory: {inputs.element_size() * inputs.nelement() / 1024 / 1024:.2f} MB"
            )
            print(
                f"Labels shape: {labels.shape}, Memory: {labels.element_size() * labels.nelement() / 1024 / 1024:.2f} MB"
            )

        # Record time when data is loaded from dataloader
        time_points["data_loaded"].append(time.time())

        # Transfer to device
        inputs, labels = inputs.to(device, non_blocking=True), labels.to(
            device, non_blocking=True
        )
        time_points["transfer_complete"].append(time.time())

        # Zero gradients
        optimizer.zero_grad(set_to_none=True)
        time_points["optimizer_zeroed"].append(time.time())

        # Forward pass
        outputs = model(inputs)
        time_points["forward_complete"].append(time.time())

        # Loss calculation
        loss = criterion(outputs, labels)
        time_points["loss_complete"].append(time.time())

        # Backward pass
        loss.backward()
        time_points["backward_complete"].append(time.time())

        # Optimizer step
        optimizer.step()
        time_points["optimizer_complete"].append(time.time())

        # Update running loss - KEEP AS TENSOR, AVOID .item()
        running_loss_tensor += loss
        time_points["running_loss_updated"].append(time.time())

        # Start accuracy calculation
        time_points["accuracy_calc_start"].append(time.time())

        # Calculate accuracy - KEEP AS TENSOR, AVOID .item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct_tensor += (predicted == labels).sum()  # Keep as tensor

        time_points["accuracy_calc_complete"].append(time.time())

        # Store predictions and labels on GPU (no CPU transfer)
        time_points["store_tensors_start"].append(time.time())

        # Store detached tensors to avoid memory leak
        batch_predictions.append(predicted.detach())
        batch_labels.append(labels.detach())

        time_points["store_tensors_complete"].append(time.time())
        time_points["metrics_complete"].append(time.time())

        # Record end of batch
        time_points["batch_end"].append(time.time())

        # Start next batch timing
        if batch_idx < len(dataloader) - 1:  # Don't record after last batch
            time_points["next_batch_start"].append(time.time())
            # Also record this as the start of the next batch
            time_points["batch_start"].append(time.time())

        if batch_idx > 0:  # We have complete timing data starting from batch 1
            i = batch_idx if batch_idx < len(time_points["data_loaded"]) else -1

            # Standard timings
            data_loading_time = (
                time_points["data_loaded"][i] - time_points["batch_start"][i]
            )
            transfer_time = (
                time_points["transfer_complete"][i] - time_points["data_loaded"][i]
            )
            zero_grad_time = (
                time_points["optimizer_zeroed"][i] - time_points["transfer_complete"][i]
            )
            forward_time = (
                time_points["forward_complete"][i] - time_points["optimizer_zeroed"][i]
            )
            loss_time = (
                time_points["loss_complete"][i] - time_points["forward_complete"][i]
            )
            backward_time = (
                time_points["backward_complete"][i] - time_points["loss_complete"][i]
            )
            optim_time = (
                time_points["optimizer_complete"][i]
                - time_points["backward_complete"][i]
            )

            # Detailed metrics timing breakdown
            running_loss_update_time = (
                time_points["running_loss_updated"][i]
                - time_points["optimizer_complete"][i]
            )
            accuracy_calc_time = (
                time_points["accuracy_calc_complete"][i]
                - time_points["accuracy_calc_start"][i]
            )
            store_tensors_time = (
                time_points["store_tensors_complete"][i]
                - time_points["store_tensors_start"][i]
            )

            # Total metrics time
            metrics_time = (
                time_points["metrics_complete"][i]
                - time_points["optimizer_complete"][i]
            )
            end_time = time_points["batch_end"][i] - time_points["metrics_complete"][i]

            total_time = time_points["batch_end"][i] - time_points["batch_start"][i]

            if i > 0:
                dataloader_overhead = (
                    time_points["data_loaded"][i] - time_points["batch_end"][i - 1]
                )
            else:
                dataloader_overhead = 0

            print(f"\nDetailed timing for Batch {batch_idx} (seconds):")
            print(f"  Data loading:       {data_loading_time:.4f}")
            print(f"  Transfer to device: {transfer_time:.4f}")
            print(f"  Zero gradients:     {zero_grad_time:.4f}")
            print(f"  Forward pass:       {forward_time:.4f}")
            print(f"  Loss calculation:   {loss_time:.4f}")
            print(f"  Backward pass:      {backward_time:.4f}")
            print(f"  Optimizer step:     {optim_time:.4f}")

            # Detailed metrics breakdown
            print("\n  METRICS BREAKDOWN:")
            print(f"    Running loss update:    {running_loss_update_time:.4f}")
            print(f"    Accuracy calculation:   {accuracy_calc_time:.4f}")
            print(f"    Store tensors:          {store_tensors_time:.4f}")
            print(f"    Total metrics time:     {metrics_time:.4f}")

            print(f"\n  End batch overhead: {end_time:.4f}")
            print(f"  Dataloader overhead:{dataloader_overhead:.4f}")
            print(f"  Total batch time:   {total_time:.4f}")

            # Sum all measured components
            measured_time = (
                data_loading_time
                + transfer_time
                + zero_grad_time
                + forward_time
                + loss_time
                + backward_time
                + optim_time
                + metrics_time
                + end_time
            )

            print(f"  Sum of measured ops: {measured_time:.4f}")
            missing = total_time - measured_time
            print(
                f"  Missing time:       {missing:.4f} ({missing/total_time*100:.1f}%)"
            )

    # Process all predictions and labels at the end of the epoch
    print("\nProcessing stored tensors at end of epoch...")
    start_time = time.time()

    # Concatenate all tensors on GPU
    all_predictions_tensor = torch.cat(batch_predictions)
    all_labels_tensor = torch.cat(batch_labels)

    # Calculate class-wise metrics on GPU first
    class_0_mask = all_labels_tensor == 0
    class_1_mask = all_labels_tensor == 1

    class_0_correct = ((all_predictions_tensor == 0) & class_0_mask).sum()
    class_1_correct = ((all_predictions_tensor == 1) & class_1_mask).sum()
    class_0_total = class_0_mask.sum()
    class_1_total = class_1_mask.sum()

    # Convert to numpy only once at the end (if needed for other metrics)
    # all_preds = all_predictions_tensor.cpu().numpy()
    # all_labels = all_labels_tensor.cpu().numpy()

    # Calculate class-wise accuracy
    class_0_acc = 100 * class_0_correct.item() / max(class_0_total.item(), 1)
    class_1_acc = 100 * class_1_correct.item() / max(class_1_total.item(), 1)

    end_time = time.time()
    print(f"End-of-epoch tensor processing time: {end_time - start_time:.4f} seconds")

    # Convert accumulated tensors to scalars ONCE at the end
    epoch_loss = running_loss_tensor.item() / len(dataloader)
    epoch_acc = 100 * correct_tensor.item() / total

    # Log epoch-level metrics
    wandb.log(
        {
            "train_loss": epoch_loss,
            "train_acc": epoch_acc,
            "train_CN_acc": class_0_acc,
            "train_AD_acc": class_1_acc,
            "epoch": epoch,
        }
    )

    return epoch_loss, epoch_acc

## Validate


In [11]:
def validate(model, dataloader, criterion, device, epoch):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    # Track for metrics
    all_labels = []
    all_preds = []
    all_probs = []

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()

            # For accuracy
            probs = torch.nn.functional.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Store for metrics
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())  # Probability of AD class

    val_loss = running_loss / len(dataloader)
    val_acc = 100 * correct / total

    # Convert to numpy for metric calculation
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)

    # Class-wise accuracy
    class_0_mask = all_labels == 0
    class_1_mask = all_labels == 1

    class_0_acc = (
        100
        * np.sum(all_preds[class_0_mask] == all_labels[class_0_mask])
        / (np.sum(class_0_mask) + 1e-10)
    )
    class_1_acc = (
        100
        * np.sum(all_preds[class_1_mask] == all_labels[class_1_mask])
        / (np.sum(class_1_mask) + 1e-10)
    )

    # Custom metrics
    true_positives = np.sum((all_preds == 1) & (all_labels == 1))
    false_positives = np.sum((all_preds == 1) & (all_labels == 0))
    true_negatives = np.sum((all_preds == 0) & (all_labels == 0))
    false_negatives = np.sum((all_preds == 0) & (all_labels == 1))

    precision = true_positives / (true_positives + false_positives + 1e-10)
    recall = true_positives / (true_positives + false_negatives + 1e-10)
    f1_score = 2 * precision * recall / (precision + recall + 1e-10)

    # Generate confusion matrix for visualization
    confusion_matrix = wandb.plot.confusion_matrix(
        preds=all_preds, y_true=all_labels, class_names=["CN", "AD"]
    )

    try:
        # Create proper format for probabilities
        y_probas = np.zeros((len(all_labels), 2))
        y_probas[:, 0] = 1 - np.array(all_probs)  # CN probabilities
        y_probas[:, 1] = np.array(all_probs)  # AD probabilities

        # Now call the function with properly formatted data
        roc_curve = wandb.plot.roc_curve(
            all_labels,
            y_probas,
            classes_to_plot=[1],  # Plot ROC for AD class (positive class)
            labels=["CN", "AD"],
        )
    except Exception as e:
        print(f"Warning: ROC curve calculation failed: {e}")
        roc_curve = None

    # Log validation metrics (with conditional ROC curve)
    log_dict = {
        "val_loss": val_loss,
        "val_acc": val_acc,
        "val_CN_acc": class_0_acc,
        "val_AD_acc": class_1_acc,
        "val_precision": precision,
        "val_recall": recall,
        "val_f1": f1_score,
        "confusion_matrix": confusion_matrix,
        "epoch": epoch,
    }

    if roc_curve is not None:
        log_dict["roc_curve"] = roc_curve

    wandb.log(log_dict)

    return val_loss, val_acc

## Main


In [12]:
def main():
    # Initialize wandb
    wandb.init(
        project="mri-alzheimers-classification",
        config={
            "architecture": "3D-ResNet18-FrozenLayers",
            "dataset": "MRI-AD-CN",
            "epochs": 20,
            "batch_size": 2,
            "learning_rate": 0.0001,
            "optimizer": "AdamW",
            "device": str(device),
            "input_dimensions": "128x128x128",
            "freeze_layers": True,
            "data_augmentation": True,
        },
    )

    # Parameters
    data_root = DATASET  # Update this to your dataset path
    batch_size = 2  # Reduced batch size for memory constraints
    num_epochs = 20  # Reduced epochs for testing
    learning_rate = 0.0001
    freeze_layers = True
    use_augmentation = True

    # Create datasets with augmentation for training
    train_dataset = MRIDataset(
        data_root, split="train", apply_augmentation=use_augmentation
    )
    val_dataset = MRIDataset(
        data_root, split="val", apply_augmentation=False
    )  # No augmentation for validation

    # Log dataset stats
    wandb.config.update(
        {
            "train_samples": len(train_dataset),
            "val_samples": len(val_dataset),
            "train_AD_samples": train_dataset.labels.count(1),
            "train_CN_samples": train_dataset.labels.count(0),
            "val_AD_samples": val_dataset.labels.count(1),
            "val_CN_samples": val_dataset.labels.count(0),
        }
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )

    # Initialize the model with layer freezing
    model = MRIModel(num_classes=2, freeze_layers=freeze_layers)
    model = model.to(device)

    # Log parameter statistics
    trainable_params = model.count_trainable_params()
    total_params = model.count_total_params()
    frozen_params = total_params - trainable_params

    print(f"Total parameters: {total_params:,}")
    print(
        f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%})"
    )
    print(f"Frozen parameters: {frozen_params:,} ({frozen_params/total_params:.2%})")

    # Log model architecture and parameter stats
    wandb.config.update(
        {
            "total_params": total_params,
            "trainable_params": trainable_params,
            "frozen_params": frozen_params,
            "frozen_percentage": frozen_params / total_params,
        }
    )

    wandb.watch(model, log="all", log_freq=10)

    # Loss function with class weighting to handle imbalance
    # Calculate class weights based on sample distribution
    num_ad = train_dataset.labels.count(1)
    num_cn = train_dataset.labels.count(0)
    total = num_ad + num_cn

    # Inverse frequency weighting
    weight_cn = total / (2 * num_cn) if num_cn > 0 else 1.0
    weight_ad = total / (2 * num_ad) if num_ad > 0 else 1.0

    class_weights = torch.tensor([weight_cn, weight_ad], device=device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    # Optimizer with parameter groups and different learning rates
    # Higher learning rate for new/unfrozen layers, lower for pre-trained unfrozen layers
    fc_params = list(model.resnet.fc.parameters())
    other_params = [
        p
        for name, p in model.named_parameters()
        if p.requires_grad and not any(p is fc_param for fc_param in fc_params)
    ]

    # Set up parameter groups with different learning rates
    param_groups = [
        {"params": other_params, "lr": learning_rate},
        {
            "params": fc_params,
            "lr": learning_rate * 10,
        },  # Higher learning rate for final layer
    ]

    optimizer = optim.AdamW(param_groups, lr=learning_rate, weight_decay=0.01)

    # Learning rate scheduler with cosine annealing for better convergence
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=5, T_mult=1, eta_min=learning_rate / 100
    )

    # Early stopping implementation
    patience = 5
    early_stopping_counter = 0

    # Check for checkpoint and load if exists
    checkpoint_path = "checkpoints/checkpoint.pth"
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        best_val_acc = checkpoint.get(
            "best_val_acc", 0.0
        )  # Handle if best_val_acc wasn't saved.
        best_val_loss = checkpoint.get("best_val_loss", float("inf"))
        print(f"Loaded checkpoint from epoch {start_epoch}")
    else:
        best_val_acc = 0.0
        best_val_loss = float("inf")

    # Training loop with early stopping
    for epoch in range(start_epoch, num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device, epoch
        )
        val_loss, val_acc = validate(model, val_loader, criterion, device, epoch)

        # Update learning rate based on scheduler
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        print(f"Current learning rate: {current_lr:.6f}")

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Save the best model by validation accuracy
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "val_acc": val_acc,
                    "val_loss": val_loss,
                },
                "best_model_acc.pth",
            )
            print("Model saved (best accuracy)!")

            # Log best model as artifact
            artifact = wandb.Artifact("best_model_acc", type="model")
            artifact.add_file("best_model_acc.pth")
            wandb.log_artifact(artifact)

            # Reset early stopping counter on improvement
            early_stopping_counter = 0

        # Save best model by validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "val_acc": val_acc,
                    "val_loss": val_loss,
                },
                "best_model_loss.pth",
            )
            print("Model saved (best loss)!")

            # Log best model as artifact
            artifact = wandb.Artifact("best_model_loss", type="model")
            artifact.add_file("best_model_loss.pth")
            wandb.log_artifact(artifact)
        else:
            # Increment early stopping counter
            early_stopping_counter += 1

        # Check for early stopping
        if early_stopping_counter >= patience:
            print(f"Early stopping after {epoch+1} epochs without improvement.")
            break

        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "best_val_acc": best_val_acc,
                "best_val_loss": best_val_loss,
            },
            checkpoint_path,
        )

        if device.type == "mps":
            empty_cache()

    # Create test dataset and dataloader
    test_dataset = MRIDataset(
        data_root, split="test", apply_augmentation=False
    )  # No augmentation for test set
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )

    # After training, load best model for final evaluation
    checkpoint = torch.load("best_model_acc.pth")
    model.load_state_dict(checkpoint["model_state_dict"])
    print(
        f"Loaded best model from epoch {checkpoint['epoch']+1} with accuracy {checkpoint['val_acc']:.2f}%"
    )

    # Final evaluation on test set
    final_test_loss, final_test_acc = validate(
        model, test_loader, criterion, device, num_epochs
    )
    print(f"Final test accuracy: {final_test_acc:.2f}%")

    # Log final model summary
    wandb.run.summary["best_val_acc"] = best_val_acc
    wandb.run.summary["best_val_loss"] = best_val_loss
    wandb.run.summary["final_val_acc"] = final_test_acc
    wandb.run.summary["final_val_loss"] = final_test_loss
    wandb.run.summary["total_epochs"] = epoch + 1 if "epoch" in locals() else 0

    # Close wandb run
    wandb.finish()


if __name__ == "__main__":
    # Set random seeds for reproducibility
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    if not os.path.exists("checkpoints"):
        os.makedirs("checkpoints")

    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrhys-alexander[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Loaded 678 samples for train split
Augmentation applied: True
Loaded 85 samples for val split
Augmentation applied: False
Total parameters: 33,148,482
Trainable parameters: 24,909,826 (75.15%)
Frozen parameters: 8,238,656 (24.85%)

Epoch 1/20


Training:   0%|          | 0/339 [00:00<?, ?it/s]

Input shape: torch.Size([2, 1, 128, 128, 128]), Memory: 16.00 MB
Labels shape: torch.Size([2]), Memory: 0.00 MB


Training:   1%|          | 2/339 [00:06<19:51,  3.54s/it]


Detailed timing for Batch 1 (seconds):
  Data loading:       0.3919
  Transfer to device: 0.0005
  Zero gradients:     0.0005
  Forward pass:       4.2107
  Loss calculation:   0.0069
  Backward pass:      0.5108
  Optimizer step:     0.1065

  METRICS BREAKDOWN:
    Running loss update:    0.0002
    Accuracy calculation:   0.0036
    Store tensors:          0.0002
    Total metrics time:     0.0040

  End batch overhead: 0.0000
  Dataloader overhead:0.3920
  Total batch time:   5.2318
  Sum of measured ops: 5.2318
  Missing time:       0.0000 (0.0%)


Training:   1%|          | 3/339 [00:12<25:55,  4.63s/it]


Detailed timing for Batch 2 (seconds):
  Data loading:       0.4148
  Transfer to device: 0.0005
  Zero gradients:     0.0008
  Forward pass:       4.8929
  Loss calculation:   0.0018
  Backward pass:      0.5149
  Optimizer step:     0.1067

  METRICS BREAKDOWN:
    Running loss update:    0.0002
    Accuracy calculation:   0.0008
    Store tensors:          0.0000
    Total metrics time:     0.0010

  End batch overhead: 0.0000
  Dataloader overhead:0.4151
  Total batch time:   5.9332
  Sum of measured ops: 5.9332
  Missing time:       0.0000 (0.0%)


Training:   1%|          | 4/339 [00:18<28:43,  5.15s/it]


Detailed timing for Batch 3 (seconds):
  Data loading:       0.2832
  Transfer to device: 0.0515
  Zero gradients:     0.0005
  Forward pass:       4.9720
  Loss calculation:   0.0009
  Backward pass:      0.5207
  Optimizer step:     0.1059

  METRICS BREAKDOWN:
    Running loss update:    0.0002
    Accuracy calculation:   0.0008
    Store tensors:          0.0000
    Total metrics time:     0.0009

  End batch overhead: 0.0000
  Dataloader overhead:0.2834
  Total batch time:   5.9357
  Sum of measured ops: 5.9357
  Missing time:       0.0000 (0.0%)


Training:   1%|▏         | 5/339 [00:24<30:14,  5.43s/it]


Detailed timing for Batch 4 (seconds):
  Data loading:       0.2901
  Transfer to device: 0.0449
  Zero gradients:     0.0006
  Forward pass:       4.9768
  Loss calculation:   0.0074
  Backward pass:      0.5073
  Optimizer step:     0.1068

  METRICS BREAKDOWN:
    Running loss update:    0.0002
    Accuracy calculation:   0.0036
    Store tensors:          0.0002
    Total metrics time:     0.0039

  End batch overhead: 0.0000
  Dataloader overhead:0.2903
  Total batch time:   5.9379
  Sum of measured ops: 5.9379
  Missing time:       0.0000 (0.0%)


Training:   2%|▏         | 6/339 [00:30<31:11,  5.62s/it]


Detailed timing for Batch 5 (seconds):
  Data loading:       0.5852
  Transfer to device: 0.0004
  Zero gradients:     0.0007
  Forward pass:       4.7181
  Loss calculation:   0.0008
  Backward pass:      0.5326
  Optimizer step:     0.1462

  METRICS BREAKDOWN:
    Running loss update:    0.0007
    Accuracy calculation:   0.0045
    Store tensors:          0.0002
    Total metrics time:     0.0053

  End batch overhead: 0.0000
  Dataloader overhead:0.5854
  Total batch time:   5.9893
  Sum of measured ops: 5.9893
  Missing time:       0.0000 (0.0%)


Training:   2%|▏         | 7/339 [00:36<31:44,  5.74s/it]


Detailed timing for Batch 6 (seconds):
  Data loading:       0.4449
  Transfer to device: 0.0004
  Zero gradients:     0.0007
  Forward pass:       4.8948
  Loss calculation:   0.0058
  Backward pass:      0.5094
  Optimizer step:     0.1116

  METRICS BREAKDOWN:
    Running loss update:    0.0003
    Accuracy calculation:   0.0077
    Store tensors:          0.0000
    Total metrics time:     0.0081

  End batch overhead: 0.0000
  Dataloader overhead:0.4451
  Total batch time:   5.9757
  Sum of measured ops: 5.9757
  Missing time:       0.0000 (0.0%)


Training:   2%|▏         | 8/339 [00:42<31:58,  5.80s/it]


Detailed timing for Batch 7 (seconds):
  Data loading:       0.2959
  Transfer to device: 0.0296
  Zero gradients:     0.0007
  Forward pass:       4.9576
  Loss calculation:   0.0010
  Backward pass:      0.5253
  Optimizer step:     0.1057

  METRICS BREAKDOWN:
    Running loss update:    0.0002
    Accuracy calculation:   0.0007
    Store tensors:          0.0000
    Total metrics time:     0.0009

  End batch overhead: 0.0000
  Dataloader overhead:0.2961
  Total batch time:   5.9167
  Sum of measured ops: 5.9167
  Missing time:       0.0000 (0.0%)


Training:   3%|▎         | 9/339 [00:47<32:02,  5.83s/it]


Detailed timing for Batch 8 (seconds):
  Data loading:       0.2632
  Transfer to device: 0.0719
  Zero gradients:     0.0004
  Forward pass:       4.9409
  Loss calculation:   0.0007
  Backward pass:      0.5178
  Optimizer step:     0.1054

  METRICS BREAKDOWN:
    Running loss update:    0.0001
    Accuracy calculation:   0.0005
    Store tensors:          0.0000
    Total metrics time:     0.0006

  End batch overhead: 0.0000
  Dataloader overhead:0.2634
  Total batch time:   5.9009
  Sum of measured ops: 5.9009
  Missing time:       0.0000 (0.0%)


Training:   3%|▎         | 10/339 [01:00<42:44,  7.79s/it]


Detailed timing for Batch 9 (seconds):
  Data loading:       0.4001
  Transfer to device: 0.0004
  Zero gradients:     0.0005
  Forward pass:       11.0342
  Loss calculation:   0.0025
  Backward pass:      0.7094
  Optimizer step:     0.0506

  METRICS BREAKDOWN:
    Running loss update:    0.0002
    Accuracy calculation:   0.0008
    Store tensors:          0.0000
    Total metrics time:     0.0010

  End batch overhead: 0.0000
  Dataloader overhead:0.4001
  Total batch time:   12.1986
  Sum of measured ops: 12.1986
  Missing time:       0.0000 (0.0%)


Training:   3%|▎         | 11/339 [01:00<30:10,  5.52s/it]


Detailed timing for Batch 10 (seconds):
  Data loading:       0.3020
  Transfer to device: 0.0003
  Zero gradients:     0.0002
  Forward pass:       0.0201
  Loss calculation:   0.0003
  Backward pass:      0.0029
  Optimizer step:     0.0400

  METRICS BREAKDOWN:
    Running loss update:    0.0001
    Accuracy calculation:   0.0006
    Store tensors:          0.0000
    Total metrics time:     0.0007

  End batch overhead: 0.0000
  Dataloader overhead:0.3021
  Total batch time:   0.3665
  Sum of measured ops: 0.3665
  Missing time:       0.0000 (0.0%)


Training:   4%|▎         | 12/339 [01:06<30:40,  5.63s/it]


Detailed timing for Batch 11 (seconds):
  Data loading:       0.3212
  Transfer to device: 0.0004
  Zero gradients:     0.0002
  Forward pass:       4.9324
  Loss calculation:   0.0004
  Backward pass:      0.5190
  Optimizer step:     0.1060

  METRICS BREAKDOWN:
    Running loss update:    0.0001
    Accuracy calculation:   0.0004
    Store tensors:          0.0000
    Total metrics time:     0.0005

  End batch overhead: 0.0000
  Dataloader overhead:0.3212
  Total batch time:   5.8800
  Sum of measured ops: 5.8800
  Missing time:       0.0000 (0.0%)


Training:   4%|▍         | 13/339 [01:12<31:07,  5.73s/it]


Detailed timing for Batch 12 (seconds):
  Data loading:       0.3893
  Transfer to device: 0.0006
  Zero gradients:     0.0010
  Forward pass:       4.9379
  Loss calculation:   0.0008
  Backward pass:      0.5194
  Optimizer step:     0.1081

  METRICS BREAKDOWN:
    Running loss update:    0.0002
    Accuracy calculation:   0.0006
    Store tensors:          0.0000
    Total metrics time:     0.0008

  End batch overhead: 0.0000
  Dataloader overhead:0.3893
  Total batch time:   5.9578
  Sum of measured ops: 5.9578
  Missing time:       0.0000 (0.0%)


Training:   4%|▍         | 14/339 [01:18<31:26,  5.80s/it]


Detailed timing for Batch 13 (seconds):
  Data loading:       0.3741
  Transfer to device: 0.0005
  Zero gradients:     0.0006
  Forward pass:       4.9463
  Loss calculation:   0.0061
  Backward pass:      0.5224
  Optimizer step:     0.1139

  METRICS BREAKDOWN:
    Running loss update:    0.0006
    Accuracy calculation:   0.0030
    Store tensors:          0.0000
    Total metrics time:     0.0036

  End batch overhead: 0.0000
  Dataloader overhead:0.3742
  Total batch time:   5.9675
  Sum of measured ops: 5.9675
  Missing time:       0.0000 (0.0%)


Training:   4%|▍         | 15/339 [01:24<31:34,  5.85s/it]


Detailed timing for Batch 14 (seconds):
  Data loading:       0.5548
  Transfer to device: 0.0006
  Zero gradients:     0.0011
  Forward pass:       4.7627
  Loss calculation:   0.0011
  Backward pass:      0.5238
  Optimizer step:     0.1094

  METRICS BREAKDOWN:
    Running loss update:    0.0001
    Accuracy calculation:   0.0005
    Store tensors:          0.0000
    Total metrics time:     0.0007

  End batch overhead: 0.0000
  Dataloader overhead:0.5550
  Total batch time:   5.9543
  Sum of measured ops: 5.9543
  Missing time:       0.0000 (0.0%)


Training:   4%|▍         | 15/339 [01:28<31:48,  5.89s/it]


KeyboardInterrupt: 

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x320483cd0>> (for post_run_cell), with arguments args (<ExecutionResult object at 1073ed840, execution_count=12 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 1073ed420, raw_cell="def main():
    # Initialize wandb
    wandb.init(.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Users/rhysalexander/Desktop/FYP/code/train.ipynb#X32sZmlsZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe