In [None]:

"""
Compares fine-tuned ViT and ResNet models for binary eczema classification on the PASSION MICCAI 2024 dataset.
Evaluates classification metrics and fairness across Fitzpatrick skin types.

Author: Domante Rabasauskaite
Date: 13_04_2025
"""

# ================================
# 1. IMPORTS AND CONFIGURATION
# ================================

import os
import glob
import cv2
import torch
import timm
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import (
    confusion_matrix, accuracy_score, classification_report,
    precision_score, recall_score, f1_score
)

from torchvision import transforms, models
from torch import nn
from torch.utils.data import Dataset, DataLoader

# Optional imports for scheduling, pruning, quantization, hyperparameter tuning
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import StratifiedGroupKFold
from torch.nn.utils import prune
from torch.quantization import quantize_dynamic
import optuna
from tqdm import tqdm

# Set timestamp for file naming
timestamp = "13_04_2025"

# Evaluation hyperparameters (if needed for advanced setups)
lr = 1e-5
weight_decay = 1e-2
dropout_rate = 0.2

# Path to PASSION MICCAI dataset
csv_path = r"C:\Users\PASSION_MICCAI_2024\label.csv"
image_folder = r"C:\Users\PASSION_MICCAI_2024\images"

# Select GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ================================
# 2. PASSION DATASET LOADER
# ================================

class PassionDataset(Dataset):
    """
    Custom PyTorch dataset for PASSION MICCAI 2024.
    
    Each sample contains:
        - A resized image
        - A binary label (1 if eczema, else 0)
        - A Fitzpatrick skin type label

    Parameters:
        csv_file (str): Path to the PASSION label CSV file.
        image_folder (str): Path to the folder containing subject images.
        target_condition (str): The condition to classify (default: 'eczema').
        transform (callable): Image preprocessing pipeline.
    """
    def __init__(self, csv_file, image_folder, target_condition="eczema", transform=None):
        self.data = pd.read_csv(csv_file)
        self.data.columns = self.data.columns.str.strip()  # Remove trailing spaces in column headers
        self.image_folder = image_folder
        self.transform = transform

        # Generate binary label: 1 if condition is eczema, else 0
        self.data["label"] = (self.data["conditions_PASSION"].str.lower() == target_condition.lower()).astype(int)
        self.label_map = {0: f"Not {target_condition}", 1: target_condition}

        # Match subject IDs to images
        self.image_files = []
        for subject_id in self.data["subject_id"]:
            subject_imgs = glob.glob(os.path.join(image_folder, f"{subject_id}_*.jpg"))
            self.image_files.extend([(img, subject_id) for img in subject_imgs])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path, subject_id = self.image_files[idx]
        row = self.data[self.data["subject_id"] == subject_id]
        label = row["label"].values[0]
        fitz = row["fitzpatrick"].values[0]

        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long), torch.tensor(fitz, dtype=torch.long)

# ================================
# 3. IMAGE TRANSFORMATIONS
# ================================

# Standard preprocessing for evaluation
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225)
    )
])
# ================================
# 4. MODEL DEFINITIONS
# ================================

# ---- Vision Transformer (ViT) with Fitzpatrick Embedding ----

