# Install and Import Required Libraries


In [None]:
# python 3.10

%pip install torch torchvision torchaudio torchinfo nibabel numpy tqdm wandb monai

In [None]:
import os
import random

import wandb
import torch

import numpy as np
import pandas as pd
import nibabel as nib
import torch.nn as nn
import torch.optim as optim
import torchvision.models.video as models

from tqdm import tqdm
from nilearn import plotting
from torchinfo import torchinfo
from torch.utils.data import Dataset, DataLoader
from monai.transforms import (
    Compose,
    RandRotate90,
    RandFlip,
    RandGaussianNoise,
    RandGaussianSmooth,
    RandAdjustContrast,
    RandScaleIntensity,
    NormalizeIntensity,
)

from IPython.display import display

## Constants


In [None]:
DATASET = "./DATA/ADNI_CROPPED_128"

# Visualise Scans


In [None]:
def visualize_scans(dataset_path, split="train", num_samples=3):
    """
    Visualize the first few MRI scans from AD and CN directories.

    Args:
        dataset_path (str): Path to the dataset directory.
        split (str): Dataset split to visualize ('train', 'val', 'test').
        num_samples (int): Number of samples to visualize from each class.
    """
    # Define directories for AD and CN
    ad_dir = os.path.join(dataset_path, split, "AD")
    cn_dir = os.path.join(dataset_path, split, "CN")

    # Get the first few files from each directory
    ad_files = [
        os.path.join(ad_dir, f) for f in os.listdir(ad_dir) if f.endswith(".nii.gz")
    ][:num_samples]
    cn_files = [
        os.path.join(cn_dir, f) for f in os.listdir(cn_dir) if f.endswith(".nii.gz")
    ][:num_samples]

    # Plot the first few AD scans
    print("AD Scans:")
    for file in ad_files:
        plotting.plot_anat(file, title=os.path.basename(file))
    plotting.show()

    # Plot the first few CN scans
    print("CN Scans:")
    for file in cn_files:
        plotting.plot_anat(file, title=os.path.basename(file))
    plotting.show()


# Example usage
# visualize_scans(DATASET, split="train", num_samples=3)

# Train


## Check Metal


In [None]:
# Check if Metal is available on macOS
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Metal) device")
else:
    device = torch.device("cpu")
    print("MPS not available, using CPU")

if device.type == "mps":
    # Empty CUDA cache periodically during training to avoid memory fragmentation
    def empty_cache():
        try:
            # For newer PyTorch versions with MPS cache management
            torch.mps.empty_cache()
        except:
            print("MPS cache management not available")
            pass  # Ignore if this function doesn't exist

## Datasets


