# Install and Import Required Libraries


In [None]:
# python 3.10

%pip install torch torchvision torchaudio nibabel numpy tqdm wandb monai

DATASET = "./DATA/ADNI_96_TEST"

In [None]:
import torch

print(torch.backends.mps.is_available())  # Should return True

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models.video as models
import nibabel as nib
from tqdm import tqdm
import wandb
import random
from scipy.ndimage import rotate, shift
from monai.transforms import (
    Compose,
    RandRotate90,
    RandFlip,
    RandGaussianNoise,
    RandGaussianSmooth,
    RandAdjustContrast,
    RandScaleIntensity,
)

# Check if Metal is available on macOS
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Metal) device")
else:
    device = torch.device("cpu")
    print("MPS not available, using CPU")

if device.type == "mps":
    # Empty CUDA cache periodically during training to avoid memory fragmentation
    def empty_cache():
        try:
            # For newer PyTorch versions with MPS cache management
            torch.mps.empty_cache()
        except:
            pass  # Ignore if this function doesn't exist


class MRIDataset(Dataset):
    def __init__(self, root_dir, split="train", apply_augmentation=False):
        self.root_dir = root_dir
        self.split = split
        self.samples = []
        self.labels = []
        self.apply_augmentation = apply_augmentation

        # Get all files from AD and CN directories
        ad_dir = os.path.join(root_dir, split, "AD")
        cn_dir = os.path.join(root_dir, split, "CN")

        # Load AD samples (label 1)
        for file in os.listdir(ad_dir):
            if file.endswith(".nii.gz"):
                self.samples.append(os.path.join(ad_dir, file))
                self.labels.append(1)  # AD class

        # Load CN samples (label 0)
        for file in os.listdir(cn_dir):
            if file.endswith(".nii.gz"):
                self.samples.append(os.path.join(cn_dir, file))
                self.labels.append(0)  # CN class

        # Setup augmentation transforms using MONAI - WITHOUT normalization
        if apply_augmentation:
            self.transforms = Compose(
                [
                    RandRotate90(prob=0.5, spatial_axes=(1, 2)),
                    RandFlip(prob=0.5, spatial_axis=0),
                    RandGaussianNoise(prob=0.2, mean=0.0, std=0.1),
                    RandGaussianSmooth(prob=0.2, sigma_x=(0.5, 1.5)),
                    RandAdjustContrast(prob=0.3, gamma=(0.7, 1.3)),
                    RandScaleIntensity(prob=0.3, factors=0.2),
                    # NormalizeIntensity line removed
                ]
            )
        else:
            # Empty transform composition - no normalization
            self.transforms = Compose([])

        print(f"Loaded {len(self.samples)} samples for {split} split")
        print(f"Augmentation applied: {apply_augmentation}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Load the .nii.gz file
        img_path = self.samples[idx]
        label = self.labels[idx]

        # Load image using nibabel
        img = nib.load(img_path)
        img_data = img.get_fdata()

        # Ensure the size is exactly 96x96x96
        current_d, current_h, current_w = img_data.shape
        if current_d != 96 or current_h != 96 or current_w != 96:
            raise ValueError(
                f"Expected image size 96x96x96 but got {current_d}x{current_h}x{current_w}"
            )

        # Convert to tensor
        img_tensor = torch.tensor(img_data, dtype=torch.float32).unsqueeze(
            0
        )  # Add channel dim

        # Apply transforms
        if len(self.transforms.transforms) > 0:  # Only apply if transforms exist
            img_tensor = self.transforms(img_tensor)

        return img_tensor, label


# Modified 3D ResNet model with layer freezing
class MRIModel(nn.Module):
    def __init__(self, num_classes=2, freeze_layers=True):
        super(MRIModel, self).__init__()
        # Using a video ResNet and modifying it for 3D MRI
        self.resnet = models.r3d_18(weights=models.R3D_18_Weights.KINETICS400_V1)

        # Replace the first layer to accept single-channel input instead of 3
        self.resnet.stem[0] = nn.Conv3d(
            1,
            64,
            kernel_size=(3, 7, 7),
            stride=(1, 2, 2),
            padding=(1, 3, 3),
            bias=False,
        )

        # Replace the final fully connected layer for binary classification
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)

        # Freeze specific layers if requested
        if freeze_layers:
            self._freeze_layers()

    def _freeze_layers(self):
        """Freeze early layers of the ResNet model"""
        # Freeze stem and layer1 (early feature extractors)
        for param in self.resnet.stem.parameters():
            param.requires_grad = False

        for param in self.resnet.layer1.parameters():
            param.requires_grad = False

        print("Frozen stem and layer1 of ResNet")

    def count_trainable_params(self):
        """Count and return trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def count_total_params(self):
        """Count and return total parameters"""
        return sum(p.numel() for p in self.parameters())

    def forward(self, x):
        # Input: (B, 1, D, H, W)
        return self.resnet(x)


# Training function
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    # Track batch-level metrics
    all_labels = []
    all_preds = []

    for batch_idx, (inputs, labels) in enumerate(tqdm(dataloader, desc="Training")):
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad(set_to_none=True)  # Faster than setting to zero

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Calculate accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Store predictions and labels for metrics
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

        # Log batch-level metrics
        wandb.log(
            {
                "batch_loss": loss.item(),
                "batch_acc": 100 * (predicted == labels).sum().item() / labels.size(0),
                "batch": epoch * len(dataloader) + batch_idx,
                "learning_rate": optimizer.param_groups[0]["lr"],
            }
        )

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100 * correct / total

    # Calculate additional metrics
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)

    # Class-wise accuracy
    class_0_mask = all_labels == 0
    class_1_mask = all_labels == 1

    class_0_acc = (
        100
        * np.sum(all_preds[class_0_mask] == all_labels[class_0_mask])
        / (np.sum(class_0_mask) + 1e-10)
    )
    class_1_acc = (
        100
        * np.sum(all_preds[class_1_mask] == all_labels[class_1_mask])
        / (np.sum(class_1_mask) + 1e-10)
    )

    # Log epoch-level metrics
    wandb.log(
        {
            "train_loss": epoch_loss,
            "train_acc": epoch_acc,
            "train_CN_acc": class_0_acc,
            "train_AD_acc": class_1_acc,
            "epoch": epoch,
        }
    )

    return epoch_loss, epoch_acc


# Validation function
def validate(model, dataloader, criterion, device, epoch):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    # Track for metrics
    all_labels = []
    all_preds = []
    all_probs = []

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()

            # For accuracy
            probs = torch.nn.functional.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Store for metrics
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())  # Probability of AD class

    val_loss = running_loss / len(dataloader)
    val_acc = 100 * correct / total

    # Convert to numpy for metric calculation
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)

    # Class-wise accuracy
    class_0_mask = all_labels == 0
    class_1_mask = all_labels == 1

    class_0_acc = (
        100
        * np.sum(all_preds[class_0_mask] == all_labels[class_0_mask])
        / (np.sum(class_0_mask) + 1e-10)
    )
    class_1_acc = (
        100
        * np.sum(all_preds[class_1_mask] == all_labels[class_1_mask])
        / (np.sum(class_1_mask) + 1e-10)
    )

    # Custom metrics
    true_positives = np.sum((all_preds == 1) & (all_labels == 1))
    false_positives = np.sum((all_preds == 1) & (all_labels == 0))
    true_negatives = np.sum((all_preds == 0) & (all_labels == 0))
    false_negatives = np.sum((all_preds == 0) & (all_labels == 1))

    precision = true_positives / (true_positives + false_positives + 1e-10)
    recall = true_positives / (true_positives + false_negatives + 1e-10)
    f1_score = 2 * precision * recall / (precision + recall + 1e-10)

    # Generate confusion matrix for visualization
    confusion_matrix = wandb.plot.confusion_matrix(
        preds=all_preds, y_true=all_labels, class_names=["CN", "AD"]
    )

    try:
        # Create proper format for probabilities
        y_probas = np.zeros((len(all_labels), 2))
        y_probas[:, 0] = 1 - np.array(all_probs)  # CN probabilities
        y_probas[:, 1] = np.array(all_probs)  # AD probabilities

        # Now call the function with properly formatted data
        roc_curve = wandb.plot.roc_curve(
            all_labels,
            y_probas,
            classes_to_plot=[1],  # Plot ROC for AD class (positive class)
            labels=["CN", "AD"],
        )
    except Exception as e:
        print(f"Warning: ROC curve calculation failed: {e}")
        roc_curve = None

    # Log validation metrics (with conditional ROC curve)
    log_dict = {
        "val_loss": val_loss,
        "val_acc": val_acc,
        "val_CN_acc": class_0_acc,
        "val_AD_acc": class_1_acc,
        "val_precision": precision,
        "val_recall": recall,
        "val_f1": f1_score,
        "confusion_matrix": confusion_matrix,
        "epoch": epoch,
    }

    if roc_curve is not None:
        log_dict["roc_curve"] = roc_curve

    wandb.log(log_dict)

    return val_loss, val_acc


def main():
    # Initialize wandb
    wandb.init(
        project="mri-alzheimers-classification",
        config={
            "architecture": "3D-ResNet18-FrozenLayers",
            "dataset": "MRI-AD-CN",
            "epochs": 10,  # Increased from 5 to 10
            "batch_size": 2,
            "learning_rate": 0.0001,
            "optimizer": "AdamW",
            "device": str(device),
            "input_dimensions": "96x96x96",
            "freeze_layers": True,
            "data_augmentation": True,
        },
    )

    # Parameters
    data_root = DATASET  # Update this to your dataset path
    batch_size = 2  # Reduced batch size for memory constraints
    num_epochs = 10
    learning_rate = 0.0001
    freeze_layers = True
    use_augmentation = True

    # Create datasets with augmentation for training
    train_dataset = MRIDataset(
        data_root, split="train", apply_augmentation=use_augmentation
    )
    val_dataset = MRIDataset(
        data_root, split="val", apply_augmentation=False
    )  # No augmentation for validation

    # Log dataset stats
    wandb.config.update(
        {
            "train_samples": len(train_dataset),
            "val_samples": len(val_dataset),
            "train_AD_samples": train_dataset.labels.count(1),
            "train_CN_samples": train_dataset.labels.count(0),
            "val_AD_samples": val_dataset.labels.count(1),
            "val_CN_samples": val_dataset.labels.count(0),
        }
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=0
    )

    # Initialize the model with layer freezing
    model = MRIModel(num_classes=2, freeze_layers=freeze_layers)
    model = model.to(device)

    # Log parameter statistics
    trainable_params = model.count_trainable_params()
    total_params = model.count_total_params()
    frozen_params = total_params - trainable_params

    print(f"Total parameters: {total_params:,}")
    print(
        f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%})"
    )
    print(f"Frozen parameters: {frozen_params:,} ({frozen_params/total_params:.2%})")

    # Log model architecture and parameter stats
    wandb.config.update(
        {
            "total_params": total_params,
            "trainable_params": trainable_params,
            "frozen_params": frozen_params,
            "frozen_percentage": frozen_params / total_params,
        }
    )

    wandb.watch(model, log="all", log_freq=10)

    # Loss function with class weighting to handle imbalance
    # Calculate class weights based on sample distribution
    num_ad = train_dataset.labels.count(1)
    num_cn = train_dataset.labels.count(0)
    total = num_ad + num_cn

    # Inverse frequency weighting
    weight_cn = total / (2 * num_cn) if num_cn > 0 else 1.0
    weight_ad = total / (2 * num_ad) if num_ad > 0 else 1.0

    class_weights = torch.tensor([weight_cn, weight_ad], device=device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    # Optimizer with parameter groups and different learning rates
    # Higher learning rate for new/unfrozen layers, lower for pre-trained unfrozen layers
    fc_params = list(model.resnet.fc.parameters())
    other_params = [
        p
        for name, p in model.named_parameters()
        if p.requires_grad and not any(p is fc_param for fc_param in fc_params)
    ]

    # Set up parameter groups with different learning rates
    param_groups = [
        {"params": other_params, "lr": learning_rate},
        {
            "params": fc_params,
            "lr": learning_rate * 10,
        },  # Higher learning rate for final layer
    ]

    optimizer = optim.AdamW(param_groups, lr=learning_rate, weight_decay=0.01)

    # Learning rate scheduler with cosine annealing for better convergence
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=5, T_mult=1, eta_min=learning_rate / 100
    )

    # Early stopping implementation
    patience = 5
    early_stopping_counter = 0
    best_val_loss = float("inf")
    best_val_acc = 0.0

    # Training loop with early stopping
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device, epoch
        )
        val_loss, val_acc = validate(model, val_loader, criterion, device, epoch)

        # Update learning rate based on scheduler
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        print(f"Current learning rate: {current_lr:.6f}")

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Save the best model by validation accuracy
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "val_acc": val_acc,
                    "val_loss": val_loss,
                },
                "best_model_acc.pth",
            )
            print("Model saved (best accuracy)!")

            # Log best model as artifact
            artifact = wandb.Artifact("best_model_acc", type="model")
            artifact.add_file("best_model_acc.pth")
            wandb.log_artifact(artifact)

            # Reset early stopping counter on improvement
            early_stopping_counter = 0

        # Save best model by validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "val_acc": val_acc,
                    "val_loss": val_loss,
                },
                "best_model_loss.pth",
            )
            print("Model saved (best loss)!")

            # Log best model as artifact
            artifact = wandb.Artifact("best_model_loss", type="model")
            artifact.add_file("best_model_loss.pth")
            wandb.log_artifact(artifact)
        else:
            # Increment early stopping counter
            early_stopping_counter += 1

        # Check for early stopping
        if early_stopping_counter >= patience:
            print(f"Early stopping after {epoch+1} epochs without improvement.")
            break

        if device.type == "mps":
            empty_cache()

    # After training, load best model for final evaluation
    checkpoint = torch.load("best_model_acc.pth")
    model.load_state_dict(checkpoint["model_state_dict"])
    print(
        f"Loaded best model from epoch {checkpoint['epoch']+1} with accuracy {checkpoint['val_acc']:.2f}%"
    )

    # Final evaluation on validation set
    final_val_loss, final_val_acc = validate(
        model, val_loader, criterion, device, num_epochs
    )
    print(f"Final validation accuracy: {final_val_acc:.2f}%")

    # Log final model summary
    wandb.run.summary["best_val_acc"] = best_val_acc
    wandb.run.summary["best_val_loss"] = best_val_loss
    wandb.run.summary["final_val_acc"] = final_val_acc
    wandb.run.summary["final_val_loss"] = final_val_loss
    wandb.run.summary["total_epochs"] = epoch + 1

    # Close wandb run
    wandb.finish()


if __name__ == "__main__":
    # Set random seeds for reproducibility
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    main()