class ViTModelWithFitzpatrick(nn.Module):
    """
    Vision Transformer model with Fitzpatrick skin type integration.
    
    The Fitzpatrick value is passed through an embedding layer, then concatenated
    with the ViT backbone output before final classification.
    
    Args:
        num_classes (int): Number of output classes (e.g. 2 for binary classification).
        fitzpatrick_vocab_size (int): Number of unique Fitzpatrick values (e.g. 7).
        fitz_emb_dim (int): Size of embedding vector for Fitzpatrick types.
        dropout (float): Dropout rate for classifier head.
    """
    def __init__(self, num_classes, fitzpatrick_vocab_size, fitz_emb_dim=32, dropout=0.2):
        super(ViTModelWithFitzpatrick, self).__init__()
        print("Initializing ViT-based model with Fitzpatrick embedding...")

        # Load pretrained Vision Transformer (ViT-Base, 16x16 patch size)
        self.vit = timm.create_model("vit_base_patch16_224", pretrained=True)
        in_features = self.vit.head.in_features  # Output size from ViT: typically 768
        self.vit.head = nn.Identity()  # Remove final classification head

        # Add Fitzpatrick embedding
        self.fitz_emb = nn.Embedding(fitzpatrick_vocab_size, fitz_emb_dim)

        # Final classification head
        self.classifier = nn.Sequential(
            nn.Linear(in_features + fitz_emb_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )

        print("ViT model created.")

    def forward(self, x, fitz):
        vit_features = self.vit(x)              # Extract features from ViT
        emb = self.fitz_emb(fitz)               # Embed Fitzpatrick score
        combined = torch.cat((vit_features, emb), dim=1)  # Concatenate features
        return self.classifier(combined)


# ---- ResNet50 with Fitzpatrick Embedding ----

class ResNet50ModelWithFitzpatrick(nn.Module):
    """
    ResNet50 model with Fitzpatrick skin type integration.
    
    The Fitzpatrick value is passed through an embedding layer, then concatenated
    with the ResNet backbone output before final classification.
    
    Args:
        num_classes (int): Number of output classes.
        fitzpatrick_vocab_size (int): Size of vocabulary for Fitzpatrick types.
        fitz_emb_dim (int): Embedding size.
        dropout (float): Dropout rate used in the classifier.
    """
    def __init__(self, num_classes, fitzpatrick_vocab_size, fitz_emb_dim=32, dropout=0.2):
        super(ResNet50ModelWithFitzpatrick, self).__init__()
        print("Initializing ResNet50-based model with Fitzpatrick embedding...")

        # Load pretrained ResNet50 and remove its final classifier
        self.resnet = models.resnet50(pretrained=True)
        in_features = self.resnet.fc.in_features  # Output size from ResNet: typically 2048
        self.resnet.fc = nn.Identity()  # Remove classification head

        # Add Fitzpatrick embedding
        self.fitz_emb = nn.Embedding(fitzpatrick_vocab_size, fitz_emb_dim)

        # New classifier head
        self.classifier = nn.Sequential(
            nn.Linear(in_features + fitz_emb_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )

        print("ResNet50 model created.")

    def forward(self, x, fitz):
        features = self.resnet(x)               # Extract features from ResNet
        emb = self.fitz_emb(fitz)               # Embed Fitzpatrick score
        combined = torch.cat((features, emb), dim=1)  # Concatenate features
        return self.classifier(combined)


In [None]:
# ================================
# 5. EVALUATION FUNCTION
# ================================

def evaluate_model(model_class, model_path, dataloader, num_classes, device, fitzpatrick_vocab_size):
    """
    Loads a saved model checkpoint and evaluates it on a test dataloader.

    Args:
        model_class (nn.Module): The model architecture class to instantiate (e.g. ViTModelWithFitzpatrick).
        model_path (str): Path to the model weights (.pth file).
        dataloader (DataLoader): PyTorch dataloader for test set.
        num_classes (int): Number of output classes (e.g. 2 for binary).
        device (torch.device): CUDA or CPU.
        fitzpatrick_vocab_size (int): Vocabulary size of Fitzpatrick embeddings.

    Returns:
        - acc (float): Overall accuracy.
        - cm (ndarray): Confusion matrix.
        - prec (float): Precision score.
        - rec (float): Recall score.
        - f1 (float): F1 score.
        - spec (float): Specificity (recall for class 0).
        - report (dict): Classification report (as dictionary).
        - fitz_acc (dict): Dictionary of accuracy per Fitzpatrick skin type.
    """
    # Instantiate model and load weights
    model = model_class(num_classes=num_classes, fitzpatrick_vocab_size=fitzpatrick_vocab_size)
    state_dict = torch.load(model_path, map_location=device)
    model_state = model.state_dict()

    # Handle potential classifier shape mismatches
    for key in list(state_dict.keys()):
        if key.startswith("classifier.3") and key in model_state:
            if state_dict[key].shape != model_state[key].shape:
                print(f"Skipping loading {key} due to shape mismatch.")
                state_dict.pop(key)

    model.load_state_dict(state_dict, strict=False)
    model.to(device)
    model.eval()

    # Store all predictions and ground truths
    all_preds, all_labels, all_fitz = [], [], []

    with torch.no_grad():
        for images, labels, fitz in dataloader:
            images, fitz = images.to(device), fitz.to(device)
            outputs = model(images, fitz)
            preds = outputs.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_fitz.extend(fitz.cpu().numpy())

    # Classification metrics
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds)
    rec = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    spec = recall_score(all_labels, all_preds, pos_label=0)
    cm = confusion_matrix(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, output_dict=True)

    # Fitzpatrick group-wise accuracy (fairness check)
    fitz_groups = {}
    for pred, true, fitz in zip(all_preds, all_labels, all_fitz):
        if fitz not in fitz_groups:
            fitz_groups[fitz] = {'correct': 0, 'total': 0}
        fitz_groups[fitz]['correct'] += int(pred == true)
        fitz_groups[fitz]['total'] += 1

    fitz_acc = {
        f"Fitzpatrick {k}": v['correct'] / v['total'] if v['total'] > 0 else 0.0
        for k, v in fitz_groups.items()
    }

    return acc, cm, prec, rec, f1, spec, report, fitz_acc


# ================================
# 6. MAIN COMPARISON SCRIPT
# ================================

def main():
    """
    Runs the evaluation pipeline:
    - Loads test data from the PASSION dataset.
    - Evaluates both ViT and ResNet models on classification metrics.
    - Computes and displays fairness-aware accuracy (Fitzpatrick group).
    - Visualizes results with bar chart and confusion matrices.
    """
    
    # === Load PASSION test dataset ===
    dataset = PassionDataset(
        csv_file=csv_path,
        image_folder=image_folder,
        target_condition="eczema",
        transform=test_transforms
    )
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

    # Classification settings
    num_classes = 2
    fitzpatrick_vocab_size = 7

    # Checkpoints to evaluate (can add more models here)
    vit_model_files = [f"vit_finetuned_optuna_09_04_2025.pth"]
    resnet_model_files = [f"resnet_finetuned_optuna_09_04_2025.pth"]
    vit_labels = ["ViT Finetuned Optuna"]
    resnet_labels = ["ResNet Finetuned Optuna"]

    # Storage for all evaluation metrics
    model_names, accuracies, conf_matrices = [], [], []
    precisions, recalls, f1s, specificities = [], [], [], []
    fitz_metrics_all = []

    # === Evaluate ViT Models ===
    for model_file, label in zip(vit_model_files, vit_labels):
        print(f"Evaluating ViT model: {model_file} ...")
        acc, cm, prec, rec, f1, spec, report, fitz_acc = evaluate_model(
            ViTModelWithFitzpatrick, model_file, dataloader,
            num_classes, device, fitzpatrick_vocab_size
        )
        model_names.append(label)
        accuracies.append(acc)
        conf_matrices.append(cm)
        precisions.append(prec)
        recalls.append(rec)
        f1s.append(f1)
        specificities.append(spec)
        fitz_metrics_all.append(fitz_acc)

        print(f"{label} Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, "
              f"F1: {f1:.4f}, Specificity: {spec:.4f}")

    # === Evaluate ResNet Models ===
    for model_file, label in zip(resnet_model_files, resnet_labels):
        print(f"Evaluating ResNet model: {model_file} ...")
        acc, cm, prec, rec, f1, spec, report, fitz_acc = evaluate_model(
            ResNet50ModelWithFitzpatrick, model_file, dataloader,
            num_classes, device, fitzpatrick_vocab_size
        )
        model_names.append(label)
        accuracies.append(acc)
        conf_matrices.append(cm)
        precisions.append(prec)
        recalls.append(rec)
        f1s.append(f1)
        specificities.append(spec)
        fitz_metrics_all.append(fitz_acc)

        print(f"{label} Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, "
              f"F1: {f1:.4f}, Specificity: {spec:.4f}")

    # === Fairness Metrics: Fitzpatrick Skin Type Accuracy ===
    print("\n--- Fitzpatrick Group Accuracy (Fairness Check) ---")
    for name, fitz_acc in zip(model_names, fitz_metrics_all):
        print(f"\n{name}:")
        for group, acc in fitz_acc.items():
            print(f"  {group}: {acc:.4f}")

    # === Plot Accuracy Bar Chart ===
    plt.figure(figsize=(8, 6))
    bars = plt.bar(model_names, accuracies, color=["skyblue", "seagreen"])
    plt.ylim(0, 1.1)
    plt.ylabel("Accuracy")
    plt.title("Eczema Classification Accuracy Comparison (ViT vs. ResNet)")

    for bar, acc in zip(bars, accuracies):
        plt.text(bar.get_x() + bar.get_width() / 2, acc + 0.02, f"{acc:.4f}",
                 ha="center", va="bottom", fontsize=11)

    plt.tight_layout()
    plt.savefig("eczema_comparison_accuracy.png", dpi=300)
    plt.show()

    # === Plot Side-by-Side Confusion Matrices ===
    fig, axes = plt.subplots(1, len(conf_matrices), figsize=(6 * len(conf_matrices), 6))

    if len(conf_matrices) == 1:
        axes = [axes]

    for i, (ax, cm) in enumerate(zip(axes, conf_matrices)):
        sns.heatmap(
            cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=list(dataset.label_map.values()),
            yticklabels=list(dataset.label_map.values()),
            ax=ax, cbar=False
        )
        ax.set_title(f"Confusion Matrix - {model_names[i]}", fontsize=13)
        ax.set_xlabel("Predicted")
        ax.set_ylabel("True")

    plt.tight_layout()
    plt.savefig("eczema_comparison_conf_matrices.png", dpi=300)
    plt.show()


# Entrypoint for script execution
if __name__ == "__main__":
    main()


In [None]:
###############################################
# 5. EVALUATION FUNCTION
###############################################

def evaluate_model(model_class, model_path, dataloader, num_classes, device, fitzpatrick_vocab_size):
    """
    Evaluates a given model on a dataloader and computes:
      - accuracy, precision, recall, F1 score, specificity
      - confusion matrix
      - full classification report
      - group-wise Fitzpatrick accuracy (for fairness analysis)
    """
    model = model_class(num_classes=num_classes, fitzpatrick_vocab_size=fitzpatrick_vocab_size)
    state_dict = torch.load(model_path, map_location=device)
    model_state = model.state_dict()

    # Handle shape mismatches (e.g., classifier layers) by skipping incompatible keys
    for key in list(state_dict.keys()):
        if key.startswith("classifier.3") and key in model_state:
            if state_dict[key].shape != model_state[key].shape:
                print(f"Skipping loading {key} due to shape mismatch.")
                state_dict.pop(key)

    model.load_state_dict(state_dict, strict=False)
    model.to(device)
    model.eval()

    all_preds, all_labels, all_fitz = [], [], []

    with torch.no_grad():
        for images, labels, fitz in dataloader:
            images, fitz = images.to(device), fitz.to(device)
            outputs = model(images, fitz)
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_fitz.extend(fitz.cpu().numpy())

    # Compute metrics
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds)
    rec = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    spec = recall_score(all_labels, all_preds, pos_label=0)
    cm = confusion_matrix(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, output_dict=True)

    # Fitzpatrick group-wise accuracy
    fitz_groups = {}
    for pred, true, fitz in zip(all_preds, all_labels, all_fitz):
        if fitz not in fitz_groups:
            fitz_groups[fitz] = {'correct': 0, 'total': 0}
        fitz_groups[fitz]['correct'] += int(pred == true)
        fitz_groups[fitz]['total'] += 1

    fitz_acc = {f"Fitzpatrick {k}": v['correct'] / v['total'] if v['total'] > 0 else 0.0 
                for k, v in fitz_groups.items()}

    return acc, cm, prec, rec, f1, spec, report, fitz_acc


