In [None]:
# =============================================
# ResNet Training Script (Fairness-Aware Eczema Diagnosis)
# Author: Domante Rabasauskaite
# Date: 09 April 2025
# =============================================

import os
import glob
import cv2
import torch
import numpy as np
import pandas as pd
from PIL import Image
from torch import nn, optim
from torchvision import transforms, datasets, models
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import StratifiedGroupKFold, StratifiedKFold
from torch.nn.utils import prune
from torch.quantization import quantize_dynamic
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
from tqdm import tqdm
import optuna
import matplotlib.pyplot as plt

# Timestamp for saving outputs (e.g., models, plots)
timestamp = "05_05_2025"

# ---------------------------------------------
# 1. HYPERPARAMETERS & DATA PATHS
# ---------------------------------------------

# DermNet pretraining parameters
lr = 1e-5                        # Learning rate
weight_decay = 1e-2             # L2 regularization
dropout_rate = 0.2              # Dropout to prevent overfitting
num_epochs_pretrain = 5         # Epochs for DermNet pretraining
patience_pretrain = 2           # Early stopping patience

# PASSION fine-tuning parameters
fairness_lambda_passion = 0.1   # Weight for fairness loss term
num_epochs_finetune = 15        # Epochs for PASSION fine-tuning
patience_finetune = 3           # Early stopping patience for fine-tuning
freeze_epochs_passion = 5       # Epochs with frozen base layers

# File paths - replace with your actual locations
dermnet_train_root = r"C:\Users\DermNet\train"
dermnet_unseen_root = r"C:\Users\DermNet\test"  # Optional test folder

passion_csv_path = r"C:\Users\PASSION_MICCAI_2024\label.csv"
passion_image_folder = r"C:\Users\PASSION_MICCAI_2024\images"

# Device setup: use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------------------------
# 2. PASSION CSV SPLITTING FOR CROSS-VALIDATION
# ---------------------------------------------

print("Loading PASSION CSV...")
df_full = pd.read_csv(passion_csv_path)
df_full.columns = df_full.columns.str.strip()  # Strip column names for consistency
print(f"PASSION CSV loaded: {len(df_full)} rows.")

# Use StratifiedGroupKFold to split while preserving condition balance and subject grouping
print("Splitting PASSION dataset using StratifiedGroupKFold...")
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)
for train_idx, val_idx in sgkf.split(df_full, df_full["conditions_PASSION"], groups=df_full["subject_id"]):
    train_df = df_full.iloc[train_idx]
    val_df = df_full.iloc[val_idx]
    break  # Use only the first fold for this run

# Save the split to CSV for reproducibility
train_csv = "train_split.csv"
val_csv = "val_split.csv"
train_df.to_csv(train_csv, index=False)
val_df.to_csv(val_csv, index=False)
print(f"PASSION split complete: {len(train_df)} training rows, {len(val_df)} validation rows.")

# ---------------------------------------------
# 3. IMAGE TRANSFORMATIONS (WITH DADA AUGMENTATION)
# ---------------------------------------------

print("Defining image transformations...")

class DADATransform:
    """
    Proxy for Differentiable Automatic Data Augmentation (DADA).
    Uses RandAugment from torchvision to simulate DADA behavior.
    
    Args:
        num_ops (int): Number of augmentation ops to apply.
        magnitude (int): Magnitude for augmentation.
    """
    def __init__(self, num_ops=2, magnitude=9):
        self.augment = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude)

    def __call__(self, img):
        return self.augment(img)

# Training transforms: includes augmentation and normalization
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),                    # Resize input image
    DADATransform(num_ops=2, magnitude=9),            # Apply RandAugment-style augmentation
    transforms.RandomHorizontalFlip(),                # Random flip
    transforms.RandomRotation(15),                    # Random rotation up to ±15 degrees
    transforms.ColorJitter(brightness=0.2,            # Simulate lighting variation
                           contrast=0.2, 
                           saturation=0.2, 
                           hue=0.1),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  # Random crop with resizing
    transforms.ToTensor(),                            # Convert to PyTorch tensor
    transforms.Normalize(mean=(0.485, 0.456, 0.406),   # Normalize using ImageNet statistics
                         std=(0.229, 0.224, 0.225))
])

# Validation/test transforms: no augmentation, only resizing and normalization
valid_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), 
                         std=(0.229, 0.224, 0.225))
])

print("Image transformations defined!")


# ---------------------------------------------
# 4. DATASET CLASSES (BINARY RELABELING)
# ---------------------------------------------

class DermNetDatasetBinary(Dataset):
    """
    Custom PyTorch Dataset for DermNet images with binary relabeling.

    Labels:
        - 1 if folder name indicates eczema ("eczema photos", "atopic dermatitis photos")
        - 0 for all other folders.

    Args:
        root (str): Path to DermNet image folders.
        transform (callable, optional): Image preprocessing transformations.
    """
    def __init__(self, root, transform=None):
        print(f"Initializing DermNetDatasetBinary from {root} ...")
        self.dataset = datasets.ImageFolder(root=root, transform=transform)
        self.transform = transform

        eczema_names = {"eczema photos", "atopic dermatitis photos"}
        self.binary_label_map = {}

        # Create binary label mapping: eczema = 1, others = 0
        for idx, cls in enumerate(self.dataset.classes):
            folder_name = cls.strip().lower()
            self.binary_label_map[idx] = 1 if folder_name in eczema_names else 0

        print(f"Found {len(self.dataset)} images across {len(self.dataset.classes)} classes.")
        print(f"Binary mapping: {self.binary_label_map}")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        binary_label = self.binary_label_map[label]
        dummy_fitz = torch.tensor(0, dtype=torch.long)  # Fitzpatrick not available in DermNet
        return image, torch.tensor(binary_label, dtype=torch.long), dummy_fitz