In [None]:
class MRIDataset(Dataset):
    def __init__(
        self,
        root_dir,
        split="train",
        apply_augmentation=False,
        target_size=(128, 128, 128),
    ):
        self.root_dir = root_dir
        self.split = split
        self.samples = []
        self.labels = []
        self.apply_augmentation = apply_augmentation
        self.target_size = target_size

        # Validate inputs
        if split not in ["train", "val", "test"]:
            raise ValueError(
                f"Split must be one of 'train', 'val', 'test', got {split}"
            )

        # Get all files from AD and CN directories
        ad_dir = os.path.join(root_dir, split, "AD")
        cn_dir = os.path.join(root_dir, split, "CN")

        if not os.path.exists(ad_dir):
            raise FileNotFoundError(f"AD directory not found at {ad_dir}")
        if not os.path.exists(cn_dir):
            raise FileNotFoundError(f"CN directory not found at {cn_dir}")

        self._load_samples(ad_dir, cn_dir)
        self._setup_transforms()

        print(f"Loaded {len(self.samples)} samples for {split} split")
        print(f"Augmentation applied: {apply_augmentation}")

    def _load_samples(self, ad_dir, cn_dir):
        """Load samples from AD and CN directories"""
        # Load AD samples (label 1)
        ad_files = [f for f in os.listdir(ad_dir) if f.endswith(".nii.gz")]
        for file in ad_files:
            self.samples.append(os.path.join(ad_dir, file))
            self.labels.append(1)  # AD class

        # Load CN samples (label 0)
        cn_files = [f for f in os.listdir(cn_dir) if f.endswith(".nii.gz")]
        for file in cn_files:
            self.samples.append(os.path.join(cn_dir, file))
            self.labels.append(0)  # CN class

        if len(self.samples) == 0:
            raise ValueError(f"No .nii.gz files found in {ad_dir} or {cn_dir}")

    def _setup_transforms(self):
        """Setup image transformations based on augmentation flag"""
        if self.apply_augmentation:
            self.transforms = Compose(
                [
                    RandRotate90(prob=0.5, spatial_axes=(1, 2)),
                    RandFlip(prob=0.5, spatial_axis=0),
                    RandGaussianNoise(prob=0.2, mean=0.0, std=0.1),
                    RandGaussianSmooth(prob=0.2, sigma_x=(0.5, 1.5)),
                    RandAdjustContrast(prob=0.3, gamma=(0.7, 1.3)),
                    RandScaleIntensity(prob=0.3, factors=0.2),
                    NormalizeIntensity(nonzero=True),
                ]
            )
        else:
            self.transforms = Compose([NormalizeIntensity(nonzero=True)])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Load the .nii.gz file
        img_path = self.samples[idx]
        label = self.labels[idx]

        try:
            # Load image using nibabel
            img = nib.load(img_path)
            img_data = img.get_fdata()

            # Validate image dimensions
            expected_d, expected_h, expected_w = self.target_size
            current_d, current_h, current_w = img_data.shape

            if (
                current_d != expected_d
                or current_h != expected_h
                or current_w != expected_w
            ):
                raise ValueError(
                    f"Expected image size {expected_d}x{expected_h}x{expected_w} "
                    f"but got {current_d}x{current_h}x{current_w} for {img_path}"
                )

            # Add channel dimension to numpy array
            img_data = np.expand_dims(img_data, axis=0)

            # Apply transforms
            img_data = self.transforms(img_data)

            # Convert to tensor if not already a tensor
            if not isinstance(img_data, torch.Tensor):
                img_data = torch.tensor(img_data, dtype=torch.float32)

            # Ensure the label is also a tensor
            label = torch.tensor(label, dtype=torch.long)

            return img_data, label

        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            # Return a default or raise the exception
            raise

## Model