In [None]:
###############################################
# 6. MAIN EVALUATION SCRIPT FOR BOTH DATASETS
###############################################

def main():
    """
    Evaluates both ViT and ResNet models on two datasets:
    - PASSION (real validation split with Fitzpatrick scores)
    - DermNet (domain-shifted test set with dummy Fitzpatrick info)
    
    Outputs classification and fairness metrics, confusion matrices, and saves results.
    """

    # Pretrained model checkpoints and their readable labels
    vit_model_files = [f"vit_finetuned_optuna_09_04_2025.pth"]
    resnet_model_files = [f"resnet_finetuned_optuna_09_04_2025.pth"]
    vit_labels = ["ViT Finetuned Optuna"]
    resnet_labels = ["ResNet Finetuned Optuna"]

    # Define datasets for testing (PASSION and DermNet)
    datasets_info = [
        ("PASSION", PassionDatasetBinary(
            csv_file=passion_csv_path,
            image_folder=passion_image_folder,
            transform=valid_transforms,
            mode="validation"
        )),
        ("DermNet", DermNetDatasetBinary(
            root=dermnet_unseen_root,
            transform=valid_transforms
        ))
    ]

    num_classes = 2
    fitzpatrick_vocab_size = 7

    for ds_name, dataset in datasets_info:
        print(f"\n=== Evaluating on {ds_name} dataset ===")
        dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

        # --- ViT Evaluation ---
        for model_file, label in zip(vit_model_files, vit_labels):
            print(f"\nTesting {label} using checkpoint {model_file} ...")
            acc, cm, prec, rec, f1, spec, report, fitz_acc = evaluate_model(
                ViTModelWithFitzpatrick, model_file, dataloader, num_classes, device, fitzpatrick_vocab_size
            )
            print(f"{label} on {ds_name}:")
            print(f"  Accuracy:    {acc:.4f}")
            print(f"  Precision:   {prec:.4f}")
            print(f"  Recall:      {rec:.4f}")
            print(f"  F1 Score:    {f1:.4f}")
            print(f"  Specificity: {spec:.4f}")
            print("  Fitzpatrick Group Accuracies:")
            for group, group_acc in fitz_acc.items():
                print(f"    {group}: {group_acc:.4f}")

            # Plot confusion matrix
            plt.figure(figsize=(6, 5))
            sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                        xticklabels=list(dataset.label_map.values()),
                        yticklabels=list(dataset.label_map.values()))
            plt.title(f"Confusion Matrix - {label} on {ds_name}")
            plt.xlabel("Predicted")
            plt.ylabel("True")
            plt.tight_layout()
            plt.savefig(f"confusion_matrix_{label.replace(' ', '_')}_{ds_name}.png", dpi=300)
            plt.show()

        # --- ResNet Evaluation ---
        for model_file, label in zip(resnet_model_files, resnet_labels):
            print(f"\nTesting {label} using checkpoint {model_file} ...")
            acc, cm, prec, rec, f1, spec, report, fitz_acc = evaluate_model(
                ResNet50ModelWithFitzpatrick, model_file, dataloader, num_classes, device, fitzpatrick_vocab_size
            )
            print(f"{label} on {ds_name}:")
            print(f"  Accuracy:    {acc:.4f}")
            print(f"  Precision:   {prec:.4f}")
            print(f"  Recall:      {rec:.4f}")
            print(f"  F1 Score:    {f1:.4f}")
            print(f"  Specificity: {spec:.4f}")
            print("  Fitzpatrick Group Accuracies:")
            for group, group_acc in fitz_acc.items():
                print(f"    {group}: {group_acc:.4f}")

            # Plot confusion matrix
            plt.figure(figsize=(6, 5))
            sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                        xticklabels=list(dataset.label_map.values()),
                        yticklabels=list(dataset.label_map.values()))
            plt.title(f"Confusion Matrix - {label} on {ds_name}")
            plt.xlabel("Predicted")
            plt.ylabel("True")
            plt.tight_layout()
            plt.savefig(f"confusion_matrix_{label.replace(' ', '_')}_{ds_name}.png", dpi=300)
            plt.show()