class PassionDatasetBinary(Dataset):
    """
    Custom PyTorch Dataset for PASSION images with binary relabeling and Fitzpatrick skin types.

    Labels:
        - 1 if condition is 'eczema', else 0
        - Fitzpatrick skin type is returned for fairness-aware training

    Args:
        csv_file (str): Path to CSV with subject metadata.
        image_folder (str): Path to PASSION image directory.
        transform (callable, optional): Image preprocessing transformations.
        mode (str): 'train' or 'validation', used for logging/debugging.
    """
    def __init__(self, csv_file, image_folder, transform=None, mode="train"):
        print(f"Initializing PassionDatasetBinary for {mode}...")
        self.data = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.transform = transform
        self.mode = mode

        # Binary label mapping: eczema = 1, other = 0
        self.data["binary_label"] = self.data["conditions_PASSION"].apply(
            lambda x: 1 if x.strip().lower() == "eczema" else 0
        )

        # Associate images with subject IDs
        self.image_files = []
        for subject_id in self.data["subject_id"]:
            subject_imgs = glob.glob(os.path.join(image_folder, f"{subject_id}_*.jpg"))
            self.image_files.extend([(img, subject_id) for img in subject_imgs])

        print(f"PASSION Binary Dataset loaded: {len(self.image_files)} images for {mode}.")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path, subject_id = self.image_files[idx]

        # Retrieve metadata from CSV
        row = self.data[self.data["subject_id"] == subject_id]
        if row.empty:
            raise ValueError(f"Subject {subject_id} not found in CSV.")

        label = int(row["binary_label"].values[0])
        fitz = int(row["fitzpatrick"].values[0])

        # Load and transform image
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long), torch.tensor(fitz, dtype=torch.long)

# ---------------------------------------------
# 5. DATALOADER CREATION
# ---------------------------------------------

# --- DermNet Pretraining Dataloaders ---
print("Creating DermNet dataset for pretraining (binary re-labeling)...")
dermnet_full = DermNetDatasetBinary(root=dermnet_train_root, transform=train_transforms)
dermnet_len = len(dermnet_full)
val_size = int(0.2 * dermnet_len)
train_size = dermnet_len - val_size

# Split DermNet into training and validation sets
dermnet_train_ds, dermnet_val_ds = random_split(dermnet_full, [train_size, val_size])
pretrain_train_loader = DataLoader(dermnet_train_ds, batch_size=32, shuffle=True, num_workers=0)
pretrain_val_loader   = DataLoader(dermnet_val_ds, batch_size=32, shuffle=False, num_workers=0)
print(f"DermNet pretraining: {train_size} train images, {val_size} val images.")

# --- PASSION Fine-tuning Dataloaders ---
print("Creating PASSION binary dataset from CSV splits...")
passion_train_ds = PassionDatasetBinary(train_csv, passion_image_folder, transform=train_transforms, mode="train")
passion_val_ds   = PassionDatasetBinary(val_csv, passion_image_folder, transform=valid_transforms, mode="validation")

train_loader = DataLoader(passion_train_ds, batch_size=32, shuffle=True, num_workers=0)
valid_loader = DataLoader(passion_val_ds, batch_size=32, shuffle=False, num_workers=0)
print(f"PASSION binary: {len(passion_train_ds)} train images, {len(passion_val_ds)} val images.")

# ---------------------------------------------
# 6. MODEL DEFINITION WITH FITZPATRICK INTEGRATION (RESNET50)
# ---------------------------------------------

class ResNet50ModelWithFitzpatrick(nn.Module):
    """
    Custom ResNet50-based model with additional Fitzpatrick skin type embedding.
    
    Combines visual features from a pretrained ResNet50 backbone with a learned embedding
    of the Fitzpatrick skin type to improve fairness-aware classification.

    Args:
        num_classes (int): Number of output classes (e.g., 2 for binary classification).
        fitzpatrick_vocab_size (int): Number of distinct Fitzpatrick types (typically 6).
        fitz_emb_dim (int): Dimensionality of the Fitzpatrick embedding vector.
        dropout (float): Dropout rate for regularization.
    """
    def __init__(self, num_classes, fitzpatrick_vocab_size, fitz_emb_dim=32, dropout=0.2):
        super().__init__()
        print("Initializing ResNet50 with Fitzpatrick embedding...")

        self.resnet = models.resnet50(pretrained=True)
        in_features = self.resnet.fc.in_features  # Typically 2048

        self.resnet.fc = nn.Identity()  # Remove default classification layer

        # Add embedding layer for Fitzpatrick skin types
        self.fitz_emb = nn.Embedding(fitzpatrick_vocab_size, fitz_emb_dim)

        # Combined classifier: visual features + Fitzpatrick embedding
        self.classifier = nn.Sequential(
            nn.Linear(in_features + fitz_emb_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )
        print("ResNet50 model created!")

    def forward(self, x, fitz):
        features = self.resnet(x)           # Extract features from images
        emb = self.fitz_emb(fitz)           # Embed Fitzpatrick skin type
        combined = torch.cat((features, emb), dim=1)  # Concatenate features
        return self.classifier(combined)


def replace_classifier_resnet(model, num_classes_new, dropout, fitz_emb_dim=32):
    """
    Utility function to replace the classifier head of a ResNet50 model
    with a new classifier supporting Fitzpatrick embeddings.

    Args:
        model (ResNet50ModelWithFitzpatrick): The base model.
        num_classes_new (int): New number of output classes.
        dropout (float): Dropout rate to apply.
        fitz_emb_dim (int): Dimensionality of Fitzpatrick embedding.
    """
    in_features = model.resnet.fc.in_features
    model.classifier = nn.Sequential(
        nn.Linear(in_features + fitz_emb_dim, 256),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(256, num_classes_new)
    )
    print(f"Classifier replaced with {num_classes_new} output classes.")

# ---------------------------------------------
# 7. TRAINING FUNCTION WITH EARLY STOPPING & FAIRNESS PENALTY
# ---------------------------------------------

def train_model(model, train_loader, val_loader, device, 
                num_epochs, freeze_epochs=0, fairness_lambda=0.0, patience=2, lr=1e-5):
    """
    Train a model with optional layer freezing, early stopping, and fairness-aware regularization.

    Args:
        model (nn.Module): The ResNet model to train.
        train_loader (DataLoader): Dataloader for training.
        val_loader (DataLoader): Dataloader for validation.
        device (torch.device): Device to use (CPU or GPU).
        num_epochs (int): Total number of training epochs.
        freeze_epochs (int): Number of initial epochs with frozen backbone.
        fairness_lambda (float): Weight for fairness penalty between Fitzpatrick groups.
        patience (int): Early stopping patience.
        lr (float): Learning rate.
    """
    print(f"Starting training: freeze_epochs={freeze_epochs}, fairness_lambda={fairness_lambda}")
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = nn.CrossEntropyLoss()

    best_val_loss = float('inf')
    epochs_no_improve = 0

    # Optionally freeze ResNet backbone for initial training epochs
    if freeze_epochs > 0:
        for param in model.resnet.parameters():
            param.requires_grad = False

    for epoch in range(num_epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0

        # --- Training loop ---
        for batch_idx, (images, labels, fitz) in enumerate(train_loader):
            images, labels, fitz = images.to(device), labels.to(device), fitz.to(device)
            optimizer.zero_grad()
            outputs = model(images, fitz)
            loss = criterion(outputs, labels)

            # Fairness-aware regularization: balance loss across skin tones
            light_mask = (fitz <= 3)
            dark_mask = (fitz > 3)
            if fairness_lambda > 0 and light_mask.sum() > 0 and dark_mask.sum() > 0:
                loss_light = criterion(outputs[light_mask], labels[light_mask])
                loss_dark = criterion(outputs[dark_mask], labels[dark_mask])
                loss += fairness_lambda * torch.abs(loss_light - loss_dark)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)

            if (batch_idx + 1) % 10 == 0:
                print(f"Batch {batch_idx+1}/{len(train_loader)}: Loss = {loss.item():.4f}")

        train_loss = total_loss / len(train_loader)
        train_acc = correct / total
        scheduler.step()

        # --- Unfreeze ResNet after freeze_epochs ---
        if epoch == freeze_epochs:
            print("Unfreezing ResNet backbone for full fine-tuning.")
            for param in model.resnet.parameters():
                param.requires_grad = True
            optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
            scheduler = CosineAnnealingLR(optimizer, T_max=(num_epochs - freeze_epochs))

        # --- Validation loop ---
        model.eval()
        val_loss_sum, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for images, labels, fitz in val_loader:
                images, labels, fitz = images.to(device), labels.to(device), fitz.to(device)
                outputs = model(images, fitz)
                loss = criterion(outputs, labels)
                val_loss_sum += loss.item()
                val_correct += (outputs.argmax(dim=1) == labels).sum().item()
                val_total += labels.size(0)

        val_loss = val_loss_sum / len(val_loader)
        val_acc = val_correct / val_total

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Train Acc = {train_acc:.4f} | Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.4f}")

        # --- Early stopping check ---
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), f"best_model_fold_5_{timestamp}.pth")
            print("Model improved; saved best model.")
        else:
            epochs_no_improve += 1
            print(f"No improvement for {epochs_no_improve} epoch(s).")
            if epochs_no_improve >= patience:
                print("Early stopping triggered!")
                break