In [None]:
# Modified 3D ResNet model with layer freezing
class MRIModel(nn.Module):
    def __init__(self, num_classes=2, freeze_layers=True):
        super(MRIModel, self).__init__()
        # Using a video ResNet and modifying it for 3D MRI
        self.resnet = models.r3d_18(weights=models.R3D_18_Weights.KINETICS400_V1)

        # Replace the first layer to accept single-channel input instead of 3
        self.resnet.stem[0] = nn.Conv3d(
            1,
            64,
            kernel_size=(3, 7, 7),
            stride=(1, 2, 2),
            padding=(1, 3, 3),
            bias=False,
        )

        # Replace the final fully connected layer for binary classification
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)

        # Freeze specific layers if requested
        if freeze_layers:
            self._freeze_layers()

    def _freeze_layers(self):
        """Freeze most layers of the ResNet model, leaving only layer4 and fc unfrozen"""
        # Freeze stem and layers 1-3
        # TODO loook at model in more detail and see where to freeze
        for name, param in self.resnet.named_parameters():
            if "layer4" not in name and "fc" not in name:
                param.requires_grad = False

    def count_trainable_params(self):
        """Count and return trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def count_total_params(self):
        """Count and return total parameters"""
        return sum(p.numel() for p in self.parameters())

    def forward(self, x):
        # Input: (B, 1, D, H, W)
        return self.resnet(x)

## Model Summary


In [None]:
def display_model_summary(model, input_size=(1, 1, 128, 128, 128), detailed=True):
    """
    Display a comprehensive summary of the model architecture and parameters.

    Args:
        model: The PyTorch model to analyze
        input_size: The input tensor size (batch_size, channels, depth, height, width)
        detailed: Whether to show detailed layer information
    """
    # Get basic model summary using torchinfo
    summary = torchinfo.summary(
        model,
        input_size=input_size,
        col_names=["input_size", "output_size", "num_params", "trainable"],
        verbose=0,
    )

    print(f"MODEL ARCHITECTURE SUMMARY:")
    print("=" * 80)
    print(summary)

    # Count parameters by layer type
    layer_counts = {}
    for name, module in model.named_modules():
        layer_type = module.__class__.__name__
        if layer_type not in layer_counts:
            layer_counts[layer_type] = {"count": 0, "params": 0, "trainable_params": 0}

        layer_counts[layer_type]["count"] += 1
        params = sum(p.numel() for p in module.parameters(recurse=False))
        trainable_params = sum(
            p.numel() for p in module.parameters(recurse=False) if p.requires_grad
        )

        layer_counts[layer_type]["params"] += params
        layer_counts[layer_type]["trainable_params"] += trainable_params

    # Create detailed layer information dataframe
    if detailed:
        layers_info = []
        for name, module in model.named_modules():
            if len(list(module.children())) == 0:  # Only leaf modules
                params = sum(p.numel() for p in module.parameters())
                trainable = sum(
                    p.numel() for p in module.parameters() if p.requires_grad
                )

                layers_info.append(
                    {
                        "Layer": name,
                        "Type": module.__class__.__name__,
                        "Parameters": params,
                        "Trainable": trainable,
                        "Frozen": params - trainable,
                    }
                )

        # Create and display DataFrame
        df = pd.DataFrame(layers_info)
        if not df.empty:
            print("\nDETAILED LAYER INFORMATION:")
            print("=" * 80)
            display(df)

    # Show frozen vs trainable stats
    total_params = model.count_total_params()
    trainable_params = model.count_trainable_params()
    frozen_params = total_params - trainable_params

    print("\nPARAMETER STATISTICS:")
    print("=" * 80)
    print(f"Total parameters:    {total_params:,}")
    print(
        f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)"
    )
    print(
        f"Frozen parameters:    {frozen_params:,} ({frozen_params/total_params*100:.2f}%)"
    )

    # Display model architecture as text
    print("\nMODEL ARCHITECTURE DETAILS:")
    print("=" * 80)
    print(model)

    # Return summary for potential further use
    return summary


# Example usage:
# model = MRIModel(num_classes=2, freeze_layers=True)
# display_model_summary(model)

## Train


In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """
    Train model for one epoch with optimized PyTorch practices.

    Args:
        model: PyTorch model to train
        dataloader: Training data loader
        criterion: Loss function
        optimizer: Optimizer
        device: Computation device (CPU/GPU/MPS)
        epoch: Current epoch number

    Returns:
        tuple: (epoch_loss, epoch_accuracy)
    """

    model.train()

    # Use tensors instead of scalar values
    running_loss_tensor = torch.tensor(0.0, device=device)
    correct_tensor = torch.tensor(0, device=device)
    total = 0

    for inputs, labels in tqdm(dataloader, desc=f"Training Epoch {epoch+1}"):
        # Move data to device with non_blocking for potential performance gain
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        # Zero gradients
        optimizer.zero_grad(set_to_none=True)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        running_loss_tensor += loss
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct_tensor += (predicted == labels).sum()  # Keep as tensor

    # Convert accumulated tensors to scalars ONCE at the end
    epoch_loss = running_loss_tensor.item() / total
    epoch_acc = 100 * correct_tensor.item() / total

    # Log epoch-level metrics
    wandb.log(
        {
            "train_loss": epoch_loss,
            "train_acc": epoch_acc,
            "epoch": epoch,
        }
    )

    return epoch_loss, epoch_acc

## Validate


In [None]:
def calculate_classification_metrics(all_labels, all_preds, all_probs):
    """Calculate and return classification metrics."""
    # Convert to numpy if they aren't already
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)

    # Class-wise accuracy
    class_0_mask = all_labels == 0
    class_1_mask = all_labels == 1

    class_0_acc = (
        100
        * np.sum(all_preds[class_0_mask] == all_labels[class_0_mask])
        / (np.sum(class_0_mask) + 1e-10)
    )
    class_1_acc = (
        100
        * np.sum(all_preds[class_1_mask] == all_labels[class_1_mask])
        / (np.sum(class_1_mask) + 1e-10)
    )

    # Calculate confusion matrix values
    true_positives = np.sum((all_preds == 1) & (all_labels == 1))
    false_positives = np.sum((all_preds == 1) & (all_labels == 0))
    true_negatives = np.sum((all_preds == 0) & (all_labels == 0))
    false_negatives = np.sum((all_preds == 0) & (all_labels == 1))

    precision = true_positives / (true_positives + false_positives + 1e-10)
    recall = true_positives / (true_positives + false_negatives + 1e-10)
    f1_score = 2 * precision * recall / (precision + recall + 1e-10)

    return {
        "class_0_acc": class_0_acc,
        "class_1_acc": class_1_acc,
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "confusion_data": {
            "true_positives": true_positives,
            "false_positives": false_positives,
            "true_negatives": true_negatives,
            "false_negatives": false_negatives,
        },
    }


def create_wandb_visualizations(all_labels, all_preds, all_probs):
    """Create and return W&B visualization objects."""
    # Generate confusion matrix
    confusion_matrix = wandb.plot.confusion_matrix(
        preds=all_preds, y_true=all_labels, class_names=["CN", "AD"]
    )

    try:
        # Create proper format for probabilities
        y_probas = np.zeros((len(all_labels), 2))
        y_probas[:, 0] = 1 - np.array(all_probs)  # CN probabilities
        y_probas[:, 1] = np.array(all_probs)  # AD probabilities

        # Generate ROC curve
        roc_curve = wandb.plot.roc_curve(
            all_labels,
            y_probas,
            classes_to_plot=[1],  # Plot ROC for AD class (positive class)
            labels=["CN", "AD"],
        )
        return {"confusion_matrix": confusion_matrix, "roc_curve": roc_curve}
    except Exception as e:
        print(f"Warning: ROC curve calculation failed: {e}")
        return {"confusion_matrix": confusion_matrix}


def validate(model, dataloader, criterion, device, epoch):
    """Validate the model and return performance metrics."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    # Track predictions and labels
    all_labels = []
    all_preds = []
    all_probs = []

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc=f"Validation Epoch {epoch+1}"):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            # Calculate accuracy
            probs = torch.nn.functional.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Store predictions and labels
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())

    # Calculate overall metrics
    val_loss = running_loss / len(dataloader)
    val_acc = 100 * correct / total

    # Calculate detailed metrics
    metrics = calculate_classification_metrics(all_labels, all_preds, all_probs)

    # Generate visualizations
    viz = create_wandb_visualizations(all_labels, all_preds, all_probs)

    # Log to W&B
    log_dict = {
        "val_loss": val_loss,
        "val_acc": val_acc,
        "val_CN_acc": metrics["class_0_acc"],
        "val_AD_acc": metrics["class_1_acc"],
        "val_precision": metrics["precision"],
        "val_recall": metrics["recall"],
        "val_f1": metrics["f1_score"],
        "confusion_matrix": viz["confusion_matrix"],
        "epoch": epoch,
    }

    if "roc_curve" in viz:
        log_dict["roc_curve"] = viz["roc_curve"]

    wandb.log(log_dict)

    return val_loss, val_acc