if __name__ == "__main__":
    main()


In [None]:
"""
Baseline ViT Model Training (Without Fairness Embedding)
--------------------------------------------------------

This script trains a baseline Vision Transformer (ViT) model for binary eczema classification
using the PASSION dataset. It excludes fairness-aware features such as Fitzpatrick skin type 
embedding, serving as a control/baseline for comparison against fairness-integrated models.

Key Features:
- Binary label generation: Eczema vs. Non-Eczema
- Image augmentation pipeline using RandAugment and ColorJitter
- StratifiedGroupKFold for patient-wise split to prevent data leakage
- Early stopping based on validation loss
- Model saving with timestamp for reproducibility

Notes:
- This model excludes Fitzpatrick skin tone information.
- A similar version of this script can be used for ResNet by replacing the ViT model class 
  with `ResNet50ModelWithFitzpatrick` (excluding the fairness embedding).
- This baseline is useful for measuring the effect of fairness-aware training components 
  in more advanced experiments.

Dependencies: PyTorch, torchvision, timm, sklearn, pandas, numpy, OpenCV

"""


import os
import glob
import cv2
import torch
import timm
import numpy as np
import pandas as pd
from PIL import Image
from datetime import datetime
from torch import nn, optim
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import confusion_matrix

# Set device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

###############################################
# 1. HYPERPARAMETERS & PATHS
###############################################
lr = 1e-5
weight_decay = 1e-2
dropout_rate = 0.2
num_epochs_finetune = 15
patience_finetune = 3
freeze_epochs_passion = 5