# ---------------------------------------------
# 8. FUNCTION TO APPLY PRUNING & QUANTIZATION
# ---------------------------------------------

def apply_pruning(model, amount=0.3):
    """
    Applies global unstructured L1 pruning to all linear layers in the model.

    Args:
        model (nn.Module): The model to prune.
        amount (float): Fraction of weights to prune (e.g., 0.3 = 30% sparsity).
    """
    print("Applying global unstructured pruning to model...")
    parameters_to_prune = []
    for module in model.modules():
        if isinstance(module, nn.Linear):
            parameters_to_prune.append((module, 'weight'))

    prune.global_unstructured(parameters_to_prune,
                              pruning_method=prune.L1Unstructured,
                              amount=amount)
    print("Pruning applied.")

# ---------------------------------------------
# 9. STAGE 1: PRETRAINING ON DERMNET (BINARY CLASSIFICATION)
# ---------------------------------------------

print("=== Stage 1: Pretraining on DermNet (binary re-labeling) ===")

# Binary task: 0 = non-eczema, 1 = eczema
num_dermnet_classes = 2

# Use dummy Fitzpatrick embedding (vocab size = 1, since no real Fitz data)
model = ResNet50ModelWithFitzpatrick(
    num_classes=num_dermnet_classes,
    fitzpatrick_vocab_size=1,
    fitz_emb_dim=32,
    dropout=dropout_rate
).to(device)

print(f"Pretraining using {num_dermnet_classes} classes from DermNet...")

# Train on DermNet (no fairness penalty since Fitzpatrick data is dummy)
train_model(
    model,
    pretrain_train_loader,
    pretrain_val_loader,
    device,
    num_epochs=num_epochs_pretrain,
    freeze_epochs=0,
    fairness_lambda=0.0,
    patience=patience_pretrain,
    lr=lr
)

print("Pretraining complete on DermNet!")
torch.save(model.state_dict(), f"dermnet_pretrained_model_{timestamp}.pth")
print(f"DermNet pretrained model saved to dermnet_pretrained_model_{timestamp}.pth")
print("--------------------------------------------------\n")

# ---------------------------------------------
# 10. STAGE 2: FINE-TUNING ON PASSION (WITH FAIRNESS REGULARIZATION)
# ---------------------------------------------

def replace_classifier(model, num_classes_new, dropout, fitz_emb_dim=32):
    """
    Replace the model’s classification head for transfer learning.

    Args:
        model (ResNet50ModelWithFitzpatrick): The model to modify.
        num_classes_new (int): New number of output classes.
        dropout (float): Dropout rate.
        fitz_emb_dim (int): Fitzpatrick embedding dimension.
    """
    in_features = model.resnet.fc.in_features
    model.classifier = nn.Sequential(
        nn.Linear(in_features + fitz_emb_dim, 256),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(256, num_classes_new)
    )
    print(f"Classifier replaced with {num_classes_new} output classes.")


print("=== Stage 2: Fine-tuning on PASSION (binary, with fairness) ===")
num_passion_classes = 2  # Binary: eczema vs non-eczema

# Create new model for PASSION with Fitzpatrick support (7 skin types)
model = ResNet50ModelWithFitzpatrick(
    num_classes=num_passion_classes,
    fitzpatrick_vocab_size=7,
    fitz_emb_dim=32,
    dropout=dropout_rate
).to(device)