## Main


In [None]:
def main():
    # Initialize wandb
    wandb.init(
        project="mri-alzheimers-classification",
        config={
            "architecture": "3D-ResNet18-FrozenLayers",
            "dataset": "MRI-AD-CN",
            "epochs": 20,
            "batch_size": 2,
            "learning_rate": 0.0001,
            "optimizer": "AdamW",
            "device": str(device),
            "input_dimensions": "128x128x128",
            "freeze_layers": True,
            "data_augmentation": True,
        },
    )

    # Parameters
    data_root = DATASET  # Update this to your dataset path
    batch_size = 2  # Reduced batch size for memory constraints
    num_epochs = 20  # Reduced epochs for testing
    learning_rate = 0.0001
    freeze_layers = True
    use_augmentation = True

    # Create datasets with augmentation for training
    train_dataset = MRIDataset(
        data_root, split="train", apply_augmentation=use_augmentation
    )
    val_dataset = MRIDataset(
        data_root, split="val", apply_augmentation=False
    )  # No augmentation for validation

    # Log dataset stats
    wandb.config.update(
        {
            "train_samples": len(train_dataset),
            "val_samples": len(val_dataset),
            "train_AD_samples": train_dataset.labels.count(1),
            "train_CN_samples": train_dataset.labels.count(0),
            "val_AD_samples": val_dataset.labels.count(1),
            "val_CN_samples": val_dataset.labels.count(0),
        }
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )

    # Initialize the model with layer freezing
    model = MRIModel(num_classes=2, freeze_layers=freeze_layers)
    model = model.to(device)

    # Log parameter statistics
    trainable_params = model.count_trainable_params()
    total_params = model.count_total_params()
    frozen_params = total_params - trainable_params

    print(f"Total parameters: {total_params:,}")
    print(
        f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%})"
    )
    print(f"Frozen parameters: {frozen_params:,} ({frozen_params/total_params:.2%})")

    # Log model architecture and parameter stats
    wandb.config.update(
        {
            "total_params": total_params,
            "trainable_params": trainable_params,
            "frozen_params": frozen_params,
            "frozen_percentage": frozen_params / total_params,
        }
    )

    wandb.watch(model, log="all", log_freq=10)

    # Loss function with class weighting to handle imbalance
    # Calculate class weights based on sample distribution
    num_ad = train_dataset.labels.count(1)
    num_cn = train_dataset.labels.count(0)
    total = num_ad + num_cn

    # Inverse frequency weighting
    weight_cn = total / (2 * num_cn) if num_cn > 0 else 1.0
    weight_ad = total / (2 * num_ad) if num_ad > 0 else 1.0

    class_weights = torch.tensor([weight_cn, weight_ad], device=device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    # Optimizer with parameter groups and different learning rates
    # Higher learning rate for new/unfrozen layers, lower for pre-trained unfrozen layers
    fc_params = list(model.resnet.fc.parameters())
    other_params = [
        p
        for name, p in model.named_parameters()
        if p.requires_grad and not any(p is fc_param for fc_param in fc_params)
    ]

    # Set up parameter groups with different learning rates
    param_groups = [
        {"params": other_params, "lr": learning_rate},
        {
            "params": fc_params,
            "lr": learning_rate * 10,
        },  # Higher learning rate for final layer
    ]

    optimizer = optim.AdamW(param_groups, lr=learning_rate, weight_decay=0.01)

    # Learning rate scheduler with cosine annealing for better convergence
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=5, T_mult=1, eta_min=learning_rate / 100
    )

    # Early stopping implementation
    patience = 5
    early_stopping_counter = 0

    # Check for checkpoint and load if exists
    checkpoint_path = "checkpoints/checkpoint.pth"
    start_epoch = 0
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        best_val_acc = checkpoint.get(
            "best_val_acc", 0.0
        )  # Handle if best_val_acc wasn't saved.
        best_val_loss = checkpoint.get("best_val_loss", float("inf"))
        print(f"Loaded checkpoint from epoch {start_epoch}")
    else:
        best_val_acc = 0.0
        best_val_loss = float("inf")

    # Training loop with early stopping
    for epoch in range(start_epoch, num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device, epoch
        )
        val_loss, val_acc = validate(model, val_loader, criterion, device, epoch)

        # Update learning rate based on scheduler
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        print(f"Current learning rate: {current_lr:.6f}")

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Save the best model by validation accuracy
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "val_acc": val_acc,
                    "val_loss": val_loss,
                },
                "best_model_acc.pth",
            )
            print("Model saved (best accuracy)!")

            # Log best model as artifact
            artifact = wandb.Artifact("best_model_acc", type="model")
            artifact.add_file("best_model_acc.pth")
            wandb.log_artifact(artifact)

            # Reset early stopping counter on improvement
            early_stopping_counter = 0

        # Save best model by validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "val_acc": val_acc,
                    "val_loss": val_loss,
                },
                "best_model_loss.pth",
            )
            print("Model saved (best loss)!")

            # Log best model as artifact
            artifact = wandb.Artifact("best_model_loss", type="model")
            artifact.add_file("best_model_loss.pth")
            wandb.log_artifact(artifact)
        else:
            # Increment early stopping counter
            early_stopping_counter += 1

        # Check for early stopping
        if early_stopping_counter >= patience:
            print(f"Early stopping after {epoch+1} epochs without improvement.")
            break

        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "best_val_acc": best_val_acc,
                "best_val_loss": best_val_loss,
            },
            checkpoint_path,
        )

        if device.type == "mps":
            empty_cache()

    # Create test dataset and dataloader
    test_dataset = MRIDataset(
        data_root, split="test", apply_augmentation=False
    )  # No augmentation for test set
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )

    # After training, load best model for final evaluation
    checkpoint = torch.load("best_model_acc.pth")
    model.load_state_dict(checkpoint["model_state_dict"])
    print(
        f"Loaded best model from epoch {checkpoint['epoch']+1} with accuracy {checkpoint['val_acc']:.2f}%"
    )

    # Final evaluation on test set
    final_test_loss, final_test_acc = validate(
        model, test_loader, criterion, device, num_epochs
    )
    print(f"Final test accuracy: {final_test_acc:.2f}%")

    # Log final model summary
    wandb.run.summary["best_val_acc"] = best_val_acc
    wandb.run.summary["best_val_loss"] = best_val_loss
    wandb.run.summary["final_val_acc"] = final_test_acc
    wandb.run.summary["final_val_loss"] = final_test_loss
    wandb.run.summary["total_epochs"] = epoch + 1 if "epoch" in locals() else 0

    # Close wandb run
    wandb.finish()


if __name__ == "__main__":
    # Set random seeds for reproducibility
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    if not os.path.exists("checkpoints"):
        os.makedirs("checkpoints")

    main()