# Dataset paths
dermnet_train_root = r"C:\Users\DermNet\train"
passion_csv_path = r"C:\Users\PASSION_MICCAI_2024\label.csv"
passion_image_folder = r"C:\Users\PASSION_MICCAI_2024\images"

###############################################
# 2. CSV SPLITTING FOR PASSION
###############################################
df_full = pd.read_csv(passion_csv_path)
df_full.columns = df_full.columns.str.strip()
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)
for train_idx, val_idx in sgkf.split(df_full, df_full["conditions_PASSION"], groups=df_full["subject_id"]):
    train_df = df_full.iloc[train_idx]
    val_df = df_full.iloc[val_idx]
    break
train_csv = "train_split.csv"
val_csv = "val_split.csv"
train_df.to_csv(train_csv, index=False)
val_df.to_csv(val_csv, index=False)

###############################################
# 3. IMAGE TRANSFORMATIONS
###############################################
class DADATransform:
    def __init__(self, num_ops=2, magnitude=9):
        self.augment = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude)
    def __call__(self, img):
        return self.augment(img)

train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    DADATransform(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

valid_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

###############################################
# 4. DATASET CLASS DEFINITIONS
###############################################
class PassionDataset(Dataset):
    def __init__(self, csv_file, image_folder, transform=None, mode="train"):
        self.data = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.transform = transform
        self.mode = mode
        self.data["binary_label"] = self.data["conditions_PASSION"].apply(
            lambda x: 1 if str(x).strip().lower() == "eczema" else 0)
        self.image_files = []
        for subject_id in self.data["subject_id"]:
            subject_imgs = glob.glob(os.path.join(image_folder, f"{subject_id}_*.jpg"))
            self.image_files.extend([(img, subject_id) for img in subject_imgs])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path, subject_id = self.image_files[idx]
        row = self.data[self.data["subject_id"] == subject_id]
        label = int(row["binary_label"].values[0])
        fitz = int(row["fitzpatrick"].values[0])  # Not used in this model
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long), torch.tensor(fitz, dtype=torch.long)