# Load pretrained weights from Stage 1 (DermNet), ignoring mismatches
print("Loading pretrained weights from DermNet...")
pretrained_weights = torch.load(f"dermnet_pretrained_model_{timestamp}.pth", map_location=device)
model_dict = model.state_dict()

# Only load weights that match in shape
filtered_dict = {
    k: v for k, v in pretrained_weights.items()
    if k in model_dict and v.shape == model_dict[k].shape
}
model_dict.update(filtered_dict)
model.load_state_dict(model_dict)
print("Pretrained weights loaded (excluding mismatches).")

# Replace classifier for PASSION binary classification
replace_classifier(model, num_passion_classes, dropout_rate)
model = model.to(device)

# Fine-tune model with fairness-aware training on PASSION
train_model(
    model,
    train_loader,
    valid_loader,
    device,
    num_epochs=num_epochs_finetune,
    freeze_epochs=freeze_epochs_passion,
    fairness_lambda=fairness_lambda_passion,
    patience=patience_finetune,
    lr=lr
)

print("Fine-tuning complete on PASSION!")
torch.save(model.state_dict(), f"best_model_fold_5_{timestamp}.pth")
print(f"Fine-tuned PASSION model saved to best_model_fold_5_{timestamp}.pth")
print("--------------------------------------------------\n")

# ---------------------------------------------
# 11. APPLY PRUNING & DYNAMIC QUANTIZATION AFTER FINE-TUNING
# ---------------------------------------------

print("Applying pruning and dynamic quantization...")

# Apply unstructured pruning to linear layers (30% of weights)
apply_pruning(model, amount=0.3)

# Quantize linear layers dynamically to 8-bit precision for efficiency
# This reduces model size and can improve inference speed
model_quantized = quantize_dynamic(
    model, {nn.Linear}, dtype=torch.qint8, inplace=True
)

# Save the quantized model
torch.save(model_quantized.state_dict(), f"resnet_quantized_model_{timestamp}.pth")
print(f"Quantized pruned model saved to resnet_quantized_model_{timestamp}.pth")

# ---------------------------------------------
# 12. OPTUNA HYPERPARAMETER OPTIMIZATION (PASSION VALIDATION SET)
# ---------------------------------------------

# Load best PASSION-trained model checkpoint
model5_path = f"best_model_fold_5_{timestamp}.pth"

# --- Helper: Train for one epoch (used during Optuna search) ---
def train_one_epoch_optuna(model, loader, criterion, optimizer, device, fairness_lambda):
    model.train()
    for images, labels, fitz in loader:
        images, labels, fitz = images.to(device), labels.to(device), fitz.to(device)
        optimizer.zero_grad()
        outputs = model(images, fitz)
        loss = criterion(outputs, labels)

        # Apply fairness regularization if both groups exist
        light_mask = (fitz <= 3)
        dark_mask = (fitz > 3)
        if fairness_lambda > 0 and light_mask.sum() > 0 and dark_mask.sum() > 0:
            loss_light = criterion(outputs[light_mask], labels[light_mask])
            loss_dark = criterion(outputs[dark_mask], labels[dark_mask])
            loss += fairness_lambda * torch.abs(loss_light - loss_dark)

        loss.backward()
        optimizer.step()