###############################################
# 5. DATALOADERS
###############################################
passion_train_ds = PassionDataset(train_csv, passion_image_folder, transform=train_transforms)
passion_val_ds = PassionDataset(val_csv, passion_image_folder, transform=valid_transforms)
train_loader = DataLoader(passion_train_ds, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(passion_val_ds, batch_size=32, shuffle=False, num_workers=0)

###############################################
# 6. BASELINE MODEL: ViT WITHOUT FAIRNESS
###############################################
class ViTWithoutFitz(nn.Module):
    def __init__(self, num_classes, dropout=0.2):
        super().__init__()
        self.vit = timm.create_model("vit_base_patch16_224", pretrained=True)
        in_features = self.vit.head.in_features
        self.vit.head = nn.Identity()
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        vit_features = self.vit(x)
        return self.classifier(vit_features)

###############################################
# 7. TRAINING FUNCTION (NO FAIRNESS)
###############################################
def train_model_nofair(model, train_loader, val_loader, device,
                       num_epochs=5, freeze_epochs=0, patience=2, lr=1e-5):
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = nn.CrossEntropyLoss()
    best_val_loss = float('inf')
    epochs_no_improve = 0

    if freeze_epochs > 0:
        for param in model.vit.parameters():
            param.requires_grad = False

    for epoch in range(num_epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for images, labels, _ in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)
        scheduler.step()

        if epoch == freeze_epochs:
            for param in model.vit.parameters():
                param.requires_grad = True
            optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
            scheduler = CosineAnnealingLR(optimizer, T_max=(num_epochs - freeze_epochs))

        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for images, labels, _ in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                val_correct += (outputs.argmax(1) == labels).sum().item()
                val_total += labels.size(0)

        val_loss /= len(val_loader)
        val_acc = val_correct / val_total

        print(f"Epoch {epoch+1}: Train Loss: {total_loss/len(train_loader):.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "vit_nofairness_best.pth")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                break

###############################################
# 8. EXECUTE TRAINING
###############################################
vit_nofair = ViTWithoutFitz(num_classes=2).to(device)
train_model_nofair(vit_nofair, train_loader, val_loader, device,
                   num_epochs=num_epochs_finetune,
                   freeze_epochs=freeze_epochs_passion,
                   patience=patience_finetune,
                   lr=lr)

torch.save(vit_nofair.state_dict(), "vit_nofairness_best_22APR2025.pth")
print(" Model re-saved as vit_nofairness_best_22APR2025.pth")



In [None]:
###############################################
# ROC Curve Plotting for Model Comparison
###############################################

from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

def plot_roc_curve(model, loader, label, color):
    """
    Plots the ROC curve for a given model and dataloader.

    Args:
        model: The trained model to evaluate.
        loader: DataLoader for the evaluation dataset.
        label: Label to use in the plot legend.
        color: Line color for the plot.

    Returns:
        AUC score (float).
    """
    model.to(device).eval()
    y_true, y_scores = [], []

    with torch.no_grad():
        for imgs, labels, fitz in loader:
            imgs, fitz = imgs.to(device), fitz.to(device)
            outputs = model(imgs, fitz)
            probs = torch.softmax(outputs, dim=1)[:, 1]  # Probability of class 1 (eczema)
            y_scores.extend(probs.cpu().numpy())
            y_true.extend(labels.numpy())

    # Compute ROC curve and AUC
    fpr, tpr, _ = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)

    # Plot ROC curve
    plt.plot(fpr, tpr, color=color, lw=2, label=f'{label} (AUC = {roc_auc:.2f})')
    return roc_auc
###############################################
# Initialize Models & Load Checkpoints
###############################################

# Ensure these model classes are already defined:
# - ViTModelWithFitzpatrick
# - ResNet50ModelWithFitzpatrick

# Instantiate models
vit_model = ViTModelWithFitzpatrick(num_classes=2, fitzpatrick_vocab_size=7)
resnet_model = ResNet50ModelWithFitzpatrick(num_classes=2, fitzpatrick_vocab_size=7)

# Load trained model weights (Optuna-tuned checkpoints)
vit_model.load_state_dict(torch.load("vit_finetuned_optuna_09_04_2025.pth", map_location=device), strict=False)
resnet_model.load_state_dict(torch.load("resnet_finetuned_optuna_09_04_2025.pth", map_location=device), strict=False)
###############################################
# Plot ROC Curves for Both Models
###############################################

# 'loader' must be defined beforehand and contain the test or validation dataset

plt.figure(figsize=(8, 6))

# Plot ViT ROC
vit_auc = plot_roc_curve(vit_model, loader, label="ViT", color="dimgray")

# Plot ResNet ROC
resnet_auc = plot_roc_curve(resnet_model, loader, label="ResNet", color="gray")

# Plot random chance baseline
plt.plot([0, 1], [0, 1], 'k--', lw=1)

# Labels and styling
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve: ViT vs ResNet (Optuna-Tuned)")
plt.legend(loc="lower right")
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()

# Save and display the figure
plt.savefig("roc_comparison_vit_vs_resnet.png", dpi=300)
plt.show()


In [None]:
"""
Evaluation Script: Calibration, Statistical Testing, and Fairness Assessment
----------------------------------------------------------------------------

This script evaluates two pretrained binary eczema classification models:
- Vision Transformer (ViT)
- ResNet-50

It performs a comprehensive analysis across multiple evaluation dimensions:
1. Classification performance (F1 score, Recall, AUC, Brier score)
2. Statistical significance testing using McNemar’s test
3. Confidence interval estimation of AUC differences via bootstrapping
4. Calibration assessment and correction using Temperature Scaling
5. Fairness analysis based on:
   - Demographic Parity (DP)
   - Equalized Odds (EO)

Usage Requirements:
- Provide trained model checkpoint paths (Optuna-tuned .pth files)
- Ensure availability of a compatible DataLoader (`loader`) using the `PassionDatasetBinary` class
- Verify `valid_transforms` is defined and used for preprocessing
- This script assumes binary classification: eczema vs. non-eczema

Note:
- Although the script is centered on evaluating ViT, ResNet evaluation is included.
- All methods are generalizable; similar analysis can be extended to additional models.

Dependencies: PyTorch, torchvision, timm, sklearn, matplotlib, scipy, numpy
"""


In [None]:
###############################################
# Model Definitions (ViT and ResNet with fairness embedding)
###############################################
class ViTFair(nn.Module):
    def __init__(self):
        super().__init__()
        self.vit = timm.create_model("vit_base_patch16_224", pretrained=False)
        feat = self.vit.head.in_features
        self.vit.head = nn.Identity()
        self.emb = nn.Embedding(7, 32)  # Fitzpatrick scale embedding
        self.classifier = nn.Sequential(
            nn.Linear(feat + 32, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )

    def forward(self, x, fitz):
        f = self.vit(x)
        e = self.emb(fitz)
        return self.classifier(torch.cat([f, e], dim=1))


class ResNetFair(nn.Module):
    def __init__(self):
        super().__init__()
        self.res = models.resnet50(pretrained=False)
        feat = self.res.fc.in_features
        self.res.fc = nn.Identity()
        self.emb = nn.Embedding(7, 32)
        self.classifier = nn.Sequential(
            nn.Linear(feat + 32, 256),
            nn.ReLU(),
            nn.Linear(256, 2)
        )

    def forward(self, x, fitz):
        f = self.res(x)
        e = self.emb(fitz)
        return self.classifier(torch.cat([f, e], dim=1))
###############################################
# Evaluation Helpers
###############################################
def evaluate_preds(model, loader):
    """
    Runs inference and returns labels, predictions, probabilities, and metrics.
    """
    model.eval()
    labs, preds, probs = [], [], []
    with torch.no_grad():
        for imgs, y, fitz in loader:
            imgs, fitz = imgs.to(device), fitz.to(device)
            out = model(imgs, fitz)
            p = torch.softmax(out, dim=1)[:, 1].cpu().numpy()
            labs.extend(y.numpy())
            preds.extend((p >= 0.5).astype(int))
            probs.extend(p)

    a = np.array(labs)
    b = np.array(preds)
    c = np.array(probs)
    return a, b, c, {
        "f1": f1_score(a, b),
        "recall": recall_score(a, b),
        "auc": roc_auc_score(a, c),
        "brier": brier_score_loss(a, c)
    }
###############################################
# Model Loading & Evaluation
###############################################
models_info = {
    "ViT": (ViTFair, "vit_finetuned_optuna_09_04_2025.pth"),
    "ResNet": (ResNetFair, "resnet_finetuned_optuna_09_04_2025.pth")
}

results, arrays = {}, {}

for name, (Cls, ckpt) in models_info.items():
    model = Cls().to(device)
    state = torch.load(ckpt, map_location=device, weights_only=True)

    # Rename weights for compatibility
    if "fitz_emb.weight" in state:
        state["emb.weight"] = state.pop("fitz_emb.weight")
    for old, new in [("classifier.3.weight", "classifier.2.weight"),
                     ("classifier.3.bias", "classifier.2.bias")]:
        if old in state:
            state[new] = state.pop(old)
    if name == "ResNet":
        state = {k.replace("resnet.", "res."): v for k, v in state.items()}

    model.load_state_dict(state, strict=True)
    l, p, q, m = evaluate_preds(model, loader)
    results[name] = m
    arrays[name] = (l, p, q)
    print(f"{name} → {m}")

###############################################
# McNemar’s Test for Statistical Comparison
###############################################
from scipy.stats import chi2

lv, pv, _ = arrays["ViT"]
lr, pr, _ = arrays["ResNet"]
assert np.array_equal(lv, lr)

N01 = np.sum((pv == lv) & (pr != lv))
N10 = np.sum((pv != lv) & (pr == lv))
n = N01 + N10

if n == 0:
    print("No discordant pairs.")
else:
    chi2_stat = (abs(N10 - N01) - 1)**2 / n
    p_val = 1 - chi2.cdf(chi2_stat, df=1)
    print(f"McNemar’s χ²={chi2_stat:.3f}, p={p_val:.4e}  (b={N10}, c={N01})")

###############################################
# AUC Confidence Interval via Bootstrapping
###############################################
from sklearn.metrics import roc_auc_score

diffs = []
rng = np.random.default_rng(0)
n = len(lv)

for _ in range(2000):
    idx = rng.integers(0, n, n)
    auc_v = roc_auc_score(lv[idx], arrays["ViT"][2][idx])
    auc_r = roc_auc_score(lv[idx], arrays["ResNet"][2][idx])
    diffs.append(auc_v - auc_r)

ci_low, ci_high = np.percentile(diffs, [2.5, 97.5])
print(f"AUC difference (ViT–ResNet): {results['ViT']['auc'] - results['ResNet']['auc']:.4f}")
print(f"95% CI: [{ci_low:.4f}, {ci_high:.4f}]")

###############################################
# Temperature Scaling for Calibration
###############################################
import torch.nn.functional as F
from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss

class TemperatureScaler(nn.Module):
    def __init__(self):
        super().__init__()
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def forward(self, logits):
        return logits / self.temperature

def temperature_scale(model, logits, labels):
    model.eval()
    temp_model = TemperatureScaler().to(device)
    optimizer = torch.optim.LBFGS([temp_model.temperature], lr=0.01, max_iter=50)

    def loss_fn():
        optimizer.zero_grad()
        loss = F.cross_entropy(temp_model(logits), labels)
        loss.backward()
        return loss

    optimizer.step(loss_fn)
    return temp_model.temperature.item()

def apply_temperature(logits, T):
    return logits / T

def evaluate_calibration(logits, labels):
    probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
    return brier_score_loss(labels.cpu().numpy(), probs)

print("\n--- Temperature Scaling Calibration ---")
for name, (Cls, ckpt) in models_info.items():
    model = Cls().to(device)
    state = torch.load(ckpt, map_location=device, weights_only=True)
    if "fitz_emb.weight" in state:
        state["emb.weight"] = state.pop("fitz_emb.weight")
    for old, new in [("classifier.3.weight", "classifier.2.weight"),
                     ("classifier.3.bias", "classifier.2.bias")]:
        if old in state:
            state[new] = state.pop(old)
    if name == "ResNet":
        state = {k.replace("resnet.", "res."): v for k, v in state.items()}
    model.load_state_dict(state, strict=True)
    model.eval()

    logits, labels = [], []
    with torch.no_grad():
        for imgs, labs, fitz in loader:
            imgs, labs, fitz = imgs.to(device), labs.to(device), fitz.to(device)
            out = model(imgs, fitz)
            logits.append(out)
            labels.append(labs)
    logits = torch.cat(logits)
    labels = torch.cat(labels)

    brier_before = evaluate_calibration(logits, labels)
    T = temperature_scale(model, logits, labels)
    logits_scaled = apply_temperature(logits, T)
    brier_after = evaluate_calibration(logits_scaled, labels)

    print(f"\n{name} - Optimal Temp: {T:.4f}")
    print(f"Brier Score Before: {brier_before:.4f}")
    print(f"Brier Score After : {brier_after:.4f}")

    # Generate calibration curve (optional, for reporting only)
    probs_scaled = torch.softmax(logits_scaled, dim=1)[:, 1].cpu().numpy()
    labels_np = labels.cpu().numpy()
    prob_true, prob_pred = calibration_curve(labels_np, probs_scaled, n_bins=10)

    # Plot calibration curve
    plt.figure(figsize=(6, 5))
    plt.plot(prob_pred, prob_true, label=f"{name} (T={T:.2f})", lw=2)
    plt.plot([0, 1], [0, 1], "k--", label="Perfectly Calibrated")
    plt.xlabel("Mean Predicted Probability")
    plt.ylabel("Empirical Accuracy")
    plt.title(f"Calibration Curve After Temperature Scaling – {name}")
    plt.grid(True, linestyle="--", alpha=0.6)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{name.lower()}_calibration_temp_scaled.png", dpi=300)
    plt.show()

###############################################
# Fairness Metrics (DP & EO)
###############################################
def fairness_metrics(labels, preds, groups):
    labels, preds, groups = np.array(labels), np.array(preds), np.array(groups)
    groups_unique = np.unique(groups)

    # Demographic Parity (DP)
    positive_rates = [preds[groups == g].mean() for g in groups_unique]
    dp_diff = np.abs(max(positive_rates) - min(positive_rates))

    # Equalized Odds (EO)
    tpr_rates = []
    for g in groups_unique:
        mask = (groups == g) & (labels == 1)
        if mask.sum() > 0:
            tpr_rates.append(preds[mask].mean())
    eo_diff = np.abs(max(tpr_rates) - min(tpr_rates))

    return dp_diff, eo_diff

print("\n--- Fairness Metrics ---")
for name in models_info.keys():
    l, p, _ = arrays[name]
    fitz_groups = np.clip(dataset.data["fitzpatrick"].values, 0, 6)
    dp, eo = fairness_metrics(l, p, fitz_groups)
    print(f"\n{name} - Demographic Parity Diff: {dp:.4f}")
    print(f"{name} - Equalized Odds Diff: {eo:.4f}")