# --- Helper: Evaluate accuracy on validation set ---
def evaluate_optuna(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels, fitz in loader:
            images, labels, fitz = images.to(device), labels.to(device), fitz.to(device)
            outputs = model(images, fitz)
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return accuracy_score(all_labels, all_preds)

trial_log = []  # For tracking trial results

# --- Optuna Objective Function ---
def objective(trial):
    # Sample hyperparameters from defined search spaces
    lr_trial = trial.suggest_float("lr", 1e-6, 1e-3, log=True)
    weight_decay_trial = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
    dropout_trial = trial.suggest_float("dropout", 0.1, 0.5)
    fairness_lambda_trial = trial.suggest_float("fairness_lambda", 0.0, 1.0)

    # Initialize and load model weights
    model_opt = ResNet50ModelWithFitzpatrick(
        num_classes=num_passion_classes,
        dropout=dropout_trial,
        fitzpatrick_vocab_size=7
    ).to(device)
    model_opt.load_state_dict(torch.load(model5_path, map_location=device))

    optimizer = optim.AdamW(model_opt.parameters(), lr=lr_trial, weight_decay=weight_decay_trial)
    criterion = nn.CrossEntropyLoss()
    scheduler = CosineAnnealingLR(optimizer, T_max=5)

    # Train for 5 short epochs to evaluate performance
    for _ in range(5):
        train_one_epoch_optuna(model_opt, valid_loader, criterion, optimizer, device, fairness_lambda_trial)
        scheduler.step()

    acc = evaluate_optuna(model_opt, valid_loader)

    # Log trial details
    trial_log.append({
        "trial": trial.number,
        "lr": lr_trial,
        "weight_decay": weight_decay_trial,
        "dropout": dropout_trial,
        "fairness_lambda": fairness_lambda_trial,
        "accuracy": acc
    })

    # Free up GPU memory
    del model_opt
    torch.cuda.empty_cache()
    return acc

# Run Optuna study
study = optuna.create_study(direction="maximize")
for _ in tqdm(range(20), desc="Running Optuna Trials"):
    study.optimize(objective, n_trials=1)  # Sequential trials to avoid memory overload

# Save Optuna results to CSV
df_log = pd.DataFrame(trial_log)
df_log.to_csv("optuna_trials.csv", index=False)

# Plot accuracy per trial
plt.figure(figsize=(10, 5))
plt.plot(df_log["trial"], df_log["accuracy"], marker="o")
plt.title("Optuna Trials: Accuracy vs Trial Number")
plt.xlabel("Trial")
plt.ylabel("Accuracy")
plt.grid(True)
plt.tight_layout()
plt.savefig("optuna_accuracy_vs_trial.png", dpi=300)
plt.show()

# Report best hyperparameters
print("Best parameters:", study.best_params)
best_params = study.best_params

# --- Retrain best model with best Optuna hyperparameters ---
model_best = ResNet50ModelWithFitzpatrick(
    num_classes=num_passion_classes,
    dropout=best_params["dropout"],
    fitzpatrick_vocab_size=7
).to(device)

model_best.load_state_dict(torch.load(model5_path, map_location=device))
optimizer = optim.AdamW(model_best.parameters(), lr=best_params["lr"], weight_decay=best_params["weight_decay"])
criterion = nn.CrossEntropyLoss()
scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Final fine-tuning with optimal settings
for _ in range(10):
    train_one_epoch_optuna(model_best, valid_loader, criterion, optimizer, device, best_params["fairness_lambda"])
    scheduler.step()

torch.save(model_best.state_dict(), f"resnet_finetuned_optuna_{timestamp}.pth")
print(f"Model saved to resnet_finetuned_optuna_{timestamp}.pth")


In [None]:
# ===================================================
# Vision Transformer (ViT) for Fairness-Aware Eczema Diagnosis
# Author: Domante Rabasauskaite
# Date: 09 April 2025
# ===================================================

import os
import glob
import cv2
import torch
import timm
import numpy as np
import pandas as pd
from PIL import Image
from torch import nn, optim
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import StratifiedGroupKFold, StratifiedKFold
from torch.nn.utils import prune
from torch.quantization import quantize_dynamic
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
from tqdm import tqdm
import optuna
import matplotlib.pyplot as plt

# Timestamp for output files (models, plots)
timestamp = "05_05_2025"

# ---------------------------------------------------
# 1. HYPERPARAMETERS & PATHS
# ---------------------------------------------------

# DermNet Pretraining Hyperparameters
lr = 1e-5
weight_decay = 1e-2
dropout_rate = 0.2
num_epochs_pretrain = 5           # Number of epochs for pretraining
patience_pretrain = 2             # Early stopping patience for pretraining

# PASSION Fine-Tuning Hyperparameters
fairness_lambda_passion = 0.1     # Fairness loss weight
num_epochs_finetune = 15          # Epochs for fine-tuning
patience_finetune = 3             # Early stopping patience for fine-tuning
freeze_epochs_passion = 5         # Epochs with frozen base encoder

# File paths (update to match your system)
dermnet_train_root = r"C:\Users\DermNet\train"
dermnet_unseen_root = r"C:\Users\DermNet\test"  # Optional test set

passion_csv_path = r"C:\Users\PASSION_MICCAI_2024\label.csv"
passion_image_folder = r"C:\Users\PASSION_MICCAI_2024\images"

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------------------------------
# 2. CSV SPLITTING FOR PASSION (Stratified Group K-Fold)
# ---------------------------------------------------

print("Loading PASSION CSV...")
df_full = pd.read_csv(passion_csv_path)
df_full.columns = df_full.columns.str.strip()
print(f"PASSION CSV loaded: {len(df_full)} rows.")

# Ensure stratified splits that preserve subject grouping
print("Splitting PASSION dataset using StratifiedGroupKFold...")
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

# Use the first fold for training/validation split
for train_idx, val_idx in sgkf.split(df_full, df_full["conditions_PASSION"], groups=df_full["subject_id"]):
    train_df = df_full.iloc[train_idx]
    val_df = df_full.iloc[val_idx]
    break

# Save splits for reproducibility
train_csv = "train_split.csv"
val_csv = "val_split.csv"
train_df.to_csv(train_csv, index=False)
val_df.to_csv(val_csv, index=False)
print(f"PASSION split complete: {len(train_df)} train rows, {len(val_df)} val rows.")

# ---------------------------------------------------
# 3. IMAGE TRANSFORMATIONS (DADA Proxy Augmentation)
# ---------------------------------------------------

print("Defining image transformations...")

class DADATransform:
    """
    DADA-style transformation using RandAugment as a proxy.

    Args:
        num_ops (int): Number of random ops to apply.
        magnitude (int): Strength of the augmentation.
    """
    def __init__(self, num_ops=2, magnitude=9):
        self.augment = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude)

    def __call__(self, img):
        return self.augment(img)

# Training transforms: RandAugment + common augmentation
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    DADATransform(num_ops=2, magnitude=9),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225))
])

# Validation/test transforms: only resizing + normalization
valid_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225))
])

print("Image transformations defined!")

# ---------------------------------------------------
# 4. DATASET CLASSES (BINARY RE-LABELING)
# ---------------------------------------------------

class DermNetDatasetBinary(Dataset):
    """
    Dataset for binary eczema classification from DermNet.
    Converts multiclass folders to binary labels:
        - 1 = eczema or atopic dermatitis
        - 0 = all other conditions

    Fitzpatrick score is not available, so a dummy value of 0 is returned.
    
    Args:
        root (str): Path to DermNet image folders.
        transform (callable): Transformations to apply to each image.
    """
    def __init__(self, root, transform=None):
        print(f"Initializing DermNetDatasetBinary from {root} ...")
        self.dataset = datasets.ImageFolder(root=root, transform=transform)
        self.transform = transform

        eczema_names = {"eczema photos", "atopic dermatitis photos"}
        self.binary_label_map = {}

        for idx, cls in enumerate(self.dataset.classes):
            folder_name = cls.strip().lower()
            self.binary_label_map[idx] = 1 if folder_name in eczema_names else 0

        print(f"Found {len(self.dataset)} images across {len(self.dataset.classes)} classes.")
        print(f"Binary mapping: {self.binary_label_map}")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        binary_label = self.binary_label_map[label]
        dummy_fitz = torch.tensor(0, dtype=torch.long)  # Fitzpatrick score not present in DermNet
        return image, torch.tensor(binary_label, dtype=torch.long), dummy_fitz


class PassionDatasetBinary(Dataset):
    """
    Dataset for binary eczema classification from PASSION.
    Binary labels and Fitzpatrick scores are derived from the CSV.

    Args:
        csv_file (str): CSV with subject_id, condition, and Fitzpatrick info.
        image_folder (str): Path to PASSION image files.
        transform (callable): Transformations to apply to each image.
        mode (str): 'train' or 'validation' (for logging/debugging).
    """
    def __init__(self, csv_file, image_folder, transform=None, mode="train"):
        print(f"Initializing PassionDatasetBinary for {mode}...")
        self.data = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.transform = transform
        self.mode = mode

        # Create binary labels
        self.data["binary_label"] = self.data["conditions_PASSION"].apply(
            lambda x: 1 if x.strip().lower() == "eczema" else 0
        )

        # Match images to subject IDs
        self.image_files = []
        for subject_id in self.data["subject_id"]:
            subject_imgs = glob.glob(os.path.join(image_folder, f"{subject_id}_*.jpg"))
            self.image_files.extend([(img, subject_id) for img in subject_imgs])

        print(f"PASSION Binary Dataset loaded: {len(self.image_files)} images for {mode}.")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path, subject_id = self.image_files[idx]

        row = self.data[self.data["subject_id"] == subject_id]
        if row.empty:
            raise ValueError(f"Subject {subject_id} not found in CSV.")

        label = int(row["binary_label"].values[0])
        fitz = int(row["fitzpatrick"].values[0])

        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long), torch.tensor(fitz, dtype=torch.long)

# ---------------------------------------------------
# 5. CREATE DATALOADERS
# ---------------------------------------------------

# --- DermNet Pretraining ---
print("Creating DermNet dataset for pretraining (binary re-labeling)...")
dermnet_full = DermNetDatasetBinary(root=dermnet_train_root, transform=train_transforms)
dermnet_len = len(dermnet_full)
val_size = int(0.2 * dermnet_len)
train_size = dermnet_len - val_size

# Random split of DermNet
dermnet_train_ds, dermnet_val_ds = random_split(dermnet_full, [train_size, val_size])

pretrain_train_loader = DataLoader(dermnet_train_ds, batch_size=32, shuffle=True, num_workers=0)
pretrain_val_loader   = DataLoader(dermnet_val_ds, batch_size=32, shuffle=False, num_workers=0)

print(f"DermNet pretraining: {train_size} train images, {val_size} val images.")

# --- PASSION Fine-tuning ---
print("Creating PASSION binary dataset from CSV splits...")
passion_train_ds = PassionDatasetBinary(train_csv, passion_image_folder, transform=train_transforms, mode="train")
passion_val_ds   = PassionDatasetBinary(val_csv, passion_image_folder, transform=valid_transforms, mode="validation")

train_loader = DataLoader(passion_train_ds, batch_size=32, shuffle=True, num_workers=0)
valid_loader = DataLoader(passion_val_ds, batch_size=32, shuffle=False, num_workers=0)

print(f"PASSION binary: {len(passion_train_ds)} train images, {len(passion_val_ds)} val images.")

# ---------------------------------------------------
# 6. MODEL DEFINITION WITH FITZPATRICK INTEGRATION (ViT)
# ---------------------------------------------------

class ViTModelWithFitzpatrick(nn.Module):
    """
    Vision Transformer (ViT-Base/16) with Fitzpatrick skin type embedding for fairness-aware classification.

    Combines image features from a pretrained ViT backbone with a learned embedding of the Fitzpatrick type.

    Args:
        num_classes (int): Number of output classes (e.g., 2 for binary classification).
        fitzpatrick_vocab_size (int): Number of distinct Fitzpatrick types (e.g., 7).
        fitz_emb_dim (int): Dimensionality of Fitzpatrick embedding.
        dropout (float): Dropout rate before final classification.
    """
    def __init__(self, num_classes, fitzpatrick_vocab_size, fitz_emb_dim=32, dropout=0.2):
        super().__init__()
        print("Initializing ViT with Fitzpatrick embedding...")

        # Load pretrained ViT-Base/16 from TIMM
        self.vit = timm.create_model("vit_base_patch16_224", pretrained=True)
        in_features = self.vit.head.in_features  # Should be 768 for ViT-Base

        self.vit.head = nn.Identity()  # Remove default classifier head
        self.fitz_emb = nn.Embedding(fitzpatrick_vocab_size, fitz_emb_dim)

        # Combined classifier: [ViT features + Fitzpatrick embedding]
        self.classifier = nn.Sequential(
            nn.Linear(in_features + fitz_emb_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )
        print("ViT model created!")

    def forward(self, x, fitz):
        vit_features = self.vit(x)              # [B, 768]
        emb = self.fitz_emb(fitz)               # [B, 32]
        combined = torch.cat((vit_features, emb), dim=1)
        return self.classifier(combined)


def replace_classifier(model, num_classes_new, dropout, fitz_emb_dim=32):
    """
    Utility function to replace the ViT classifier block with new output dimensions.

    Args:
        model (ViTModelWithFitzpatrick): ViT model with Fitzpatrick embedding.
        num_classes_new (int): New number of output classes.
        dropout (float): Dropout probability.
        fitz_emb_dim (int): Dimensionality of Fitzpatrick embedding.
    """
    in_features = 768  # ViT-Base patch16 output
    model.classifier = nn.Sequential(
        nn.Linear(in_features + fitz_emb_dim, 256),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(256, num_classes_new)
    )
    print(f"Classifier replaced with {num_classes_new} output classes.")

# ---------------------------------------------------
# 7. TRAINING FUNCTION WITH EARLY STOPPING & FAIRNESS PENALTY
# ---------------------------------------------------

def train_model(model, train_loader, val_loader, device, 
                num_epochs, freeze_epochs=0, fairness_lambda=0.0, patience=2, lr=1e-5):
    """
    Trains the ViT model with optional backbone freezing, fairness loss regularization, and early stopping.

    Args:
        model (nn.Module): The model to train.
        train_loader (DataLoader): Training data loader.
        val_loader (DataLoader): Validation data loader.
        device (torch.device): Target device (CPU or GPU).
        num_epochs (int): Total number of epochs.
        freeze_epochs (int): Number of epochs to keep the backbone frozen.
        fairness_lambda (float): Fairness loss weight (penalizes subgroup disparities).
        patience (int): Early stopping patience.
        lr (float): Learning rate.
    """
    print(f"Starting training: freeze_epochs={freeze_epochs}, fairness_lambda={fairness_lambda}")
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = nn.CrossEntropyLoss()

    best_val_loss = float('inf')
    epochs_no_improve = 0

    # Optional: Freeze ViT backbone initially
    if freeze_epochs > 0:
        for param in model.vit.parameters():
            param.requires_grad = False

    for epoch in range(num_epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0

        # --- Training loop ---
        for batch_idx, (images, labels, fitz) in enumerate(train_loader):
            images, labels, fitz = images.to(device), labels.to(device), fitz.to(device)
            optimizer.zero_grad()

            outputs = model(images, fitz)
            loss = criterion(outputs, labels)

            # Fairness penalty: balance subgroup loss
            light_mask = (fitz <= 3)
            dark_mask = (fitz > 3)
            if fairness_lambda > 0 and light_mask.sum() > 0 and dark_mask.sum() > 0:
                loss_light = criterion(outputs[light_mask], labels[light_mask])
                loss_dark = criterion(outputs[dark_mask], labels[dark_mask])
                loss += fairness_lambda * torch.abs(loss_light - loss_dark)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)

            if (batch_idx + 1) % 10 == 0:
                print(f"Batch {batch_idx+1}/{len(train_loader)}: Loss = {loss.item():.4f}")

        train_loss = total_loss / len(train_loader)
        train_acc = correct / total
        scheduler.step()

        # Unfreeze backbone after freeze_epochs
        if epoch == freeze_epochs:
            print("Unfreezing ViT backbone for full fine-tuning.")
            for param in model.vit.parameters():
                param.requires_grad = True
            optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
            scheduler = CosineAnnealingLR(optimizer, T_max=(num_epochs - freeze_epochs))

        # --- Validation loop ---
        model.eval()
        val_loss_sum, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for images, labels, fitz in val_loader:
                images, labels, fitz = images.to(device), labels.to(device), fitz.to(device)
                outputs = model(images, fitz)
                loss = criterion(outputs, labels)
                val_loss_sum += loss.item()
                val_correct += (outputs.argmax(dim=1) == labels).sum().item()
                val_total += labels.size(0)

        val_loss = val_loss_sum / len(val_loader)
        val_acc = val_correct / val_total

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Train Acc = {train_acc:.4f} | Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.4f}")

        # --- Early stopping ---
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), f"best_model_fold_5_{timestamp}.pth")
            print("Model improved; saved best model.")
        else:
            epochs_no_improve += 1
            print(f"No improvement for {epochs_no_improve} epoch(s).")
            if epochs_no_improve >= patience:
                print("Early stopping triggered!")
                break










# ---------------------------------------------------
# 8. FUNCTION TO APPLY PRUNING & DYNAMIC QUANTIZATION
# ---------------------------------------------------

def apply_pruning(model, amount=0.3):
    """
    Applies global unstructured L1 pruning to all linear layers in the model.

    Args:
        model (nn.Module): Model to prune.
        amount (float): Proportion of weights to prune (e.g., 0.3 = 30%).
    """
    print("Applying global unstructured pruning to model...")

    parameters_to_prune = [
        (module, 'weight')
        for module in model.modules()
        if isinstance(module, nn.Linear)
    ]

    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount
    )

    print("Pruning applied.")

    # Remove pruning hooks to convert pruned layers back to standard form
    for module, name in parameters_to_prune:
        try:
            prune.remove(module, name)
        except Exception as e:
            print(f"Could not remove pruning from module {module}: {e}")

# ---------------------------------------------------
# 9. STAGE 1: PRETRAINING ON DERMNET (BINARY)
# ---------------------------------------------------

print("=== Stage 1: Pretraining on DermNet (binary re-labeling) ===")
num_dermnet_classes = 2  # Binary classification: eczema vs non-eczema

# Dummy Fitzpatrick score (vocab size = 1) for pretraining
model = ViTModelWithFitzpatrick(
    num_classes=num_dermnet_classes,
    fitzpatrick_vocab_size=1,
    fitz_emb_dim=32,
    dropout=dropout_rate
).to(device)

print(f"Pretraining using {num_dermnet_classes} classes from DermNet...")

# Pretrain without fairness penalty
train_model(
    model,
    pretrain_train_loader,
    pretrain_val_loader,
    device,
    num_epochs=num_epochs_pretrain,
    freeze_epochs=0,
    fairness_lambda=0.0,
    patience=patience_pretrain,
    lr=lr
)

torch.save(model.state_dict(), f"dermnet_pretrained_model_{timestamp}.pth")
print(f"DermNet pretrained model saved to dermnet_pretrained_model_{timestamp}.pth")
print("--------------------------------------------------\n")

# ---------------------------------------------------
# 10. STAGE 2: FINE-TUNING ON PASSION (WITH FAIRNESS)
# ---------------------------------------------------

def replace_classifier(model, num_classes_new, dropout, fitz_emb_dim=32):
    """
    Replace the ViT classifier block for fine-tuning with new output classes.

    Args:
        model (nn.Module): ViT model.
        num_classes_new (int): Output class count.
        dropout (float): Dropout rate.
        fitz_emb_dim (int): Fitzpatrick embedding dimension.
    """
    in_features = 768  # ViT Base Patch16 output size
    model.classifier = nn.Sequential(
        nn.Linear(in_features + fitz_emb_dim, 256),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(256, num_classes_new)
    )
    print(f"Classifier replaced with {num_classes_new} output classes.")

print("=== Stage 2: Fine-tuning on PASSION (binary, with fairness) ===")
num_passion_classes = 2  # Still binary (eczema vs non-eczema)

# Create new model with Fitzpatrick integration (7 skin types)
model = ViTModelWithFitzpatrick(
    num_classes=num_passion_classes,
    fitzpatrick_vocab_size=7,
    fitz_emb_dim=32,
    dropout=dropout_rate
).to(device)

# Load pretrained weights from DermNet
print("Loading pretrained weights from DermNet...")
pretrained_weights = torch.load(f"dermnet_pretrained_model_{timestamp}.pth", map_location=device)
model_dict = model.state_dict()

# Match only compatible layers
filtered_dict = {
    k: v for k, v in pretrained_weights.items()
    if k in model_dict and v.shape == model_dict[k].shape
}
model_dict.update(filtered_dict)
model.load_state_dict(model_dict)
print("Pretrained weights loaded (excluding mismatches).")

# Replace classifier for PASSION binary classification
replace_classifier(model, num_passion_classes, dropout_rate)
model = model.to(device)

# Train with fairness-aware penalty on PASSION
train_model(
    model,
    train_loader,
    valid_loader,
    device,
    num_epochs=num_epochs_finetune,
    freeze_epochs=freeze_epochs_passion,
    fairness_lambda=fairness_lambda_passion,
    patience=patience_finetune,
    lr=lr
)

torch.save(model.state_dict(), f"best_model_fold_5_{timestamp}.pth")
print(f"Fine-tuned PASSION model saved to best_model_fold_5_{timestamp}.pth")
print("--------------------------------------------------\n")

# ---------------------------------------------------
# 11. APPLY PRUNING & DYNAMIC QUANTIZATION AFTER FINE-TUNING
# ---------------------------------------------------

print("Applying pruning and dynamic quantization...")

# Apply global unstructured L1 pruning to all linear layers
apply_pruning(model, amount=0.3)

# Quantize linear layers to int8 precision (dynamic quantization)
model_quantized = quantize_dynamic(
    model, {nn.Linear}, dtype=torch.qint8, inplace=True
)

# Save compressed model
torch.save(model_quantized.state_dict(), f"vit_quantized_model_{timestamp}.pth")
print(f"Quantized pruned model saved to vit_quantized_model_{timestamp}.pth")
print("--------------------------------------------------\n")

# ---------------------------------------------------
# 12. OPTUNA HYPERPARAMETER OPTIMIZATION (PASSION VALIDATION SET)
# ---------------------------------------------------

model5_path = f"best_model_fold_5_{timestamp}.pth"  # Checkpoint used as starting point

# --- Helper: One Epoch of Fairness-Aware Training ---
def train_one_epoch_optuna(model, loader, criterion, optimizer, device, fairness_lambda):
    model.train()
    for images, labels, fitz in loader:
        images, labels, fitz = images.to(device), labels.to(device), fitz.to(device)
        optimizer.zero_grad()
        outputs = model(images, fitz)
        loss = criterion(outputs, labels)

        # Fairness penalty between light and dark Fitzpatrick groups
        light_mask = (fitz <= 3)
        dark_mask = (fitz > 3)
        if fairness_lambda > 0 and light_mask.sum() > 0 and dark_mask.sum() > 0:
            loss_light = criterion(outputs[light_mask], labels[light_mask])
            loss_dark = criterion(outputs[dark_mask], labels[dark_mask])
            loss += fairness_lambda * torch.abs(loss_light - loss_dark)

        loss.backward()
        optimizer.step()

# --- Helper: Evaluation Metric (Accuracy) ---
def evaluate_optuna(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels, fitz in loader:
            images, labels, fitz = images.to(device), labels.to(device), fitz.to(device)
            outputs = model(images, fitz)
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return accuracy_score(all_labels, all_preds)

# --- Optuna Objective Function ---
trial_log = []

def objective(trial):
    lr_trial = trial.suggest_float("lr", 1e-6, 1e-3, log=True)
    weight_decay_trial = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
    dropout_trial = trial.suggest_float("dropout", 0.1, 0.5)
    fairness_lambda_trial = trial.suggest_float("fairness_lambda", 0.0, 1.0)

    model_opt = ViTModelWithFitzpatrick(
        num_classes=num_passion_classes,
        dropout=dropout_trial,
        fitzpatrick_vocab_size=7
    ).to(device)
    model_opt.load_state_dict(torch.load(model5_path, map_location=device))

    optimizer = optim.AdamW(model_opt.parameters(), lr=lr_trial, weight_decay=weight_decay_trial)
    criterion = nn.CrossEntropyLoss()
    scheduler = CosineAnnealingLR(optimizer, T_max=5)

    for _ in range(5):
        train_one_epoch_optuna(model_opt, valid_loader, criterion, optimizer, device, fairness_lambda_trial)
        scheduler.step()

    acc = evaluate_optuna(model_opt, valid_loader)

    trial_log.append({
        "trial": trial.number,
        "lr": lr_trial,
        "weight_decay": weight_decay_trial,
        "dropout": dropout_trial,
        "fairness_lambda": fairness_lambda_trial,
        "accuracy": acc
    })

    del model_opt
    torch.cuda.empty_cache()
    return acc

# --- Run Optuna Tuning Loop ---
study = optuna.create_study(direction="maximize")
for _ in tqdm(range(20), desc="Running Optuna Trials"):
    study.optimize(objective, n_trials=1)

# Save tuning results
df_log = pd.DataFrame(trial_log)
df_log.to_csv("optuna_trials.csv", index=False)

# Plot tuning curve
plt.figure(figsize=(10, 5))
plt.plot(df_log["trial"], df_log["accuracy"], marker="o")
plt.title("Optuna Trials: Accuracy vs Trial Number")
plt.xlabel("Trial")
plt.ylabel("Accuracy")
plt.grid(True)
plt.tight_layout()
plt.savefig("optuna_accuracy_vs_trial.png", dpi=300)
plt.show()

# --- Best Trial Results ---
print("Best parameters:", study.best_params)
best_params = study.best_params

# ---------------------------------------------------
# Retrain Best Model Using Optimal Hyperparameters
# ---------------------------------------------------

model_best = ViTModelWithFitzpatrick(
    num_classes=num_passion_classes,
    dropout=best_params["dropout"],
    fitzpatrick_vocab_size=7
).to(device)
model_best.load_state_dict(torch.load(model5_path, map_location=device))

optimizer = optim.AdamW(
    model_best.parameters(),
    lr=best_params["lr"],
    weight_decay=best_params["weight_decay"]
)
criterion = nn.CrossEntropyLoss()
scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Final training on validation set using best trial
for _ in range(10):
    train_one_epoch_optuna(model_best, valid_loader, criterion, optimizer, device, best_params["fairness_lambda"])
    scheduler.step()

torch.save(model_best.state_dict(), f"vit_finetuned_optuna_{timestamp}.pth")
print(f"Model saved to vit_finetuned_optuna_{timestamp}.pth")
