In [2]:
from google.colab import drive
drive.mount('/content/drive')
!unzip "/content/drive/MyDrive/Image Directory Zips/aligned_faces_and_embeddings_v3.zip"

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: content/FairFace/aligned_faces_v3/Middle Eastern/12860.jpg  
  inflating: content/FairFace/aligned_faces_v3/Middle Eastern/83711.jpg  
  inflating: content/FairFace/aligned_faces_v3/Middle Eastern/69948.jpg  
  inflating: content/FairFace/aligned_faces_v3/Middle Eastern/45552.jpg  
  inflating: content/FairFace/aligned_faces_v3/Middle Eastern/46132.jpg  
  inflating: content/FairFace/aligned_faces_v3/Middle Eastern/41974.jpg  
  inflating: content/FairFace/aligned_faces_v3/Middle Eastern/42904.jpg  
  inflating: content/FairFace/aligned_faces_v3/Middle Eastern/77438.jpg  
  inflating: content/FairFace/aligned_faces_v3/Middle Eastern/33937.jpg  
  inflating: content/FairFace/aligned_faces_v3/Middle Eastern/23595.jpg  
  inflating: content/FairFace/aligned_faces_v3/Middle Eastern/70009.jpg  
  inflating: content/FairFace/aligned_faces_v3/Middle Eastern/64940.jpg  
  inflating: content/FairFace/aligned_faces_v3/

In [3]:
# @title Imports

# ===== Standard Library =====
import copy
import gc  # Only if used later
import math
import os
import random
import sys
import time
import logging
import traceback  # Only if used later
from collections import Counter, defaultdict
from contextlib import nullcontext
from typing import Dict, List, Optional
# ===== Third-party =====
import cv2
import numpy as np
import pandas as pd
import seaborn as sns  # Remove if not used in the notebook
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, UnidentifiedImageError
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.metrics import (accuracy_score, auc, classification_report, confusion_matrix,
                             f1_score, precision_score, recall_score, roc_curve)
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import LabelEncoder, label_binarize
from sklearn.utils.class_weight import compute_class_weight
from torch.cuda.amp import autocast, GradScaler  # Consolidated amp import
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.optim import Optimizer
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from skimage import color # Added for color space conversions

# Optional: Add specific imports for plotting
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick  # For advanced tick formatting in plots

In [4]:
 # @title Configurations
# --- Environment / device ---
def get_device():
    try:
        if torch.cuda.is_available():
            return torch.device("cuda")
        if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            return torch.device("mps")
    except Exception:
        pass
    return torch.device("cpu")

DEVICE = get_device()

# --- Environment flags ---
ON_COLAB = os.path.exists("/content")

# --- Paths per environment ---
if ON_COLAB:
    BASE_DATASET_DIR = "/content/content/FairFace/aligned_faces_v3"
    RESULTS_DIR = '/content/drive/MyDrive/personal_research_project/results_dir'
    EMBED_NPY = "/content/drive/MyDrive/Image Directory Zips/fairface_embeddings_v3.pt"
else:
    BASE_DATASET_DIR = "/Users/ticaurisstokes/Desktop/research utilities/filtered_fairface_by_race/"
    RESULTS_DIR = '/Users/ticaurisstokes/Desktop/research utilities/save_dir'
    EMBED_NPY = "/Users/ticaurisstokes/Desktop/research utilities/fairface_embeddings_v3.pt"

In [5]:
# @title Load Images

def load_img_from_dir(dir_path, max_images_per_class=None):
    image_paths, labels = [], []

    for class_name in sorted(os.listdir(dir_path)):
        class_path = os.path.join(dir_path, class_name)
        if not os.path.isdir(class_path):
            continue

        count = 0
        for img_name in os.listdir(class_path):
            if not img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
                continue

            img_path = os.path.join(class_path, img_name)
            if not os.path.isfile(img_path):
                continue

            image_paths.append(img_path)
            labels.append(class_name)
            count += 1

            if max_images_per_class and count >= max_images_per_class:
                break

    print(f"✅ Loaded {len(image_paths)} pre-filtered images.")
    return image_paths, labels

In [6]:
# @title Transformers

# Refined transformations for each type of augmentation.
def base_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

def aggressive_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.RandomAffine(degrees=15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.2)
    ])

def specific_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomApply([
            transforms.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6),
            transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.RandomPerspective(distortion_scale=0.5, p=0.5)
        ], p=0.9),
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.3)
    ])

# Define the Class-Based Augmentation Schedule with more dynamic decisions
class ClassBasedAugmentationSchedule:
    def __init__(self, class_policy_map=None, num_classes=0):
        self.class_policy_map = class_policy_map or {}
        self.class_performance = {i: 1.0 for i in range(num_classes)}  # Initialize performance

    def update_performance(self, y_true, y_pred):
        recalls = recall_score(y_true, y_pred, average=None, labels=np.unique(y_true), zero_division=0)
        recall_map = dict(zip(np.unique(y_true), recalls))

        for class_idx in self.class_performance:
            if class_idx in recall_map:
                self.class_performance[class_idx] = recall_map[class_idx]
        print(f"📊 Updated Augmentation Performance Metrics: {self.class_performance}")

    def get_transform(self, epoch, class_label):
        class_label = int(class_label)

        # Use base policy for warmup (epochs < 5)
        if epoch < 5:
            return "base_transform"

        # Get base policy for class
        base_policy = self.class_policy_map.get(class_label, "base_transform")

        # Dynamic augmentation strategy based on performance
        if self.class_performance.get(class_label, 1.0) < 0.5:
            print(f"Applying more aggressive transform for class {class_label} due to low recall.")
            return "aggressive_transform"
        elif self.class_performance.get(class_label, 1.0) < 0.75:
            return "specific_transform"
        return base_policy  # Default case (e.g., class-specific or basic transform)


In [7]:
# @title Custom Dataset

class CustomDataset(Dataset):
    """
    An optimized custom dataset class for image classification.
    - Applies augmentations using the high-performance Albumentations library.
    - Pre-processes and stores metadata in the __init__ method to avoid
      re-reading data and to make epoch starts much faster.
    - Handles training and validation/testing modes cleanly based on whether a
      dynamic 'class_policy_map' or a static 'transform_name' is provided.
    """
    def __init__(
        self,
        image_paths,
        labels,
        metadata=None, # This 'metadata' is a list of color_metrics dicts from z_train
        transform_name=None, # For validation/test, pass the name of the transform, e.g., "standard_transform"
        include_skin_vec=False,
        triplet_embedding_dict=None,
        class_policy_map=None, # For training, pass the policy map
        num_classes=None,
    ):
        self.triplet_embedding_dict = triplet_embedding_dict or {}
        self.metadata_input = metadata if metadata is not None else [] # Renamed to avoid conflict
        self.is_train = class_policy_map is not None # True if a policy map is provided (training mode)

        # Initialize epoch, will be updated by set_epoch for dynamic transforms
        self.epoch = 0

        if self.is_train:
            self.aug_schedule = ClassBasedAugmentationSchedule(class_policy_map, num_classes)
        # Store the name for static transform (validation/test mode)
        self.transform_name_for_val_test = transform_name

        # Transformation map (from `VsCu1UxyW4p4`) - ensures these are defined globally
        self.transform_map = {
            "base_transform": base_transform(),
            "standard_transform": base_transform(),
            "aggressive_transform": aggressive_transform(),
            "specific_transform": specific_transform(),
        }

        # --- Pre-processing Loop ---
        # This loop runs once to gather all necessary data, making __getitem__ much faster.
        self.data = []
        print("Pre-processing and caching dataset metadata...")
        for i, (img_path, label) in enumerate(tqdm(zip(image_paths, labels), total=len(image_paths))):

            embedding = self.triplet_embedding_dict.get(os.path.basename(img_path).lower())
            if embedding is None:
                # Use a zero tensor if an embedding is missing
                embedding = np.zeros(512, dtype=np.float32)

            # Retrieve the raw metadata (color_metrics dict) for this item
            raw_meta_for_item = self.metadata_input[i] if i < len(self.metadata_input) else {}

            skin_vec = np.zeros(12, dtype=np.float32)
            if include_skin_vec and raw_meta_for_item:
                try:
                    # Assumes raw_meta_for_item is a dictionary of metrics
                    skin_vec = build_skin_vector(raw_meta_for_item)
                except Exception as e:
                    print(f"Warning: Could not build skin vector for {img_path}. Using zeros. Error: {e}")

            # Store all necessary components, including the new ones
            self.data.append({
                "path": img_path,
                "label": label,
                "skin_vec": skin_vec,
                "embedding": embedding,
                "raw_metadata": raw_meta_for_item, # Store the full color_metrics dict
                "mst_bin": raw_meta_for_item.get("MST", 0), # Extract MST bin
                "skin_group": bin_mst_to_skin_group(raw_meta_for_item.get("MST", 0)) # Extract skin group
            })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Retrieve the pre-processed data for the item
        item_data = self.data[idx]
        img_path = item_data["path"]

        try:
            img = Image.open(img_path).convert("RGB") # PIL Image

            # 2. Determine and apply the correct transform
            if self.is_train:
                # Dynamic transform for training
                transform_key = self.aug_schedule.get_transform(self.epoch, item_data["label"])
                current_transform = self.transform_map.get(transform_key, base_transform())
            else:
                # Static transform for validation/test
                transform_key = self.transform_name_for_val_test if self.transform_name_for_val_test else "base_transform"
                current_transform = self.transform_map.get(transform_key, base_transform())

            img_tensor = current_transform(img) # Apply torchvision transform to PIL Image

        except Exception as e:
            print(f"⚠️  ERROR: Failed to load or process image {img_path}: {e}")
            # On error, return a placeholder tensor to prevent the training loop from crashing
            return (
                torch.zeros(3, 224, 224),
                torch.tensor(item_data["label"]),
                torch.zeros(12),
                torch.zeros(512),
                torch.tensor(item_data["mst_bin"], dtype=torch.int64),
                item_data["skin_group"],
                item_data["raw_metadata"]
            )

        # Convert numpy arrays to tensors
        skin_vec_tensor = torch.tensor(item_data["skin_vec"], dtype=torch.float32)
        embedding = item_data["embedding"]
        embedding_tensor = embedding if isinstance(embedding, torch.Tensor) else torch.tensor(embedding, dtype=torch.float32)

        # Retrieve the additional items to return
        mst_bin_val = torch.tensor(item_data["mst_bin"], dtype=torch.int64)
        skin_group_val = item_data["skin_group"]
        raw_meta_dict = item_data["raw_metadata"]

        # Return 7 items as expected by evaluate_model
        return img_tensor, item_data["label"], skin_vec_tensor, embedding_tensor, mst_bin_val, skin_group_val, raw_meta_dict

    def set_epoch(self, epoch):
        """
        Sets the current epoch for the dataset, used by the augmentation schedule.
        """
        self.epoch = epoch

In [8]:
# @title Color Calculation

def bin_mst_to_skin_group(mst_value: int) -> str:
    return f"MST_{mst_value}" if 1 <= mst_value <= 10 else "unknown"

def normalize_color_features(L, h):
    L_scaled = L / 100.0
    h_scaled = h / 360.0
    return L_scaled, h_scaled

def extract_color_metrics(image_path):
    image_bgr = cv2.imread(image_path)
    if image_bgr is None:
        return None, None

    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    image_rgb = image_rgb / 255.0

    lab = color.rgb2lab(image_rgb)
    l = lab[:, :, 0]
    a = lab[:, :, 1]
    b = lab[:, :, 2]
    h = np.degrees(np.arctan2(b, a)) % 360

    skin_pixels = l > 0
    avg_L = np.mean(l[skin_pixels])
    avg_h = np.mean(h[skin_pixels])

    return avg_L, avg_h

def estimate_mst_from_ita(ita_value):
    if ita_value > 55: return 1
    elif ita_value > 41: return 2
    elif ita_value > 28: return 3
    elif ita_value > 19: return 4
    elif ita_value > 10: return 5
    elif ita_value > 0: return 6
    elif ita_value > -10: return 7
    elif ita_value > -20: return 8
    elif ita_value > -30: return 9
    else: return 10

def extract_color_metrics_and_estimate_mst(image_path):
    avg_L, avg_h = extract_color_metrics(image_path)
    if avg_L is None or avg_h is None:
        return None

    ita = np.degrees(np.arctan((avg_L - 50) / avg_h))
    mst_bin = estimate_mst_from_ita(ita)

    return {
        "L": avg_L,
        "h": avg_h,
        "MST": mst_bin
    }

def normalize_ita_hue(ita, hue):
    ita_scaled = (ita + 60) / 120
    hue_scaled = hue / 360.0
    return ita_scaled, hue_scaled

def one_hot_encode_mst(mst_bin, num_classes=10):
    one_hot = np.zeros(num_classes)
    if 1 <= mst_bin <= num_classes:
        one_hot[mst_bin - 1] = 1.0
    return one_hot

def build_skin_vector(color_metrics):
    if color_metrics is None:
        return None

    L = color_metrics["L"]
    h = color_metrics["h"]
    ita = np.degrees(np.arctan((L - 50) / h))
    ita_scaled, hue_scaled = normalize_ita_hue(ita, h)

    mst_onehot = one_hot_encode_mst(color_metrics["MST"])
    return np.array([ita_scaled, hue_scaled], dtype=np.float32).tolist() + mst_onehot.tolist()

In [9]:
# @title Oversampling and Balancing

# Dynamic oversampling target counts
def calculate_dynamic_target_counts(y_encoded, z, oversample_percentage=1.2):
    combo_counts = defaultdict(int)
    for label_encoded, metadata in zip(y_encoded, z):
        mst_group = bin_mst_to_skin_group(metadata.get("MST")) # Use the same binning as balance_data_to_targets
        if mst_group != "unknown": # Only consider known MST groups
            combo_counts[(label_encoded, mst_group)] += 1

    if not combo_counts:
        return {} # Return empty if no valid combos found

    max_combo_count = max(combo_counts.values())

    dynamic_target_counts = {}
    for (label_encoded, mst_group), count in combo_counts.items():
        # Target all combos to be at least (max_combo_count * oversample_percentage)
        dynamic_target_counts[(label_encoded, mst_group)] = int(max_combo_count * oversample_percentage)

    return dynamic_target_counts

# Definition for balance_data_to_targets (copied for self-sufficiency)
def balance_data_to_targets(X, y, z, target_counts):
    print("\n⚖️ Balancing dataset to meet fairness targets...")

    # Print class distribution before oversampling
    print("\n📊 Class Distribution Before Oversampling:")
    pre_oversample_distribution = defaultdict(Counter)
    for label, metadata in zip(y, z):
        mst_value = metadata.get("MST")
        if mst_value is not None:
            pre_oversample_distribution[label][mst_value] += 1

    for label, mst_counts in sorted(pre_oversample_distribution.items()):
        print(f"\n--- Class: {label} ---")
        for mst_value, count in sorted(mst_counts.items()):
            print(f"  MST {mst_value}: {count} samples")

    grouped_indices = defaultdict(list)
    for i, (label, metadata) in enumerate(zip(y, z)):
        mst_group = bin_mst_to_skin_group(metadata.get("MST"))
        if mst_group != "unknown":
            grouped_indices[(label, mst_group)].append(i)

    balanced_indices = []
    for (label, mst_group), target_count in target_counts.items():
        available_indices = grouped_indices.get((label, mst_group), [])
        if not available_indices:
            continue

        if target_count > len(available_indices):
            chosen_indices = random.choices(available_indices, k=target_count)
        else:
            chosen_indices = random.sample(available_indices, k=target_count)

        balanced_indices.extend(chosen_indices)

    X_bal = [X[i] for i in balanced_indices]
    y_bal = [y[i] for i in balanced_indices]
    z_bal = [z[i] for i in balanced_indices]

    # Print class distribution after oversampling
    print("\n📊 Class Distribution After Oversampling:")
    post_oversample_distribution = defaultdict(Counter)
    for label, metadata in zip(y_bal, z_bal):
        mst_value = metadata.get("MST")
        if mst_value is not None:
            post_oversample_distribution[label][mst_value] += 1

    for label, mst_counts in sorted(post_oversample_distribution.items()):
        print(f"\n--- Class: {label} ---")
        for mst_value, count in sorted(mst_counts.items()):
            print(f"  MST {mst_value}: {count} samples")

    # Ensure that the dataset is not empty after balancing
    if len(X_bal) == 0 or len(y_bal) == 0 or len(z_bal) == 0:
        raise ValueError("Oversampled dataset is empty. Please check your oversampling logic.")

    return X_bal, y_bal, z_bal

In [10]:
# @title Grad-Cam

# Utility to clear all forward hooks to prevent retain_grad issues
def clear_all_forward_hooks(model: torch.nn.Module):
    """Clear all forward hooks to prevent memory leaks and prevent retain_grad."""
    if hasattr(model, "_forward_hooks"):
        model._forward_hooks.clear()
    for m in model.modules():
        if hasattr(m, "_forward_hooks"):
            m._forward_hooks.clear()

class SafeGradCAM:
    """Safe Grad-CAM++ implementation that uses hooks to calculate and visualize gradients and activations for Grad-CAM++ heatmaps."""
    def __init__(self, model: torch.nn.Module, target_layer: torch.nn.Module):
        self.model = model
        self.target_layer = target_layer
        self.activations = None
        self.gradients = None
        self._handles = []

    def _fw_hook(self, module, inputs, output):
        """Forward hook to capture activations."""
        self.activations = output
        if isinstance(output, torch.Tensor) and output.requires_grad:
            output.retain_grad()  # Retain gradient if required

    def _bw_hook(self, module, grad_input, grad_output):
        """Backward hook to capture gradients."""
        self.gradients = grad_output[0]  # Gradients w.r.t the target layer

    def _register(self):
        """Register hooks to capture gradients and activations."""
        self._handles.append(self.target_layer.register_forward_hook(self._fw_hook))
        self._handles.append(self.target_layer.register_full_backward_hook(self._bw_hook))

    def _remove(self):
        """Remove all registered hooks."""
        for h in self._handles:
            try:
                h.remove()
            except Exception:
                pass
        self._handles = []

    @torch.no_grad()
    def _overlay(self, heatmap: np.ndarray, image_path: str):
        """Overlay the Grad-CAM++ heatmap onto the original image."""
        src = cv2.imread(image_path, cv2.IMREAD_COLOR)
        if src is None:
            return None
        hm = (255 * (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6)).astype(np.uint8)
        hm = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
        out = cv2.addWeighted(src, 0.5, hm, 0.5, 0)
        return out

    def generate(self, input_tensor: torch.Tensor, skin_vec: torch.Tensor, target_class: int = None):
        """Generate Grad-CAM++ heatmap for a single image."""
        self.model.eval()
        self._register()

        try:
            with torch.enable_grad():
                input_tensor = input_tensor.requires_grad_(True)
                out = self.model(input_tensor, skin_vec)
                if target_class is None:
                    target_class = int(out.argmax(dim=1).item())
                loss = out[0, target_class]
                self.model.zero_grad(set_to_none=True)
                loss.backward()

                A = self.activations   # [1,C,H,W]
                G = self.gradients     # [1,C,H,W]
                assert A is not None and G is not None, "Hooks did not capture activations/gradients."

                posG = F.relu(G)
                alpha_num = G.pow(2)
                alpha_den = 2 * alpha_num + (A * G.pow(3)).sum(dim=(2,3), keepdim=True)
                alpha_den = torch.where(alpha_den != 0, alpha_den, torch.ones_like(alpha_den))
                alpha = alpha_num / alpha_den
                weights = (alpha * posG).sum(dim=(2,3))   # [1,C]
                cam = (weights.unsqueeze(-1).unsqueeze(-1) * A).sum(dim=1)[0]  # [H,W]
                cam = F.relu(cam)
                cam = cam / (cam.max() + 1e-6)
                return cam.detach().cpu().numpy()
        finally:
            self._remove()

class GradCAM:
    """Grad-CAM++ style heatmap for a specific target layer."""
    def __init__(self, model: torch.nn.Module, target_layer: torch.nn.Module):
        self.model = model.eval()
        self.target_layer = target_layer
        self.activations = None   # A
        self.gradients = None     # dY/dA

        # keep handles so you can remove later if needed
        self._h_fwd = target_layer.register_forward_hook(self._fwd_hook)
        self._h_bwd = target_layer.register_full_backward_hook(self._bwd_hook)

    def _fwd_hook(self, module, inputs, output):
        # DO NOT detach; we need the graph. Also, retain grad on activations.
        self.activations = output
        if isinstance(self.activations, torch.Tensor):
            self.activations.retain_grad()

    def _bwd_hook(self, module, grad_input, grad_output):
        # grad_output is a tuple; take the gradient w.r.t. the layer output
        self.gradients = grad_output[0]

    @torch.enable_grad()  # make sure grads are enabled even during eval sections
    def generate(
        self,
        input_tensor: torch.Tensor,
        skin_vec: torch.Tensor,
        target_class: Optional[int] = None,
        use_amp: bool = False,
        device_type: str = "cuda",
    ) -> np.ndarray:
        """Returns a (H, W) numpy heatmap in [0,1]."""
        self.model.zero_grad(set_to_none=True)

        if use_amp:
            from torch.amp import autocast
            ctx = autocast(device_type=device_type, dtype=torch.float16)
        else:
            from contextlib import nullcontext
            ctx = nullcontext()

        with ctx:
            output = self.model(input_tensor, skin_vec)  # [1, C]
            if target_class is None:
                target_class = int(output.argmax(dim=1).item())
            score = output[0, target_class]

        # Backprop dY/dA
        score.backward(retain_graph=False)

        A = self.activations            # [B, K, H, W]
        dYdA = self.gradients           # [B, K, H, W]
        assert A is not None and dYdA is not None, "Hooks did not capture activations/gradients."

        # Grad-CAM++ weights
        dYdA_pos = F.relu(dYdA)
        alpha_num = dYdA.pow(2)
        alpha_den = 2 * alpha_num + (A * dYdA.pow(3)).sum(dim=(2, 3), keepdim=True)
        alpha_den = torch.where(alpha_den != 0, alpha_den, torch.ones_like(alpha_den))
        alpha = alpha_num / alpha_den
        weights = (alpha * dYdA_pos).sum(dim=(2, 3))  # [B, K]

        # Weighted sum of activation maps (use sample 0)
        A0 = A[0]                     # [K, H, W]
        w0 = weights[0].view(-1, 1, 1)
        heatmap = (w0 * A0).sum(dim=0)  # [H, W]

        heatmap = F.relu(heatmap)
        heatmap = heatmap / (heatmap.max() + 1e-6)
        return heatmap.detach().cpu().numpy()

def get_gradcam_layer(model, model_name, use_gradcam=False):
    """Heuristically find the last convolutional layer for Grad-CAM."""
    if not use_gradcam:
        print("Skipping Grad-CAM for this run.")
        return None  # Simply return None if Grad-CAM is not needed

    # Existing Grad-CAM logic for model with layers (EfficientNet, etc.)
    if hasattr(model, "get_gradcam_target_layer"):
        return model.get_gradcam_target_layer()

    # Handle EfficientNet models specifically
    if isinstance(model, EfficientNetWithAttention):
        # EfficientNet models might have features instead of conv_head
        if hasattr(model, "conv_head"):
            return model.conv_head  # In case conv_head exists for Grad-CAM
        elif hasattr(model, "features"):
            # Handle EfficientNet's features (the last layer in features)
            return model.features[-1]  # EfficientNet models usually have the last convolutional layer in features
        else:
            raise ValueError(f"EfficientNet model '{model_name}' does not have a conv_head or features layer for Grad-CAM.")

    # EfficientNet model family check (in case EfficientNet isn't using EfficientNetWithAttention)
    if "efficientnet" in model_name.lower():
        if hasattr(model, "features"):
            return model.features[-1]  # The last layer in EfficientNet is usually in 'features'
        else:
            raise ValueError(f"EfficientNet model '{model_name}' does not have the expected layer for Grad-CAM.")

    # Define layers for other model families
    layers_map = {
        "resnet": lambda model: model.base.layer4,  # ResNet: The last block of layer4
        "vgg": lambda model: model.base.features[-1],  # VGG: The last feature layer
        "mobilenet": lambda model: model.base.features[-1],  # MobileNet: The last feature layer
        "inception_v3": lambda model: model.base.Mixed_7c,  # Inception V3: Last mixed block
    }

    # Check which model family we are dealing with
    for model_prefix, get_layer_func in layers_map.items():
        if model_name.lower().startswith(model_prefix):
            try:
                return get_layer_func(model)
            except AttributeError as e:
                raise ValueError(f"Model '{model_name}' does not have the expected attribute for Grad-CAM: {e}")

    # If model name is not recognized
    raise ValueError(f"Unsupported model for Grad-CAM: {model_name}")


In [11]:
# @title Multi Layer Perceptron

class Skin_Multi_Layer_Perceptron(nn.Module):
    def __init__(self, input_dim=12, hidden_dim1=32, hidden_dim2=16, output_dim=8):
        super(Skin_Multi_Layer_Perceptron, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.BatchNorm1d(hidden_dim1),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.BatchNorm1d(hidden_dim2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim2, output_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.mlp(x)

In [12]:
# @title Two Layer Classifier Head

class TwoLayerClassifierHead(nn.Module):
    def __init__(self, input_dim, hidden_dim=512, output_dim=7):  # or 4 for FairFace4
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.head(x)

In [13]:
# @title CBAM

class CBAM(nn.Module):
    def __init__(self, channels, reduction=16, kernel_size=7, use_film=False, film_in_dim=12):
        super().__init__()
        assert isinstance(channels, int) and channels > 0, \
            f"❌ CBAM init error: channels must be int>0, got {channels}"

        self.channels = channels
        self.use_film = use_film

        # Channel Attention
        self.channel_attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, kernel_size=1),
            nn.Sigmoid()
        )

        # Spatial Attention
        self.spatial_attn = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2),
            nn.Sigmoid()
        )

        # Optional FiLM
        if self.use_film:
            self.film = FiLM(film_in_dim, channels)

    def forward(self, x, skin_vec=None):
        """
        Forward pass through the CBAM block.

        Parameters:
        - x (torch.Tensor): Input tensor of shape [B, C, H, W].
        - skin_vec (torch.Tensor, optional): Skin tone vector, used only if `use_film=True`.

        Returns:
        - out (torch.Tensor): Attention-modulated output tensor.
        """
        # === Debug: Check shapes early ===
        self._check_input(x, skin_vec)

        # === Channel Attention ===
        ca = self.channel_attn(x)  # [B, C, 1, 1]
        ca = ca * x  # Apply channel attention

        # === Optional FiLM (Feature-wise Linear Modulation) ===
        if self.use_film:
            ca = self.film(ca, skin_vec)

        # === Spatial Attention ===
        sa = self.spatial_attention(ca)  # Apply spatial attention

        out = sa * ca  # Apply spatial attention to the channel attention output
        return out

    def _check_input(self, x, skin_vec):
        """Helper method to check input tensor shapes and raise informative errors."""
        if x.dim() != 4:
            raise ValueError(f"❌ Expected 4D input [B, C, H, W], got {x.shape}")
        B, C, H, W = x.shape
        if C != self.channels:
            raise ValueError(f"❌ Channel mismatch: got {C}, expected {self.channels}")

        if self.use_film:
            if skin_vec is None:
                raise ValueError("❌ use_film=True but skin_vec is None")
            if skin_vec.shape[0] != B:
                raise ValueError(f"❌ Batch mismatch: skin_vec batch {skin_vec.shape[0]} vs input {B}")

    def spatial_attention(self, x):
        """Apply spatial attention using max and average pooling."""
        avg_out = torch.mean(x, dim=1, keepdim=True)  # [B, 1, H, W]
        max_out, _ = torch.max(x, dim=1, keepdim=True)  # [B, 1, H, W]
        sa_input = torch.cat([avg_out, max_out], dim=1)  # [B, 2, H, W]
        sa = self.spatial_attn(sa_input)  # [B, 1, H, W]
        return sa

In [14]:
# @title FiLM
class FiLM(nn.Module):
    def __init__(self, in_features, feature_map_channels):
        """
        Feature-wise Linear Modulation (FiLM) Layer.

        Args:
            in_features (int): Input size of conditioning vector (e.g., 12)
            feature_map_channels (int): Number of channels in the input feature map
        """
        super().__init__()

        # Modify in_features to accept 12 instead of 10 for skin vector
        self.gamma_fc = nn.Linear(in_features, feature_map_channels)
        self.beta_fc = nn.Linear(in_features, feature_map_channels)

        # Initialize gamma and beta
        self._initialize_parameters()

    def _initialize_parameters(self):
        """Initialize gamma and beta to specific values."""
        nn.init.constant_(self.gamma_fc.weight, 1)  # Initialize gamma weights to 1
        nn.init.constant_(self.beta_fc.weight, 0)   # Initialize beta weights to 0
        nn.init.constant_(self.gamma_fc.bias, 0)    # Initialize gamma bias to 0
        nn.init.constant_(self.beta_fc.bias, 0)     # Initialize beta bias to 0

    def forward(self, x, cond):
        """
        Args:
            x (torch.Tensor): Feature map tensor of shape [B, C, H, W]
            cond (torch.Tensor): Conditioning vector (e.g., skin vector) of shape [B, in_features]

        Returns:
            torch.Tensor: FiLM-modulated feature map of shape [B, C, H, W]
        """
        # Check dimensions
        B, C, H, W = x.size()
        assert cond.size(0) == B, f"Batch size mismatch: x ({B}) vs cond ({cond.size(0)})"
        assert cond.size(1) == self.gamma_fc.in_features, \
            f"Feature size mismatch: cond ({cond.size(1)}) vs in_features ({self.gamma_fc.in_features})"

        # Get the gamma and beta from the conditioning vector
        gamma = self.gamma_fc(cond).view(B, C, 1, 1)  # [B, C, 1, 1]
        beta = self.beta_fc(cond).view(B, C, 1, 1)    # [B, C, 1, 1]

        # FiLM modulation: scale and shift the feature map
        return gamma * x + beta

In [15]:
# @title ResNet

class ResNetWithAttention(nn.Module):
    def __init__(
        self,
        num_classes,
        backbone_name="resnet152d",
        attention_type="none",
        drop_path_rate=0.2,
        dropout_rate=0.6,
        use_film_before=False,
        use_film_in_cbam=False,
        use_triplet_embedding=False,
        triplet_embedding_dim=512,
        include_skin_vec=True,
        fusion_mode="concat",
        fusion_hidden_dim=128,
    ):
        super().__init__()

        self.use_film_before = use_film_before
        self.use_film_in_cbam = use_film_in_cbam
        self.use_triplet_embedding = use_triplet_embedding
        self.triplet_embedding_dim = triplet_embedding_dim
        self.include_skin_vec = include_skin_vec
        self.fusion_mode = fusion_mode
        self.fusion_hidden_dim = fusion_hidden_dim

        self.base = timm.create_model(backbone_name, pretrained=True, num_classes=0)
        C = self.base.num_features
        self._feat_dim = C

        # ✅ FiLM now uses 12-dim conditioning
        if self.use_film_before:
            self.film = FiLM(in_features=12, feature_map_channels=C)

        # ✅ CBAM now uses 12-dim cond
        if attention_type == "cbam":
            self.attn = CBAM(C, use_film=self.use_film_in_cbam, film_in_dim=12)
        elif attention_type == "self":
            self.attn = SelfAttentionBlock(C)
        else:
            self.attn = nn.Identity()

        # ✅ Skin MLP now 12-dim
        self.skin_mlp = Skin_Multi_Layer_Perceptron(input_dim=12)

        # Fusion + classifier logic same pattern as above if you're using fusion_mode
        # (if you also have fusion here, apply the same "final_in_dim then classifier" pattern)


        if fusion_mode in ["mlp", "gated"]:
            self.image_proj = nn.Linear(C, fusion_hidden_dim)
            self.skin_proj = nn.Linear(8, fusion_hidden_dim)
            if self.use_triplet_embedding:
                self.triplet_proj = nn.Linear(triplet_embedding_dim, fusion_hidden_dim)

            if fusion_mode == "gated":
                gate_input_dim = C + 8 + (triplet_embedding_dim if self.use_triplet_embedding else 0)
                self.gate = nn.Sequential(
                    nn.Linear(gate_input_dim, 3),
                    nn.Softmax(dim=1)
                )
            final_in_dim = fusion_hidden_dim
        else:
            final_in_dim = C + 8 + (triplet_embedding_dim if self.use_triplet_embedding else 0)

        self.expected_final_dim = final_in_dim
        self.dropout = nn.Dropout(p=dropout_rate)
        self.classifier = TwoLayerClassifierHead(input_dim=final_in_dim, output_dim=num_classes)

    def forward_features(self, x, skin_vec):
        x = self.base.forward_features(x)
        if self.use_film_before:
            x = self.film(x, skin_vec)
        x = self.attn(x, skin_vec) if isinstance(self.attn, CBAM) else self.attn(x)
        return x

    def forward(self, x, skin_vec=None, triplet_embedding=None, return_features=False):
        B = x.size(0)
        if self.include_skin_vec and skin_vec is None:
            skin_vec = torch.zeros((B, 10), device=x.device)

        feat = self.forward_features(x, skin_vec)
        feat = F.adaptive_avg_pool2d(feat, 1).view(B, -1)
        features = feat

        skin_feat = self.skin_mlp(skin_vec)

        if self.use_triplet_embedding:
            if triplet_embedding is None:
                triplet_embedding = torch.zeros((B, self.triplet_embedding_dim), device=x.device, dtype=feat.dtype)

        if self.fusion_mode == "concat":
            parts = [feat, skin_feat] + ([triplet_embedding] if self.use_triplet_embedding else [])
            final_feat = torch.cat(parts, dim=1)

        elif self.fusion_mode == "mlp":
            feat_proj = self.image_proj(feat)
            skin_proj = self.skin_proj(skin_feat)
            if self.use_triplet_embedding:
                triplet_proj = self.triplet_proj(triplet_embedding)
                final_feat = feat_proj + skin_proj + triplet_proj
            else:
                final_feat = feat_proj + skin_proj

        elif self.fusion_mode == "gated":
            gate_in = torch.cat([feat, skin_feat] + ([triplet_embedding] if self.use_triplet_embedding else []), dim=1)
            weights = self.gate(gate_in)
            feat_proj = self.image_proj(feat)
            skin_proj = self.skin_proj(skin_feat)
            if self.use_triplet_embedding:
                triplet_proj = self.triplet_proj(triplet_embedding)
                final_feat = weights[:, 0:1] * feat_proj + weights[:, 1:2] * skin_proj + weights[:, 2:3] * triplet_proj
            else:
                final_feat = weights[:, 0:1] * feat_proj + weights[:, 1:2] * skin_proj
        else:
            raise ValueError(f"Unknown fusion mode: {self.fusion_mode}")
        #print(f"Fusion: {self.fusion_mode}, Final feature shape: {final_feat.shape}")
        final_feat = self.dropout(final_feat)
        logits = self.classifier(final_feat)

        return (logits, features) if return_features else logits

In [16]:
# @title EfficientNet

class EfficientNetWithAttention(nn.Module):
    def __init__(
        self,
        num_classes,
        attention_type="none",
        pretrained=True,
        use_film=False,
        use_film_before=False,
        use_film_in_cbam=False,
        use_triplet_embedding=False,
        triplet_embedding_dim=512,
        include_skin_vec=True,
        efficientnet_variant="efficientnet_b0",
        dropout_rate=0.6,  # <-- This argument
        fusion_mode="concat",  # "concat", "mlp", or "gated"
        fusion_hidden_dim=128,
        device=None,
    ):
        super().__init__()

        # Device
        self.device = device if device is not None else torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        print(f"Using device: {self.device}")

        self.use_film_before = use_film_before
        self.use_film_in_cbam = use_film_in_cbam
        self.use_triplet_embedding = use_triplet_embedding
        self.triplet_embedding_dim = triplet_embedding_dim
        self.include_skin_vec = include_skin_vec
        self.fusion_mode = fusion_mode
        self.fusion_hidden_dim = fusion_hidden_dim

        # Backbone
        self.base = timm.create_model(
            efficientnet_variant, pretrained=pretrained, num_classes=0
        )
        C = self.base.num_features
        print(f"Feature dimension (C): {C}")
        if not isinstance(C, int) or C <= 0:
            raise ValueError(f"Invalid feature dimension C: {C}")
        self._feat_dim = C

        # FiLM before attention (skin_vec is 12-dim: ITA, hue, 10 MST one-hot)
        if self.use_film_before:
            self.film = FiLM(in_features=12, feature_map_channels=C)
        else:
            self.film = nn.Identity() # Added for consistency

        # Attention
        if attention_type == "self":
            self.attn = SelfAttentionBlock(C)
        elif attention_type == "cbam":
            self.attn = CBAM(C, use_film=self.use_film_in_cbam, film_in_dim=12)
        else:
            self.attn = nn.Identity()

        # Skin MLP: 12 -> 8
        self.skin_mlp = Skin_Multi_Layer_Perceptron(input_dim=12)  # -> 8D

        # 🔑 FUSION BLOCK — this decides final_in_dim
        if self.fusion_mode in ["mlp", "gated"]:
            print(f"Fusion Dim: {fusion_hidden_dim}")
            if not isinstance(fusion_hidden_dim, int) or fusion_hidden_dim <= 0:
                raise ValueError(
                    f"Invalid fusion_hidden_dim: {fusion_hidden_dim}. It should be a positive integer."
                )

            self.image_proj = nn.Linear(C, fusion_hidden_dim).to(
                self.device, dtype=torch.float32
            )
            self.skin_proj = nn.Linear(8, fusion_hidden_dim).to(
                self.device, dtype=torch.float32
            )
            if self.use_triplet_embedding:
                self.triplet_proj = nn.Linear(
                    triplet_embedding_dim, fusion_hidden_dim
                ).to(self.device, dtype=torch.float32)

            if self.fusion_mode == "gated":
                gate_input_dim = (
                    C
                    + 8
                    + (triplet_embedding_dim if self.use_triplet_embedding else 0)
                )
                self.gate = nn.Sequential(
                    nn.Linear(gate_input_dim, 3),
                    nn.Softmax(dim=1),
                ).to(self.device, dtype=torch.float32)

            final_in_dim = fusion_hidden_dim
        else:
            # "concat" mode
            final_in_dim = C + 8 + (triplet_embedding_dim if use_triplet_embedding else 0)

        self.final_in_dim = final_in_dim

        self.dropout = nn.Dropout(p=dropout_rate) # Now uses the argument

        #CLASSIFIER CREATED AFTER fusion_dim is known
        self.classifier = TwoLayerClassifierHead(
            input_dim=self.final_in_dim, output_dim=num_classes
        )

    def forward_features(self, x, skin_vec):
        x = self.base.forward_features(x)
        # Apply film if it exists
        if hasattr(self, 'film'):
             x = self.film(x, skin_vec)
        x = self.attn(x, skin_vec) if isinstance(self.attn, CBAM) else self.attn(x)
        return x

    def forward(self, x, skin_vec=None, triplet_embedding=None, return_features=False):
        B = x.size(0)
        if self.include_skin_vec and skin_vec is None:
            skin_vec = torch.zeros((B, 12), device=x.device, dtype=x.dtype) # Use dtype of x

        feat = self.forward_features(x, skin_vec)
        feat = F.adaptive_avg_pool2d(feat, 1).view(B, -1)

        skin_feat = self.skin_mlp(skin_vec)

        if self.use_triplet_embedding:
            if triplet_embedding is None:
                triplet_embedding = torch.zeros(
                    (B, self.triplet_embedding_dim), device=x.device, dtype=feat.dtype # Use dtype of feat
                )
        else:
            triplet_embedding = None

        # Fusion logic based on fusion_mode
        if self.fusion_mode == "concat":
            parts = [feat, skin_feat] + (
                [triplet_embedding] if self.use_triplet_embedding else []
            )
            final_feat = torch.cat(parts, dim=1)
        elif self.fusion_mode == "mlp":
            feat_proj = self.image_proj(feat)
            skin_proj = self.skin_proj(skin_feat)
            if self.use_triplet_embedding:
                triplet_proj = self.triplet_proj(triplet_embedding)
                final_feat = feat_proj + skin_proj + triplet_proj
            else:
                final_feat = feat_proj + skin_proj
        elif self.fusion_mode == "gated":
            gate_in = torch.cat(
                [feat, skin_feat]
                + ([triplet_embedding] if self.use_triplet_embedding else []),
                dim=1,
            )
            weights = self.gate(gate_in)
            feat_proj = self.image_proj(feat)
            skin_proj = self.skin_proj(skin_feat)
            if self.use_triplet_embedding:
                triplet_proj = self.triplet_proj(triplet_embedding)
                final_feat = (
                    weights[:, 0:1] * feat_proj
                    + weights[:, 1:2] * skin_proj
                    + weights[:, 2:3] * triplet_proj
                )
            else:
                final_feat = (
                    weights[:, 0:1] * feat_proj + weights[:, 1:2] * skin_proj
                )
        else:
            raise ValueError(f"Unknown fusion mode: {self.fusion_mode}")
        final_feat = self.dropout(final_feat)
        logits = self.classifier(final_feat)

        return (logits, feat) if return_features else logits

In [17]:
# @title InceptionNet V3

class InceptionV3WithAttention(nn.Module):
    def __init__(
        self,
        num_classes,
        attention_type="none",
        pretrained=True,
        use_film_before=False,
        use_film_in_cbam=False,
        use_triplet_embedding=False,
        triplet_embedding_dim=512,
        include_skin_vec=True,
        fusion_mode="concat",   # "concat", "mlp", "gated"
        fusion_hidden_dim=128,
        dropout_rate=0.8,
        **kwargs,
    ):
        super().__init__()

        self.include_skin_vec = include_skin_vec
        self.use_triplet_embedding = use_triplet_embedding
        self.fusion_mode = fusion_mode
        self.fusion_hidden_dim = fusion_hidden_dim

        # Backbone
        self.base = timm.create_model("inception_v3", pretrained=pretrained, num_classes=0)
        feature_dim = self.base.num_features

        # FiLM
        self.film = FiLM(in_features=12, feature_map_channels=feature_dim) if use_film_before else nn.Identity()

        # Attention
        if attention_type == "cbam":
            self.attn = CBAM(channels=feature_dim, use_film=use_film_in_cbam, film_in_dim=12)
        else:
            self.attn = nn.Identity()

        # Skin MLP
        self.skin_mlp = Skin_Multi_Layer_Perceptron(input_dim=12)

        # 🔑 Fusion logic
        if self.fusion_mode in ["mlp", "gated"]:
            if not isinstance(fusion_hidden_dim, int) or fusion_hidden_dim <= 0:
                raise ValueError(f"Invalid fusion_hidden_dim: {fusion_hidden_dim}")

            self.image_proj = nn.Linear(feature_dim, fusion_hidden_dim)
            self.skin_proj = nn.Linear(8, fusion_hidden_dim)
            if self.use_triplet_embedding:
                self.triplet_proj = nn.Linear(triplet_embedding_dim, fusion_hidden_dim)

            if self.fusion_mode == "gated":
                gate_input_dim = (
                    feature_dim
                    + 8
                    + (triplet_embedding_dim if self.use_triplet_embedding else 0)
                )
                self.gate = nn.Sequential(
                    nn.Linear(gate_input_dim, 3),
                    nn.Softmax(dim=1),
                )
            final_in_dim = fusion_hidden_dim
        else:
            # concat
            final_in_dim = feature_dim + 8 + (triplet_embedding_dim if self.use_triplet_embedding else 0)

        self.final_in_dim = final_in_dim

        # Was: self.dropout = nn.Dropout(p=0.3)
        self.dropout = nn.Dropout(p=dropout_rate) # Now uses the argument


        # ✅ Classifier created using final_in_dim
        self.classifier = TwoLayerClassifierHead(
            input_dim=self.final_in_dim, output_dim=num_classes
        )

    def get_gradcam_target_layer(self):
        return self.base.Mixed_7c

    def forward_features(self, x, skin_vec):
        """Extract features from the image with optional FiLM and attention."""
        x = self.base.forward_features(x)  # Extract features from the base model
        x = self.film(x, skin_vec)
        x = self.attn(x, skin_vec) if isinstance(self.attn, CBAM) else self.attn(x) # Handle Identity
        return x

    def forward(self, x, skin_vec=None, triplet_embedding=None, return_features=False):
        """Forward pass through the model, including feature extraction, fusion, and classification."""
        batch_size = x.size(0)

        if self.include_skin_vec and skin_vec is None:
            skin_vec = torch.zeros((batch_size, 12), device=x.device, dtype=x.dtype)

        feat = self.forward_features(x, skin_vec)
        feat = F.adaptive_avg_pool2d(feat, 1).view(batch_size, -1)

        skin_feat = self.skin_mlp(skin_vec)

        if self.use_triplet_embedding:
            if triplet_embedding is None:
                triplet_embedding = torch.zeros((batch_size, 512), device=x.device, dtype=feat.dtype) # Use feat dtype

        # Fusion: Concatenate or use MLP/Gated fusion
        if self.fusion_mode == "concat":
            parts = [feat, skin_feat] + ([triplet_embedding] if self.use_triplet_embedding else [])
            final_feat = torch.cat(parts, dim=1)
        elif self.fusion_mode == "mlp":
            feat_proj = self.image_proj(feat)
            skin_proj = self.skin_proj(skin_feat)
            if self.use_triplet_embedding:
                triplet_proj = self.triplet_proj(triplet_embedding)
                final_feat = feat_proj + skin_proj + triplet_proj
            else:
                final_feat = feat_proj + skin_proj
        elif self.fusion_mode == "gated":
            gate_in = torch.cat([feat, skin_feat] + ([triplet_embedding] if self.use_triplet_embedding else []), dim=1)
            weights = self.gate(gate_in)
            feat_proj = self.image_proj(feat)
            skin_proj = self.skin_proj(skin_feat)
            if self.use_triplet_embedding:
                triplet_proj = self.triplet_proj(triplet_embedding)
                final_feat = weights[:, 0:1] * feat_proj + weights[:, 1:2] * skin_proj + weights[:, 2:3] * triplet_proj
            else:
                final_feat = weights[:, 0:1] * feat_proj + weights[:, 1:2] * skin_proj
        else:
            raise ValueError(f"Unknown fusion mode: {self.fusion_mode}")
        #print(f"Fusion: {self.fusion_mode}, Final feature shape: {final_feat.shape}")
        final_feat = self.dropout(final_feat)
        logits = self.classifier(final_feat)

        return (logits, feat) if return_features else logits

In [18]:
# @title Plotting Functions

# Set up logging for error handling and info messages
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


# Ensure the directory exists
def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)
    return p

def tensor_to_bgr_uint8(img: torch.Tensor) -> np.ndarray:
    """
    Convert a CHW tensor in [0,1] or roughly normalized to a BGR uint8 image.
    If it was normalized with ImageNet stats, this still looks acceptable after min-max.
    """
    with torch.no_grad():
        x = img.detach().float().cpu()

        # If it is a 4D tensor, take the first element (batch size)
        if x.ndimension() == 4:
            x = x[0]

        # Min-max to [0, 1] and convert to [0, 255] BGR
        x -= x.min()
        denom = (x.max() - x.min()).clamp(min=1e-6)
        x = x / denom
        x = (x * 255.0).clamp(0, 255).byte()

        # Convert CHW to HWC and then BGR (from RGB)
        x = x.numpy().transpose(1, 2, 0)
        return x[..., ::-1]  # Convert RGB to BGR


def plot_tsne(features, labels, label_encoder, graph_dir, model_name, fold_classes, attention_type="CBAM", save_plots=False):
    """Plot t-SNE visualization."""
    tsne = TSNE(n_components=2, random_state=42)
    tsne_results = tsne.fit_transform(features)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=labels, cmap='jet', s=50)
    plt.colorbar(scatter)
    plt.title(f"t-SNE visualization for {model_name}")
    if save_plots:
        plot_path = os.path.join(graph_dir, f"tsne_{model_name}.png")
        plt.savefig(plot_path)
        print(f"✅ t-SNE plot saved to: {plot_path}")
    else:
        plt.show()

def overlay_heatmap(heatmap: np.ndarray, image_path: str, alpha: float = 0.5, colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
    """
    Overlays a heatmap onto an image for visualization.

    Args:
        heatmap (np.ndarray): The heatmap array (grayscale, 0-1).
        image_path (str): The path to the original image.
        alpha (float): The transparency of the heatmap overlay.
        colormap (int): The OpenCV colormap to apply to the heatmap.

    Returns:
        np.ndarray: The original image with the heatmap overlayed.
    """

    try:
        # Read and resize the original image
        image = cv2.imread(image_path)
        if image is None:
            raise FileNotFoundError(f"Could not read image at {image_path}")
        image = cv2.resize(image, (224, 224))

        # Resize the heatmap and apply the colormap
        heatmap_resized = cv2.resize(heatmap, (224, 224))
        heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_resized), colormap)

        # Blend the image and the heatmap
        overlayed_image = cv2.addWeighted(heatmap_colored, alpha, image, 1 - alpha, 0)

        return overlayed_image
    except Exception as e:
        print(f"❌ Error in overlay_heatmap: {e}")
        # Return the original image if an error occurs
        return cv2.imread(image_path)


# --- Class and Skin Group Count Plot (Interactive Version) ---
def plot_class_skin_group_counts_interactive(y, z, group_fn=bin_mst_to_skin_group, save_path=None):
    """Generates an interactive heatmap showing class distribution across skin groups."""
    combo_counts = Counter()
    for label, skin_vec in zip(y, z):
        if not isinstance(skin_vec, dict) or "MST" not in skin_vec:
            continue
        group = group_fn(skin_vec["MST"])
        combo_counts[(label, group)] += 1

    df = pd.DataFrame([{"Class": k[0], "Skin_Group": k[1], "Count": v} for k, v in combo_counts.items()])
    pivot = df.pivot(index="Class", columns="Skin_Group", values="Count").fillna(0)

    fig = px.imshow(pivot, labels=dict(x="Skin Group", y="Class", color="Count"),
                    color_continuous_scale='Viridis', title="Samples per (Class, Skin Group)")
    fig.update_layout(title="Samples per (Class, Skin Group)", autosize=True)

    if save_path:
        fig.write_html(save_path)
        logging.info(f"Interactive class-group heatmap saved to: {save_path}")
    else:
        fig.show()


# --- MST Distribution Plot ---
def plot_mst_distribution_by_class(y_true, mst_bins, class_names, save_path=None):
    """Plots MST distribution by class in a stacked bar chart."""
    try:
        df = pd.DataFrame({
            'Class': pd.Categorical([class_names[int(y)] for y in y_true], categories=class_names, ordered=True),
            'MST Bin': mst_bins
        })
        crosstab = pd.crosstab(df['Class'], df['MST Bin'], dropna=False)
        crosstab_pct = crosstab.div(crosstab.sum(axis=1), axis=0) * 100
        ax = crosstab_pct.plot(
            kind='bar', stacked=True, figsize=(18, 10),
            colormap='viridis', width=0.8
        )
        ax.set_title("MST Distribution by Class", fontsize=20, pad=20)
        ax.set_xlabel("Class", fontsize=14)
        ax.set_ylabel("Percentage", fontsize=14)
        ax.yaxis.set_major_formatter(mtick.PercentFormatter())
        ax.legend(title="MST Bin", bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=12)
        for container in ax.containers:
            labels = [f"{w:.1f}%" if w > 4 else "" for w in container.datavalues]
            ax.bar_label(container, labels=labels, label_type='center', color='white', weight='bold', fontsize=10)
        plt.tight_layout(rect=[0, 0, 0.88, 1])
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            logging.info(f"Saved MST distribution plot to: {save_path}")
        plt.close()
    except Exception as e:
        logging.error(f"Failed to generate MST distribution plot: {e}")


# --- Training Curves Plot ---
def plot_training_curves(history, save_dir, model_name=""):
    """Plots training and validation curves (loss/accuracy)."""
    try:
        prefix = f"{model_name}_" if model_name else ""
        epochs = range(1, len(history['train_loss']) + 1)

        # Loss Plot
        plt.figure(figsize=(8, 5))
        plt.plot(epochs, history['train_loss'], 'o-', label="Train Loss")
        plt.plot(epochs, history['val_loss'], 'o-', label="Val Loss")
        plt.title(f"{model_name} - Training & Validation Loss")
        plt.legend(); plt.grid(True)
        plt.savefig(os.path.join(save_dir, f"{prefix}loss_curve.png"))
        plt.close()

        # Accuracy Plot
        plt.figure(figsize=(8, 5))
        plt.plot(epochs, history['train_acc'], 'o-', label="Train Accuracy")
        plt.plot(epochs, history['val_acc'], 'o-', label="Val Accuracy")
        plt.title(f"{model_name} - Training & Validation Accuracy")
        plt.legend(); plt.grid(True)
        plt.savefig(os.path.join(save_dir, f"{prefix}accuracy_curve.png"))
        plt.close()
        logging.info(f"Training curves saved for: {model_name}")
    except Exception as e:
        logging.error(f"Failed to generate training curves: {e}")


def plot_evaluation_results(
    model_name: str,
    y_true: np.ndarray,
    y_pred: np.ndarray,
    y_probs: np.ndarray,
    confusion: np.ndarray,
    class_names: List[str],
    skin_vecs: List[Dict],
    mst_bins: List,
    skin_groups: List[str],
    output_dir: str,
    save_training_curves: bool = False,
    training_curves_data: Optional[Dict] = None,
    features: Optional[np.ndarray] = None
) -> None:
    """A master function to call all individual plotting utilities."""
    os.makedirs(output_dir, exist_ok=True)

    # --- Confusion Matrix ---
    try:
        cm_path = os.path.join(output_dir, f'{model_name}_confusion.png')
        plt.figure(figsize=(8, 6))
        sns.heatmap(confusion, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_names, yticklabels=class_names)
        plt.xlabel("Predicted Label"); plt.ylabel("True Label")
        plt.title(f"{model_name} Confusion Matrix")
        plt.savefig(cm_path, dpi=300, bbox_inches='tight')
        plt.close()
        logging.info(f"Confusion matrix saved to: {cm_path}")
    except Exception as e:
        logging.error(f"Confusion matrix plot failed: {e}")

    # --- Fairness Plot ---
    try:
        fairness_df = compute_fairness_by_group(y_true, y_probs, class_names, skin_groups=skin_groups)
        fairness_plot_path = os.path.join(output_dir, f"{model_name}_fairness.png")
        plt.figure(figsize=(10, 6))
        sns.set_style("whitegrid")
        sns.scatterplot(data=fairness_df, x="Accuracy", y="F1", hue="Skin Group", s=100)
        plt.title("Accuracy vs F1 Score by Skin Group")
        plt.xlim(0, 1); plt.ylim(0, 1); plt.grid(True)
        plt.savefig(fairness_plot_path, bbox_inches='tight')
        plt.close()
        logging.info(f"Fairness plot saved to: {fairness_plot_path}")
    except Exception as e:
        logging.error(f"Fairness plot failed: {e}")

    # --- MST Distribution ---
    mst_dist_path = os.path.join(output_dir, f"{model_name}_mst_distribution.png")
    plot_mst_distribution_by_class(y_true, mst_bins, class_names, save_path=mst_dist_path)

    # --- Training Curves ---
    if save_training_curves and training_curves_data:
        plot_training_curves(
            history=training_curves_data,
            save_dir=output_dir,
            model_name=model_name
        )

In [19]:
# @title Triplet Loss

class InBatchHardTripletLoss(nn.Module):
    """
    Computes triplet loss using hard positive and hard negative mining within the batch.
    Assumes anchor, positive, and negative are all drawn from the model's learned features (feat).
    """
    def __init__(self, margin=1.0, reduction='mean'):
        super(InBatchHardTripletLoss, self).__init__()
        self.margin = margin
        self.reduction = reduction

    def forward(self, features, labels):
        """
        Args:
            features (torch.Tensor): Tensor of shape (batch_size, embedding_dim)
                                     representing the model's learned features (feat).
            labels (torch.Tensor): Tensor of shape (batch_size,) representing the class labels.
        Returns:
            torch.Tensor: Scalar triplet loss.
        """
        if features.size(0) < 2:
            # Not enough samples in the batch to form triplets, return 0 loss
            return torch.tensor(0.0, device=features.device, requires_grad=True)

        # Calculate pairwise Euclidean distances
        # (a-b)^2 = a^2 + b^2 - 2ab
        dot_product = torch.matmul(features, features.transpose(0, 1))
        square_norm = torch.diag(dot_product) # Sum of squares of each vector
        distances = square_norm.unsqueeze(1) + square_norm.unsqueeze(0) - 2 * dot_product
        distances = torch.sqrt(F.relu(distances) + 1e-16) # Add epsilon for numerical stability

        # Initialize loss
        triplet_loss = torch.tensor(0.0, device=features.device)
        num_valid_triplets = 0

        for i in range(features.size(0)):
            anchor_feature = features[i]
            anchor_label = labels[i]

            # Find positive samples (same class as anchor, excluding anchor itself)
            positive_mask = (labels == anchor_label) & (torch.arange(features.size(0), device=features.device) != i)
            positive_distances = distances[i][positive_mask]

            # Find negative samples (different class from anchor)
            negative_mask = (labels != anchor_label)
            negative_distances = distances[i][negative_mask]

            if positive_distances.numel() > 0 and negative_distances.numel() > 0:
                # Hard positive mining: pick the farthest positive from anchor
                hard_positive_dist = torch.max(positive_distances)
                # Hard negative mining: pick the closest negative to anchor
                hard_negative_dist = torch.min(negative_distances)

                # Compute triplet loss for this anchor
                loss_i = F.relu(hard_positive_dist - hard_negative_dist + self.margin)
                if loss_i > 0: # Only accumulate if the loss is positive (i.e., violation)
                    triplet_loss += loss_i
                    num_valid_triplets += 1

        if num_valid_triplets == 0:
            # If no valid triplets could be formed in the batch, return 0 loss
            return torch.tensor(0.0, device=features.device, requires_grad=True)

        if self.reduction == 'mean':
            return triplet_loss / num_valid_triplets
        elif self.reduction == 'sum':
            return triplet_loss
        else: # reduction == 'none' or other
            return triplet_loss

In [20]:
# @title Evaluation Function

def evaluate_model(
    model, test_loader, device, label_encoder,
    save_dir, model_name="model", graph_dir=None,
    save_training_curves=False, training_curves_data=None, fold_classes=None,
    gradcam_layer=None, visualize_gradcam=False, max_gradcam_images=3,
    use_amp=False, plot_tsne_enabled=False  # Add plot_tsne_enabled flag
):
    """
    Evaluates the model on the test set and generates various evaluation metrics including accuracy,
    classification report, confusion matrix, and optional t-SNE plots.
    """
    ensure_dir(save_dir)
    if graph_dir:
        ensure_dir(graph_dir)

    model.eval()
    clear_all_forward_hooks(model)

    y_true, y_pred, y_probs = [], [], []
    all_mst_bins, all_skin_groups, all_raw_metadata = [], [], []

    print(f"\n--- Evaluating model: {model_name} ---")

    # Plain eval
    with torch.no_grad():
        features_list, labels_list = [], []  # Collect features for t-SNE

        for batch in test_loader:
            # Correctly unpack all 7 items returned by CustomDataset
            inputs, labels, skin_vecs, triplet_embeddings, mst_bins_batch, skin_groups_batch, raw_meta_batch = batch

            inputs = inputs.to(device)
            labels = labels.to(device)
            skin_vecs = skin_vecs.to(device)

            # Triplet embeddings should already be a tensor from CustomDataset
            triplet_embeddings = triplet_embeddings.to(device) if triplet_embeddings is not None else None

            # --- MODIFICATION: Pass return_features=True to get the actual features ---
            outputs, features_for_tsne = model(inputs, skin_vecs, triplet_embedding=triplet_embeddings, return_features=True)

            probs = F.softmax(outputs, dim=1)
            preds = probs.argmax(dim=1)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
            y_probs.extend(probs.cpu().numpy())
            all_mst_bins.extend(mst_bins_batch.numpy() if hasattr(mst_bins_batch, "numpy") else np.asarray(mst_bins_batch))
            all_skin_groups.extend(skin_groups_batch)
            all_raw_metadata.append(raw_meta_batch)  # Keep as-is for later use

            # Collect features for t-SNE (before classification)
            features_list.append(features_for_tsne.cpu().numpy()) # --- MODIFICATION: Use extracted features ---
            labels_list.append(labels.cpu().numpy())

    # Combine features and labels for t-SNE
    features = np.concatenate(features_list, axis=0)
    labels = np.concatenate(labels_list, axis=0)

    # Optional t-SNE plot
    if plot_tsne_enabled:
        plot_tsne(features, labels, label_encoder, graph_dir, model_name, fold_classes, attention_type="CBAM", save_plots=True)

    # Metrics/report
    y_true = np.array(y_true); y_pred = np.array(y_pred); y_probs = np.array(y_probs)
    cls_names = [label_encoder.inverse_transform([cls])[0] for cls in sorted(np.unique(y_true))]

    acc = accuracy_score(y_true, y_pred)
    report = classification_report(y_true, y_pred, target_names=cls_names, zero_division=0)
    cm = confusion_matrix(y_true, y_pred, labels=sorted(np.unique(y_true)))

    print(f"\nAccuracy: {acc * 100:.2f}%")
    print("\nClassification Report:\n", report)

    report_path = os.path.join(save_dir, f"{model_name}_report.txt")
    with open(report_path, "w") as f:
        f.write(f"Accuracy: {acc * 100:.2f}%\n\n{report}\n\nConfusion Matrix:\n{np.array2string(cm)}")
    print(f"✅ Evaluation report saved to: {report_path}")

    # Plotting entrypoint
    if graph_dir:
        try:
            plot_evaluation_results(
                model_name=model_name,
                y_true=y_true, y_pred=y_pred, y_probs=y_probs, confusion=cm,
                class_names=cls_names, skin_vecs=all_raw_metadata,
                mst_bins=all_mst_bins, skin_groups=all_skin_groups,
                output_dir=graph_dir,
                save_training_curves=save_training_curves,
                training_curves_data=training_curves_data
            )
            print(f"📊 Plots saved under: {graph_dir}")
        except Exception as e:
            print(f"⚠️ Plotting failed: {e}")

    return acc, report, cm


In [21]:
# @title Utilities

def get_output_channels(model_name):
    if model_name == "resnet18":
        return 512
    elif model_name in ["resnet50v2", "resnet101v2", "resnet101d", "resnet152d", "resnetrs101"]:
        return 2048
    elif model_name == "inception_v3": # <-- ADDED
        return 2048                  # <-- ADDED
    elif model_name in ["mobilenet_v2", "googlenet"]:
        return 1280
    elif model_name.startswith("vgg"):
        return 4096
    elif model_name == "alexnet":
        return 4096
    elif model_name == "densenet201":
        return 1920
    elif model_name.startswith("efficientnet_b"):
        tf_variant_map = {
            "efficientnet_b4": "tf_efficientnet_b4_ns",
            "efficientnet_b5": "tf_efficientnet_b5_ns",
            "efficientnet_b6": "tf_efficientnet_b6_ns",
            "efficientnet_b7": "tf_efficientnet_b7_ns",
        }
        tf_variant = tf_variant_map.get(model_name, model_name)
        backbone = timm.create_model(tf_variant, pretrained=True, num_classes=0)
        return backbone.num_features
    else:
        # Try loading with timm dynamically
        try:
            backbone = timm.create_model(model_name, pretrained=True, num_classes=0)
            return backbone.num_features
        except Exception:
            raise ValueError(f"Unknown model: {model_name}")


def compute_classwise_alpha(
    y_true,
    y_pred,
    num_classes=4,
    normalize=True,
    clip_range=(0.1, 3.0),
    prev_alpha=None,
    beta=0.9,
    smoothing=True
):
    """
    Compute smoothed, capped alpha weights for Focal Loss based on inverse recall.

    Args:
        y_true (array): Ground truth labels.
        y_pred (array): Predicted labels.
        num_classes (int): Number of classes.
        normalize (bool): Whether to normalize alpha to sum to num_classes.
        clip_range (tuple): Min and max values to clip alpha.
        prev_alpha (np.ndarray or torch.Tensor): Previous epoch's alpha for smoothing.
        beta (float): Smoothing factor for EMA.
        smoothing (bool): Whether to apply exponential smoothing.

    Returns:
        torch.Tensor: Alpha weights.
    """
    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
    recalls = cm.diagonal() / (cm.sum(axis=1) + 1e-6)  # Avoid division by zero
    alphas = 1.0 / (recalls + 1e-6)

    # Clip alpha to avoid extreme weights
    alphas = np.clip(alphas, clip_range[0], clip_range[1])

    # Smooth with EMA using previous alpha
    if smoothing and prev_alpha is not None:
        if isinstance(prev_alpha, torch.Tensor):
            prev_alpha = prev_alpha.detach().cpu().numpy()
        alphas = beta * prev_alpha + (1 - beta) * alphas


    # Normalize to keep total scale constant
    if normalize:
        alphas = alphas / alphas.sum() * num_classes

    print("🔍 Dynamic Alpha (inverse recall):", np.round(alphas, 4))
    return torch.tensor(alphas, dtype=torch.float32)


class GradualUnfreezer:
    """
    Gradually unfreezes backbone layers from the end towards the beginning (high-level to low-level features).
    """
    def __init__(self, model, base_lr=0.001, start_epoch=5, unfreeze_every=5, max_blocks=None, weight_decay=1e-4):
        self.model = model
        self.base_lr = base_lr
        self.unfreeze_every = unfreeze_every
        self.weight_decay = weight_decay

        # Freeze all backbone parameters initially
        for p in self.model.base.parameters():
            p.requires_grad = False

        # Break the backbone into its main sequential child modules
        self.children = list(model.base.children())
        self.total_blocks = len(self.children)
        self.max_blocks_to_unfreeze = max_blocks if max_blocks is not None else self.total_blocks

        # Start the pointer at the LAST block for backward unfreezing
        self.next_block_to_unfreeze = self.total_blocks - 1

        self.start_epoch = start_epoch
        self.next_unfreeze_epoch = self.start_epoch

        print(f"🧊 Backbone frozen: {self.total_blocks} total blocks.")
        print(f"📅 Strategy: Unfreeze from the last block backward (high-level features first).")
        print(f"   - Starting from block #{self.next_block_to_unfreeze} at epoch {self.start_epoch}.")
        print(f"   - Unfreezing one block every {self.unfreeze_every} epoch(s).")
        print(f"   - A maximum of {self.max_blocks_to_unfreeze} blocks will be unfrozen.")

    def step(self, optimizer: Optimizer, current_epoch: int):
        """
        This method should be called at the beginning of each training epoch.
        It now correctly uses 'self.next_block_to_unfreeze' for backward unfreezing.
        """
        # 1. Check if it's the right time to unfreeze
        if current_epoch < self.start_epoch or current_epoch < self.next_unfreeze_epoch:
            return

        # 2. Check if we have already unfrozen all possible blocks
        if self.next_block_to_unfreeze < 0:
            return

        # 3. Check if we have reached the user-defined limit of blocks to unfreeze
        unfrozen_count = (self.total_blocks - 1) - self.next_block_to_unfreeze
        if unfrozen_count >= self.max_blocks_to_unfreeze:
            return

        # 4. Get the block to unfreeze
        block_to_unfreeze = self.children[self.next_block_to_unfreeze]

        # 5. Find parameters in this block that are not already in the optimizer
        param_ids_in_optimizer = {id(p) for group in optimizer.param_groups for p in group['params']}
        new_params = [p for p in block_to_unfreeze.parameters() if id(p) not in param_ids_in_optimizer]

        if not new_params:
            print(f"Epoch {current_epoch}: Block {self.next_block_to_unfreeze} params are already in optimizer. Skipping.")
        else:
            print(f"🔥 Epoch {current_epoch}: Unfreezing backbone block {self.next_block_to_unfreeze}...")
            for param in new_params:
                param.requires_grad = True

            # Add the newly trainable parameters to the optimizer with a smaller learning rate
            '''optimizer.add_param_group({
                'params': new_params,
                'lr': self.base_lr / 10, # Use a smaller LR for fine-tuning
                'weight_decay': self.weight_decay
            })'''
            # Use a much smaller LR for the fine-tuned backbone layers
            low_lr = self.base_lr / 100

            print(f"  -> Added {len(new_params)} backbone params with low LR ({low_lr}).")
            optimizer.add_param_group({
                'params': new_params,
                'lr': low_lr,
                'weight_decay': self.weight_decay
            })

            print(f"  -> Added {len(new_params)} new parameters to the optimizer.")

        # 6. Decrement the pointer to the next block and schedule the next unfreeze event
        self.next_block_to_unfreeze -= 1
        self.next_unfreeze_epoch += self.unfreeze_every

class PostWarmupLRScheduler:
    def __init__(self, optimizer, base_lr=0.001, rise_epochs=3, weight_decay=1e-4):
        self.optimizer = optimizer
        self.base_lr = base_lr
        self.rise_epochs = rise_epochs
        self.epoch_count = 0
        self.weight_decay = weight_decay

    def step(self):
        if self.epoch_count >= self.rise_epochs:
            return

        new_lr = self.base_lr * (self.epoch_count + 1) / self.rise_epochs
        for i, group in enumerate(self.optimizer.param_groups):
            group['lr'] = new_lr
        self.epoch_count += 1

        print(f"📈 LR Increase: Set LR to {new_lr:.6f}")

class HybridLRScheduler:
    def __init__(
        self,
        optimizer,
        warmup_epochs,
        total_epochs,
        mode="plateau",  # ✅ 'cosine' or 'plateau'
        plateau_patience=5,
        plateau_factor=0.5,
        min_lr=1e-6
    ):
        assert mode in ["cosine", "plateau"], "mode must be 'cosine' or 'plateau'"
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.mode = mode
        self.plateau_patience = plateau_patience
        self.plateau_factor = plateau_factor
        self.min_lr = min_lr
        self.lr_history = []

        self.current_epoch = 0
        self.best_val_acc = 0
        self.epochs_since_improvement = 0

        # ✅ Store original learning rates
        self.initial_lr = [group['lr'] for group in optimizer.param_groups]

    def step(self, val_acc=None):
        lr = self.optimizer.param_groups[0]['lr']
        self.lr_history.append(lr)

        if self.current_epoch < self.warmup_epochs:
            # 🔼 Linear Warmup
            scale = (self.current_epoch + 1) / self.warmup_epochs
            for i, group in enumerate(self.optimizer.param_groups):
                base_lr = self.initial_lr[i] if i < len(self.initial_lr) else group['lr']
                group['lr'] = base_lr * scale

        elif self.mode == "cosine":
            # 🌀 Cosine Annealing
            progress = (self.current_epoch - self.warmup_epochs) / max(1, self.total_epochs - self.warmup_epochs)
            for i, group in enumerate(self.optimizer.param_groups):
                base_lr = self.initial_lr[i] if i < len(self.initial_lr) else group['lr']
                new_lr = self.min_lr + 0.5 * (base_lr - self.min_lr) * (1 + math.cos(math.pi * progress))
                group['lr'] = new_lr

        elif self.mode == "plateau":
            # 📉 Reduce LR on Plateau
            if val_acc is not None:
                if val_acc > self.best_val_acc:
                    self.best_val_acc = val_acc
                    self.epochs_since_improvement = 0
                else:
                    self.epochs_since_improvement += 1
                    if self.epochs_since_improvement >= self.plateau_patience:
                        for i, group in enumerate(self.optimizer.param_groups):
                            new_lr = max(group['lr'] * self.plateau_factor, self.min_lr)
                            group['lr'] = new_lr
                            print(f"Plateau: Reducing LR group {i} to {new_lr:.6f}")
                        self.epochs_since_improvement = 0

        self.current_epoch += 1

    def get_lr(self):
        return [group['lr'] for group in self.optimizer.param_groups]

def setup_directories(base_path, model_name, fold=None, attention_type=None):
    """
    Create directory structure:
    base_path/fold_{N}_{model_name}_{attention_type}/[checkpoints, weights, graphs, predictions]
    """
    attention_str = str(attention_type).lower() if attention_type else "none"
    fold_str = f"fold_{fold}" if fold is not None else "fold_None"
    tag = f"{fold_str}_{model_name}_{attention_str}"

    model_root = os.path.join(base_path, tag)  # ✅ Use tag as subdirectory

    checkpoint_dir = os.path.join(model_root, "checkpoints")
    weights_dir = os.path.join(model_root, "weights")
    graph_dir = os.path.join(model_root, "graphs")
    predictions_dir = os.path.join(model_root, "predictions")

    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(weights_dir, exist_ok=True)
    os.makedirs(graph_dir, exist_ok=True)
    os.makedirs(predictions_dir, exist_ok=True)

    checkpoint_path = os.path.join(checkpoint_dir, f"{tag}_checkpoint.pth")
    best_weights_path = os.path.join(weights_dir, f"{tag}_best.pth")

    return checkpoint_path, best_weights_path, graph_dir, predictions_dir

def compute_class_mst_alpha_matrix(y_true, y_pred, mst_bins, num_classes=7, num_mst_bins=10, normalize=True):

    alpha_matrix = np.ones((num_classes, num_mst_bins), dtype=np.float32)

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    mst_bins = np.array(mst_bins)

    for cls in range(num_classes):
        for mst in range(num_mst_bins):
            mask = (y_true == cls) & (mst_bins == mst)
            if mask.sum() == 0:
                alpha_matrix[cls, mst] = 1.0  # fallback
                continue

            y_true_subset = y_true[mask]
            y_pred_subset = y_pred[mask]

            recall = np.sum(y_pred_subset == cls) / (len(y_true_subset) + 1e-6)
            alpha_matrix[cls, mst] = 1.0 / (recall + 1e-6)  # inverse recall

    if normalize:
        # Normalize each row (per class) to sum to num_mst_bins
        alpha_matrix = alpha_matrix / alpha_matrix.sum(axis=1, keepdims=True) * num_mst_bins

    print("📊 Alpha Matrix (class × MST):")
    print(np.round(alpha_matrix, 2))

    return torch.tensor(alpha_matrix, dtype=torch.float32)

def freeze_backbone(model):
    for param in model.base.parameters():
        param.requires_grad = False
    #print("🧊 Backbone frozen.")

def safe_criterion_call(criterion, outputs, labels, mst_groups=None):
    try:
        return criterion(outputs, labels, mst_groups)
    except TypeError:
        return criterion(outputs, labels)

def plot_alpha_trends(alpha_history, num_classes, save_path=None):
    alpha_array = torch.stack(alpha_history).cpu().numpy()
    for class_idx in range(num_classes):
        plt.plot(alpha_array[:, class_idx], label=f"Class {class_idx}")
    plt.xlabel("Epoch")
    plt.ylabel("Alpha Weight")
    plt.title("Focal Loss Alpha Trend (Inverse Recall)")
    plt.legend()
    plt.grid(True)

    if save_path:
        plt.savefig(save_path)
        print(f"📊 Alpha plot saved to: {save_path}")
    else:
        plt.show()

def compute_fairness_by_group(y_true, y_probs, class_names, skin_groups=None):
    y_preds = np.argmax(y_probs, axis=1)
    results = []
    if skin_groups is None:
        skin_groups = ['unknown'] * len(y_true)
    unique_groups = sorted(set(skin_groups))
    for group in unique_groups:
        indices = [i for i, g in enumerate(skin_groups) if g == group]
        if not indices:
            continue
        group_y_true = [y_true[i] for i in indices]
        group_y_pred = [y_preds[i] for i in indices]
        results.append({
            "Skin Group": group,
            "Accuracy": accuracy_score(group_y_true, group_y_pred),
            "Precision": precision_score(group_y_true, group_y_pred, average='macro', zero_division=0),
            "Recall": recall_score(group_y_true, group_y_pred, average='macro', zero_division=0),
            "F1": f1_score(group_y_true, group_y_pred, average='macro', zero_division=0),
        })
    return pd.DataFrame(results)


In [22]:

# @title MixUp utilities

def soft_cross_entropy(pred, soft_targets):
    # pred: [B, C] logits; soft_targets: [B, C] probs (rows sum to 1)
    log_probs = F.log_softmax(pred, dim=1)
    return -(soft_targets * log_probs).sum(dim=1).mean()

# ---------- MixUp criterion (soft targets) ----------
def mixup_criterion(pred, y_a, y_b, lam, num_classes):
    # y_a, y_b: class indices [B]
    y_a = F.one_hot(y_a.long(), num_classes=num_classes).float()
    y_b = F.one_hot(y_b.long(), num_classes=num_classes).float()
    soft_targets = lam * y_a + (1 - lam) * y_b
    return soft_cross_entropy(pred, soft_targets)

# ---------- Safe MixUp ----------
def _to_scalar_alpha(alpha, default=0.4):
    """
    Accepts float/int/0-d tensor/1-d tensor/ndarray; returns a safe positive float.
    - If tensor/array with >1 elem (e.g., class weights), use its mean as the scalar alpha.
    - Clamps to a small positive to avoid Beta errors.
    """
    if alpha is None:
        return float(default)
    try:
        # turn anything into a tensor on CPU, flatten, take mean -> scalar
        a = torch.as_tensor(alpha).detach().float().mean().item()
    except Exception:
        try:
            # last resort
            a = float(alpha)
        except Exception:
            a = float(default)
    # clamp to safe range
    return float(max(a, 1e-6))

def mixup_data(x, y, skin_vec=None, alpha=0.4, epoch=0, warmup_epochs=5, lam_clip=(0.3, 0.7)):
    """
    x: [B, C, H, W]
    y: [B]
    skin_vec: optional [B, D] or None (if your model ignores it)
    alpha: can be scalar or any tensor/array (e.g., class-weight vector) — reduced to a scalar
    Returns: mixed_x, mixed_skin (or None), y_a, y_b, lam(float)
    """
    if x.ndim != 4:
        raise ValueError(f"Expected x [B,C,H,W], got {tuple(x.shape)}")
    if skin_vec is not None and skin_vec.ndim != 2:
        raise ValueError(f"Expected skin_vec [B,D] or None, got {None if skin_vec is None else tuple(skin_vec.shape)}")

    B = x.size(0)
    if B < 2:
        # can't permute a single sample; no-op
        lam = 1.0
        index = torch.arange(B, device=x.device)
    else:
        if epoch < int(warmup_epochs):
            lam = 1.0
        else:
            a = _to_scalar_alpha(alpha, default=0.4)
            # Beta is undefined for non-positive alpha
            if a <= 0.0:
                lam = 1.0
            else:
                lam = np.random.beta(a, a)
                if lam_clip is not None:
                    lo, hi = lam_clip
                    lam = float(np.clip(lam, lo, hi))
        index = torch.randperm(B, device=x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    mixed_skin = None
    if skin_vec is not None:
        mixed_skin = lam * skin_vec + (1 - lam) * skin_vec[index]

    y_a, y_b = y, y[index]
    return mixed_x, mixed_skin, y_a, y_b, float(lam)


def compute_classwise_alpha(
    y_true,
    y_pred,
    num_classes=4,
    normalize=True,
    clip_range=(0.1, 3.0),
    prev_alpha=None,
    beta=0.9,
    smoothing=True
):
    """
    Compute smoothed, capped alpha weights for Focal Loss based on inverse recall.

    Args:
        y_true (array): Ground truth labels.
        y_pred (array): Predicted labels.
        num_classes (int): Number of classes.
        normalize (bool): Whether to normalize alpha to sum to num_classes.
        clip_range (tuple): Min and max values to clip alpha.
        prev_alpha (np.ndarray or torch.Tensor): Previous epoch's alpha for smoothing.
        beta (float): Smoothing factor for EMA.
        smoothing (bool): Whether to apply exponential smoothing.

    Returns:
        torch.Tensor: Alpha weights.
    """
    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
    recalls = cm.diagonal() / (cm.sum(axis=1) + 1e-6)  # Avoid division by zero
    alphas = 1.0 / (recalls + 1e-6)

    # Clip alpha to avoid extreme weights
    alphas = np.clip(alphas, clip_range[0], clip_range[1])

    # Smooth with EMA using previous alpha
    if smoothing and prev_alpha is not None:
        if isinstance(prev_alpha, torch.Tensor):
            prev_alpha = prev_alpha.detach().cpu().numpy()
        alphas = beta * prev_alpha + (1 - beta) * alphas


    # Normalize to keep total scale constant
    if normalize:
        alphas = alphas / alphas.sum() * num_classes

    print("🔍 Dynamic Alpha (inverse recall):", np.round(alphas, 4))
    return torch.tensor(alphas, dtype=torch.float32)

In [23]:
# @title Get Model Function

def get_model_with_attention(model_name, num_classes, attention_type="none", pretrained=True,
                             fold=None, weights_root=None, resume=True, **kwargs):
    # Extract any additional configuration arguments passed through kwargs
    use_film_before = kwargs.get("use_film_before", False)
    use_film_in_cbam = kwargs.get("use_film_in_cbam", False)
    use_triplet_embedding = kwargs.get("use_triplet_embedding", False)
    triplet_embedding_dim = kwargs.get("triplet_embedding_dim", 512)
    include_skin_vec = kwargs.get("include_skin_vec", True)
    drop_path_rate = kwargs.get("drop_path_rate", 0.2)
    fusion_mode = kwargs.get("fusion_mode", "concat")  # Default to "concat" if not passed
    fusion_hidden_dim = kwargs.get("fusion_hidden_dim", 128)  # Default to 128 if not passed

    print(f"Fusion hidden dim: {fusion_hidden_dim}")  # Debugging line to check the value

    # EfficientNet models
    if model_name.startswith("efficientnet_b"):
        model = EfficientNetWithAttention(
            num_classes=num_classes,
            attention_type=attention_type,
            pretrained=pretrained,
            use_film_before=use_film_before,
            use_film_in_cbam=use_film_in_cbam,
            use_triplet_embedding=use_triplet_embedding,
            triplet_embedding_dim=triplet_embedding_dim,
            include_skin_vec=include_skin_vec,
            efficientnet_variant=model_name,
            fusion_mode=fusion_mode,  # Pass fusion_mode to EfficientNetWithAttention
            fusion_hidden_dim=fusion_hidden_dim  # Pass fusion_hidden_dim to EfficientNetWithAttention
        )

    # ResNet models
    elif model_name in ["resnet101v2", "resnet101d", "resnet152d", "resnetrs101"]:
        model = ResNetWithAttention(
            num_classes=num_classes,
            backbone_name=model_name,
            attention_type=attention_type,
            drop_path_rate=drop_path_rate,
            use_film_before=use_film_before,
            use_film_in_cbam=use_film_in_cbam,
            use_triplet_embedding=use_triplet_embedding,
            triplet_embedding_dim=triplet_embedding_dim,
            include_skin_vec=include_skin_vec,
            fusion_mode=fusion_mode,  # Pass fusion_mode to ResNetWithAttention
            fusion_hidden_dim=fusion_hidden_dim  # Pass fusion_hidden_dim to ResNetWithAttention
        )

    # InceptionV3 model
    elif model_name.lower() in ["inceptionv3", "inception_v3"]:
        model = InceptionV3WithAttention(
            num_classes=num_classes,
            attention_type=attention_type,
            pretrained=pretrained,
            use_film_before=use_film_before,
            use_film_in_cbam=use_film_in_cbam,
            use_triplet_embedding=use_triplet_embedding,
            triplet_embedding_dim=triplet_embedding_dim,
            include_skin_vec=include_skin_vec,
            fusion_mode=fusion_mode,  # Pass fusion_mode to InceptionV3WithAttention
            fusion_hidden_dim=fusion_hidden_dim  # Pass fusion_hidden_dim to InceptionV3WithAttention
        )

    else:
        raise ValueError(f"Unsupported model: {model_name}")

    return model

In [24]:
# @title Local Train

def local_train(
    train_loader, model, device, num_epochs=10, lr=0.003,
    val_loader=None, save_model_path=None, model_name="model",
    fold=None, resume_path=None, alpha=0.2, mixup_enabled=True,
    warmup_epochs=4, num_classes=4, attention_type="none",
    log_lr_each_epoch=True, y_train=None, use_gradcam=False, # Add use_gradcam flag
    triplet_loss_weight=0.1 # --- NEW PARAMETER FOR TRIPLET LOSS WEIGHT ---
):
    """
    Executes a local training loop with a clean tqdm progress bar and full checkpoint/resume logic.
    """

    # === Corrected Optimizer Setup ===

    # 1. Get all "new" parameters that are NOT in the backbone
    new_params = [
        {'params': model.classifier.parameters(), 'lr': lr, 'weight_decay': 1e-4}
    ]

    # Dynamically add other new parts if they exist
    if hasattr(model, 'skin_mlp'):
        new_params.append({'params': model.skin_mlp.parameters(), 'lr': lr, 'weight_decay': 1e-4})

    if hasattr(model, 'image_proj'):
        new_params.append({'params': model.image_proj.parameters(), 'lr': lr, 'weight_decay': 1e-4})

    if hasattr(model, 'skin_proj'):
        new_params.append({'params': model.skin_proj.parameters(), 'lr': lr, 'weight_decay': 1e-4})

    if hasattr(model, 'triplet_proj'):
        new_params.append({'params': model.triplet_proj.parameters(), 'lr': lr, 'weight_decay': 1e-4})

    if hasattr(model, 'gate'):
        new_params.append({'params': model.gate.parameters(), 'lr': lr, 'weight_decay': 1e-4})

    if hasattr(model, 'attn') and not isinstance(model.attn, nn.Identity):
         new_params.append({'params': model.attn.parameters(), 'lr': lr, 'weight_decay': 1e-4})

    if hasattr(model, 'film') and not isinstance(model.film, nn.Identity):
         new_params.append({'params': model.film.parameters(), 'lr': lr, 'weight_decay': 1e-4})

    print(f"✅ Optimizing {len(new_params)} new parameter groups with high LR ({lr}).")

    # 2. Create the optimizer *only* with these new parameter groups
    optimizer = torch.optim.AdamW(new_params)

    if hasattr(model, 'attn') and not isinstance(model.attn, nn.Identity):
        # Check if 'attn' parameters are already added
        # Iterate over parameters of model.attn and check if any of them are already in optimizer
        for param in model.attn.parameters():
            if not any(param is group_param for group in optimizer.param_groups for group_param in group['params']):
                optimizer.add_param_group({'params': model.attn.parameters(), 'lr': lr, 'weight_decay': 1e-4})
                break  # Add param group only once


    # History lists for training and validation metrics
    train_loss_history, val_loss_history, train_acc_history, val_acc_history, lrs_history = [], [], [], [], []

    # Class weights for loss function
    class_weights_np = compute_class_weight(class_weight="balanced", classes=np.unique(y_train), y=y_train)
    weights = torch.tensor(class_weights_np, dtype=torch.float, device=device)
    criterion = nn.CrossEntropyLoss(weight=weights, label_smoothing=0.1)

    # --- NEW: Triplet Loss Criterion ---
    triplet_criterion = InBatchHardTripletLoss(margin=1.0)

    # Learning rate scheduler and utilities
    scheduler = HybridLRScheduler(optimizer, warmup_epochs=warmup_epochs, total_epochs=num_epochs, mode='cosine', min_lr=1e-6)

    # --- START FIX 1: AMP and Scaler Setup ---
    scaler = torch.amp.GradScaler() if device.type == "cuda" else None
    use_amp = (scaler is not None)
    # This context will automatically handle both CPU/GPU and enabled/disabled states
    autocast_context = torch.amp.autocast(device_type=device.type, enabled=use_amp)
    # --- END FIX 1 ---

    checkpoint_path, best_weights_path, _, _ = setup_directories(
        base_path=save_model_path, model_name=model_name, fold=fold, attention_type=attention_type or "none"
    )

    # Initial state variables for early stopping
    best_val_accuracy = 0.0
    best_model_weights = copy.deepcopy(model.state_dict())
    early_stop_counter = 0
    early_stop_patience = 30
    start_epoch = 0

    # Checkpoint resume logic
    if resume_path and os.path.isfile(resume_path):
        try:
            print(f"Attempting to resume training from checkpoint: {resume_path}")
            # --- START FIX 2: Added weights_only=False ---
            checkpoint = torch.load(resume_path, map_location=device, weights_only=False)
            # --- END FIX 2 ---
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            best_val_accuracy = checkpoint.get("best_val_accuracy", 0.0)
            start_epoch = checkpoint.get("epoch", 0) + 1
            print(f"Successfully resumed training from checkpoint at Epoch {start_epoch}")
        except Exception as e:
            print(f"⚠️ Failed to load checkpoint from {resume_path}: {e}")
            print("Starting training from scratch (Epoch 0) instead.")
            best_val_accuracy, start_epoch, early_stop_counter = 0.0, 0, 0
    else:
        print(f"No valid checkpoint found. Starting training from scratch (Epoch 0).")

    # Backbone freezing for warmup
    if start_epoch < warmup_epochs:
        freeze_backbone(model)

    # Gradual unfreezing setup
    unfreezer = GradualUnfreezer(model, base_lr=lr, start_epoch=warmup_epochs, unfreeze_every=2, max_blocks=None, weight_decay=1e-4)
    lr_riser = PostWarmupLRScheduler(optimizer, base_lr=lr, rise_epochs=3)

    # Training loop with TQDM progress bar
    epoch_pbar = tqdm(range(start_epoch, num_epochs), desc=f"Training {model_name}")


    for epoch in epoch_pbar:
        if hasattr(train_loader.dataset, 'set_epoch'):
            train_loader.dataset.set_epoch(epoch)

        unfreezer.step(optimizer, epoch + 1)
        lr_riser.step()
        lrs_history.append(scheduler.get_lr()[0])

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Training phase
        model.train()
        total_loss, correct, total = 0.0, 0, 0

        for batch in train_loader:
            # Unpack the 7 items from the dataset
            images, labels, skin_vecs, triplet, mst_bin, skin_group, metadata = batch
            images, labels, skin_vecs, triplet = images.to(device), labels.to(device), skin_vecs.to(device), triplet.to(device)

            # Mixup logic
            use_mixup = False
            if mixup_enabled:
                mix_ratio = max(1.0 - epoch / num_epochs, 0.1)
                images, skin_vecs, y_a, y_b, lam = mixup_data(images, labels, skin_vecs, weights * mix_ratio, epoch, warmup_epochs)
                use_mixup = True
            else:
                y_a, y_b, lam = labels, labels, 1.0

            optimizer.zero_grad()

            # --- START FIX 3: Conditional backward pass ---
            try:
                # Forward pass with autocast
                with autocast_context:
                    # --- MODIFICATION: Get features for Triplet Loss ---
                    out, feat = model(images, skin_vec=skin_vecs, triplet_embedding=triplet, return_features=True)
                    classification_loss = mixup_criterion(out, y_a, y_b, lam, num_classes=num_classes) if use_mixup else criterion(out, labels)

                    # --- NEW: Calculate Triplet Loss ---
                    t_loss = triplet_criterion(feat, labels)
                    loss = classification_loss + triplet_loss_weight * t_loss


                # Backward pass
                if use_amp:
                    # CUDA / AMP path
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    # Standard CPU path
                    loss.backward()
                    optimizer.step()

            except Exception as e:
                print(f"🚨 STEP FAILURE: {e}")
                traceback.print_exc()
                raise
            # --- END FIX 3 ---

            total_loss += loss.item() * labels.size(0)
            correct += (out.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)

        # Handle case where train_loader was empty (e.g., drop_last=True)
        if total == 0:
            print(f"⚠️ Epoch {epoch}: No data processed in training loop. Check DataLoader and batch_size.")
            avg_train_loss, train_acc = 0.0, 0.0
        else:
            avg_train_loss, train_acc = total_loss / total, correct / total

        train_loss_history.append(avg_train_loss)
        train_acc_history.append(train_acc)

        # Validation phase
        avg_val_loss, val_acc = 0.0, 0.0
        if val_loader:
            model.eval()
            val_correct, val_total, val_loss_total = 0, 0, 0.0
            with torch.no_grad():
                for batch in val_loader:
                    # Unpack the 7 items for validation
                    v_images, v_labels, v_skin, v_triplet, v_mst_bin, v_skin_group, v_metadata = batch
                    v_images, v_labels, v_skin, v_triplet = v_images.to(device), v_labels.to(device), v_skin.to(device), v_triplet.to(device)

                    # Use autocast context
                    with autocast_context:
                        # --- MODIFICATION: Get features for Triplet Loss (not used in val loss, but consistency) ---
                        out, feat = model(v_images, skin_vec=v_skin, triplet_embedding=v_triplet, return_features=True)
                        val_classification_loss = criterion(out, v_labels)

                        # --- NEW: Calculate Triplet Loss for validation (for logging, not backprop) ---
                        val_t_loss = triplet_criterion(feat, v_labels)
                        val_loss = val_classification_loss + triplet_loss_weight * val_t_loss

                    val_loss_total += val_loss.item() * v_labels.size(0)
                    val_correct += (out.argmax(dim=1) == v_labels).sum().item()
                    val_total += v_labels.size(0)

            if val_total == 0:
                print(f"⚠️ Epoch {epoch}: No data processed in validation loop.")
                avg_val_loss, val_acc = 0.0, 0.0
            else:
                avg_val_loss, val_acc = val_loss_total / val_total, val_correct / val_total

            val_loss_history.append(avg_val_loss)
            val_acc_history.append(val_acc)
            scheduler.step(val_acc)

        # Update the progress bar
        epoch_pbar.set_postfix(train_loss=f"{avg_train_loss:.4f}", train_acc=f"{train_acc:.4f}", val_loss=f"{avg_val_loss:.4f}", val_acc=f"{val_acc:.4f}", lr=f"{scheduler.get_lr()[0]:.1e}")

        # Early stopping & checkpoint logic
        if val_acc > best_val_accuracy:
            best_val_accuracy = val_acc
            best_model_weights = copy.deepcopy(model.state_dict())
            early_stop_counter = 0
        else:
            early_stop_counter += 1
            if early_stop_counter >= early_stop_patience:
                print(f"\nEarly stopping triggered after {early_stop_patience} epochs with no improvement.")
                break

        if checkpoint_path:
            # Ensure the parent directory exists right before saving
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            torch.save({"epoch": epoch, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "best_val_accuracy": best_val_accuracy}, checkpoint_path)

    # End of training loop
    model.load_state_dict(best_model_weights)
    if best_weights_path:
        torch.save(best_model_weights, best_weights_path)
        print(f"\n✅ Saved best model weights to: {best_weights_path}")

    # Don't try to empty cache if CUDA is not available
    if device.type == "cuda":
        torch.cuda.empty_cache()
    gc.collect()

    return model, {"train_loss": train_loss_history, "val_loss": val_loss_history, "train_acc": train_acc_history, "val_acc": val_acc_history, "lrs": lrs_history}


In [25]:
# @title K-Fold

'''
import os
import torch
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
'''

def kfold_cross_validation(
    X, y, z, label_encoder, model_names, attention_types, num_classes,
    transform, num_folds=5, num_epochs=10, batch_size=64,
    save_root=RESULTS_DIR, triplet_embedding_dict=None, val_size=0.3,
    triplet_loss_weight=0.1 # --- ADDED: Triplet Loss Weight parameter ---
):
    device = DEVICE
    seed = np.random.randint(0, 99999)
    print(f"Using random_state = {seed} for this k-fold trial")

    splitter = StratifiedShuffleSplit(n_splits=num_folds, test_size=val_size, random_state=seed)

    for fold_idx, (train_idx, val_idx) in enumerate(splitter.split(X, y)):
        current_fold_num = fold_idx + 1
        print(f"\n℀ Fold {current_fold_num}/{num_folds}")

        X_tr, y_tr_orig, z_tr = [X[i] for i in train_idx], [y[i] for i in train_idx], [z[i] for i in train_idx]
        X_val, y_val_orig, z_val = [X[i] for i in val_idx], [y[i] for i in val_idx], [z[i] for i in val_idx]
        print(f"Train set size: {len(X_tr)} | Val set size: {len(X_val)}")

        fold_classes = sorted(set(y_tr_orig + y_val_orig))
        fold_num_classes = len(fold_classes)
        class_mapping = {label: idx for idx, label in enumerate(fold_classes)}
        y_tr = [class_mapping[lbl] for lbl in y_tr_orig]
        y_val = [class_mapping[lbl] for lbl in y_val_orig]

        print(f"Fold {current_fold_num} Classes: {fold_classes} → Remapped to: {list(class_mapping.values())}")
        print(f"ðŸ’‰ Fold {current_fold_num} Class Distribution (Train): {dict(Counter(y_tr))}")
        print(f"ðŸ’‰ Fold {current_fold_num} Class Distribution (Val): {dict(Counter(y_val))}")

        CLASS_POLICY_MAP = {
            0: "standard_transform",
            1: "standard_transform",
            2: "standard_transform",
            3: "aggressive_transform",
            4: "standard_transform",
            5: "aggressive_transform",
            6: "standard_transform",
        }

        train_dataset = CustomDataset(
            image_paths=X_tr,
            labels=y_tr,
            metadata=z_tr,
            include_skin_vec=True,
            triplet_embedding_dict=triplet_embedding_dict,
            class_policy_map=CLASS_POLICY_MAP, # Pass class_policy_map for training
            num_classes=fold_num_classes
        )

        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            # num_workers=4,
            num_workers=0,
            pin_memory=True,
            drop_last=False
        )

        val_dataset = CustomDataset(
            image_paths=X_val,
            labels=y_val,
            metadata=z_val,
            include_skin_vec=True,
            triplet_embedding_dict=triplet_embedding_dict,
            transform_name="standard_transform", # Pass transform_name for validation
            num_classes=fold_num_classes
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            #num_workers=4,
            num_workers=0,
            pin_memory=True
        )

        train_dataset.set_epoch(0)

        print(f"Train dataset size: {len(train_dataset)}")
        print(f"Val dataset size: {len(val_dataset)}")

        for attn_type in attention_types:
            for model_name in model_names:
                run_name = f"{model_name}_{attn_type}"
                print(f"\nTraining: {run_name.upper()} — Fold {current_fold_num}")

                checkpoint_path, best_weights_path, graph_dir, predictions_dir = setup_directories(
                    base_path=save_root,
                    model_name=model_name,
                    fold=current_fold_num,
                    attention_type=attn_type
                )

                try:
                    start_time = time.time()
                    model = get_model_with_attention(
                        model_name=model_name, num_classes=fold_num_classes, attention_type=attn_type,
                        pretrained=True, fold=current_fold_num, weights_root=save_root, resume=True,
                        use_film_before=True, use_film_in_cbam=True, use_triplet_embedding=True,
                        triplet_embedding_dim=512, fusion_mode="concat"
                    ).to(device)
                    #print(f"Before local train fusion_mode: {model.fusion_mode}, fusion_hidden_dim: {model.fusion_hidden_dim}")
                    model, training_history_data = local_train(
                        train_loader=train_loader, model=model, device=device, num_epochs=num_epochs,
                        lr=0.001, val_loader=val_loader, save_model_path=save_root,
                        model_name=model_name, fold=current_fold_num, resume_path=checkpoint_path,
                        alpha=0.3, mixup_enabled=True, warmup_epochs=5,
                        num_classes=fold_num_classes, attention_type=attn_type, y_train=y_tr,
                        triplet_loss_weight=triplet_loss_weight # --- ADDED: Pass triplet_loss_weight ---
                    )
                    #print(f"In local train fusion_mode: {model.fusion_mode}, fusion_hidden_dim: {model.fusion_hidden_dim}")
                    gradcam_layer = get_gradcam_layer(model, model_name)
                    evaluate_model(
                        model=model, test_loader=val_loader, device=device,
                        label_encoder=label_encoder,
                        save_dir=predictions_dir,
                        model_name=f"{model_name}_{attn_type}_fold{current_fold_num}",
                        visualize_gradcam=True, gradcam_layer=gradcam_layer, graph_dir=graph_dir,
                        save_training_curves=True, training_curves_data=training_history_data,
                        fold_classes=fold_classes,
                        plot_tsne_enabled=True
                    )
                    elapsed = time.time() - start_time
                    print(f"Training and evaluation time for {run_name.upper()} — Fold {current_fold_num}: {elapsed:.2f} seconds")

                except Exception as e:
                    print(f"Error — Skipping {run_name.upper()} (Fold {current_fold_num}): {e}")
                    import traceback
                    traceback.print_exc()

                finally:
                    if 'model' in locals():
                        del model
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    gc.collect()

In [None]:
# @title Main

if __name__ == "__main__":
    # === 🔧 Configuration ===
    train_dataset_dir = BASE_DATASET_DIR
    save_model_root = RESULTS_DIR
    triplet_path = EMBED_NPY

    num_epochs = 50
    batch_size = 64
    n_splits = 1
    model_names = ["efficientnet_b0","InceptionV3", "resnet152d"] #, "InceptionV3", "resnet152d"
    attention_types = ["cbam"]

    # --- Triplet Loss Weight ---
    triplet_loss_weight = 0.1 # Adjust this weight as needed

    # === 📂 1. Load and Process Training Data ===
    print("--- Loading and Processing Training Data ---")
    X_train_paths, y_train_labels = load_img_from_dir(train_dataset_dir, max_images_per_class=2500)
    if not X_train_paths:
        raise RuntimeError(f"No training images found in: {train_dataset_dir}")

    triplet_embedding_dict = torch.load(triplet_path, map_location='cpu')

    # Create the final data lists in a single, efficient loop
    X_train, y_train, z_train = [], [], []
    for path, label in zip(X_train_paths, y_train_labels):
        if os.path.basename(path).lower() in triplet_embedding_dict:
            color_metrics = extract_color_metrics_and_estimate_mst(path)
            if color_metrics and 1 <= color_metrics.get("MST", 0) <= 10:
                X_train.append(path)
                y_train.append(label)
                z_train.append(color_metrics)

    print(f"Total usable training images: {len(X_train)}")

    # === Encode Labels ===
    label_encoder = LabelEncoder()
    y_train_encoded = label_encoder.fit_transform(y_train)
    num_classes = len(label_encoder.classes_)

    # Calculate dynamic target counts based on a percentage of the largest (class, MST) group size
    #dynamic_target_counts = calculate_dynamic_target_counts(y_train_encoded, z_train, oversample_percentage=1.2)

    # Balance the dataset using dynamic target counts
    #X_train, y_train, z_train = balance_data_to_targets(X_train, y_train_encoded, z_train, dynamic_target_counts)

    # === 🔁 3. Run K-fold Cross-Validation on the Training Set ===
    kfold_cross_validation(
        X=X_train,
        y=y_train,
        z=z_train,
        label_encoder=label_encoder,
        model_names=model_names,
        attention_types=attention_types,
        num_classes=num_classes,
        transform=None,
        num_folds=n_splits,
        num_epochs=num_epochs,
        batch_size=batch_size,
        save_root=save_model_root,
        triplet_embedding_dict=triplet_embedding_dict,
        triplet_loss_weight=triplet_loss_weight
    )

    print("\n✅ K-fold cross-validation and graph generation complete.")

--- Loading and Processing Training Data ---
✅ Loaded 17500 pre-filtered images.
Total usable training images: 17500
Using random_state = 32395 for this k-fold trial

℀ Fold 1/1
Train set size: 12250 | Val set size: 5250
Fold 1 Classes: ['Black', 'East Asian', 'Indian', 'Latino_Hispanic', 'Middle Eastern', 'Southeast Asian', 'White'] → Remapped to: [0, 1, 2, 3, 4, 5, 6]
ðŸ’‰ Fold 1 Class Distribution (Train): {6: 1750, 1: 1750, 5: 1750, 2: 1750, 4: 1750, 3: 1750, 0: 1750}
ðŸ’‰ Fold 1 Class Distribution (Val): {4: 750, 6: 750, 3: 750, 0: 750, 2: 750, 1: 750, 5: 750}
Pre-processing and caching dataset metadata...


100%|██████████| 12250/12250 [00:00<00:00, 105976.63it/s]


Pre-processing and caching dataset metadata...


100%|██████████| 5250/5250 [00:00<00:00, 101937.81it/s]


Train dataset size: 12250
Val dataset size: 5250

Training: EFFICIENTNET_B0_CBAM — Fold 1
Fusion hidden dim: 128
Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

Feature dimension (C): 1280
✅ Optimizing 4 new parameter groups with high LR (0.001).
No valid checkpoint found. Starting training from scratch (Epoch 0).
🧊 Backbone frozen: 7 total blocks.
📅 Strategy: Unfreeze from the last block backward (high-level features first).
   - Starting from block #6 at epoch 5.
   - Unfreezing one block every 2 epoch(s).
   - A maximum of 7 blocks will be unfrozen.


Training efficientnet_b0:   0%|          | 0/50 [00:00<?, ?it/s]

📈 LR Increase: Set LR to 0.000333


Training efficientnet_b0:   2%|▏         | 1/50 [00:57<47:03, 57.63s/it, lr=2.0e-04, train_acc=0.4654, train_loss=1.8648, val_acc=0.5823, val_loss=1.4922]

📈 LR Increase: Set LR to 0.000667


Training efficientnet_b0:   4%|▍         | 2/50 [01:41<39:41, 49.62s/it, lr=4.0e-04, train_acc=0.5807, train_loss=1.2709, val_acc=0.6198, val_loss=1.3572]

📈 LR Increase: Set LR to 0.001000


Training efficientnet_b0:   8%|▊         | 4/50 [03:10<35:18, 46.06s/it, lr=8.0e-04, train_acc=0.6061, train_loss=1.1814, val_acc=0.6236, val_loss=1.3414]

Epoch 5: Block 6 params are already in optimizer. Skipping.


Training efficientnet_b0:  12%|█▏        | 6/50 [04:50<35:51, 48.89s/it, lr=1.0e-03, train_acc=0.6081, train_loss=1.8799, val_acc=0.6227, val_loss=1.5060]

Epoch 7: Block 5 params are already in optimizer. Skipping.


Training efficientnet_b0:  16%|█▌        | 8/50 [06:41<36:45, 52.51s/it, lr=1.0e-03, train_acc=0.6068, train_loss=1.8281, val_acc=0.6244, val_loss=1.4719]

🔥 Epoch 9: Unfreezing backbone block 4...
  -> Added 2 backbone params with low LR (1e-05).
  -> Added 2 new parameters to the optimizer.


Training efficientnet_b0:  20%|██        | 10/50 [08:32<36:02, 54.07s/it, lr=9.8e-04, train_acc=0.6122, train_loss=1.8325, val_acc=0.6309, val_loss=1.4808]

🔥 Epoch 11: Unfreezing backbone block 3...
  -> Added 1 backbone params with low LR (1e-05).
  -> Added 1 new parameters to the optimizer.


Training efficientnet_b0:  24%|██▍       | 12/50 [10:23<34:46, 54.90s/it, lr=9.6e-04, train_acc=0.6107, train_loss=1.8361, val_acc=0.6288, val_loss=1.4831]

🔥 Epoch 13: Unfreezing backbone block 2...
  -> Added 205 backbone params with low LR (1e-05).
  -> Added 205 new parameters to the optimizer.


Training efficientnet_b0:  28%|██▊       | 14/50 [12:45<37:32, 62.56s/it, lr=9.2e-04, train_acc=0.6146, train_loss=1.8478, val_acc=0.6293, val_loss=1.5225]

🔥 Epoch 15: Unfreezing backbone block 1...
  -> Added 2 backbone params with low LR (1e-05).
  -> Added 2 new parameters to the optimizer.


Training efficientnet_b0:  32%|███▏      | 16/50 [14:53<35:44, 63.07s/it, lr=8.8e-04, train_acc=0.6184, train_loss=1.8390, val_acc=0.6375, val_loss=1.4994]

🔥 Epoch 17: Unfreezing backbone block 0...
  -> Added 1 backbone params with low LR (1e-05).
  -> Added 1 new parameters to the optimizer.


Training efficientnet_b0: 100%|██████████| 50/50 [50:55<00:00, 61.12s/it, lr=2.2e-06, train_acc=0.6371, train_loss=1.8195, val_acc=0.6455, val_loss=1.4661]



✅ Saved best model weights to: /content/drive/MyDrive/personal_research_project/results_dir/fold_1_efficientnet_b0_cbam/weights/fold_1_efficientnet_b0_cbam_best.pth
Skipping Grad-CAM for this run.

--- Evaluating model: efficientnet_b0_cbam_fold1 ---
✅ t-SNE plot saved to: /content/drive/MyDrive/personal_research_project/results_dir/fold_1_efficientnet_b0_cbam/graphs/tsne_efficientnet_b0_cbam_fold1.png

Accuracy: 64.67%

Classification Report:
                  precision    recall  f1-score   support

          Black       0.83      0.81      0.82       750
     East Asian       0.70      0.70      0.70       750
         Indian       0.64      0.68      0.66       750
Latino_Hispanic       0.51      0.44      0.48       750
 Middle Eastern       0.56      0.60      0.58       750
Southeast Asian       0.59      0.62      0.61       750
          White       0.70      0.67      0.69       750

       accuracy                           0.65      5250
      macro avg       0.65      0.6

model.safetensors:   0%|          | 0.00/95.5M [00:00<?, ?B/s]

✅ Optimizing 4 new parameter groups with high LR (0.001).
No valid checkpoint found. Starting training from scratch (Epoch 0).
🧊 Backbone frozen: 21 total blocks.
📅 Strategy: Unfreeze from the last block backward (high-level features first).
   - Starting from block #20 at epoch 5.
   - Unfreezing one block every 2 epoch(s).
   - A maximum of 21 blocks will be unfrozen.


Training InceptionV3:   0%|          | 0/50 [00:00<?, ?it/s]

📈 LR Increase: Set LR to 0.000333


Training InceptionV3:   2%|▏         | 1/50 [00:45<36:53, 45.18s/it, lr=2.0e-04, train_acc=0.4064, train_loss=1.9011, val_acc=0.5981, val_loss=1.5733]

📈 LR Increase: Set LR to 0.000667


Training InceptionV3:   4%|▍         | 2/50 [01:30<36:03, 45.08s/it, lr=4.0e-04, train_acc=0.5533, train_loss=1.3622, val_acc=0.6192, val_loss=1.3744]

📈 LR Increase: Set LR to 0.001000


Training InceptionV3:   8%|▊         | 4/50 [02:59<34:23, 44.85s/it, lr=8.0e-04, train_acc=0.5795, train_loss=1.2577, val_acc=0.6270, val_loss=1.3532]

Epoch 5: Block 20 params are already in optimizer. Skipping.


Training InceptionV3:  12%|█▏        | 6/50 [04:40<35:43, 48.71s/it, lr=1.0e-03, train_acc=0.5676, train_loss=1.9164, val_acc=0.6232, val_loss=1.5498]

Epoch 7: Block 19 params are already in optimizer. Skipping.


Training InceptionV3:  16%|█▌        | 8/50 [06:33<36:58, 52.82s/it, lr=1.0e-03, train_acc=0.5718, train_loss=1.8564, val_acc=0.6238, val_loss=1.5300]

Epoch 9: Block 18 params are already in optimizer. Skipping.


Training InceptionV3:  20%|██        | 10/50 [08:26<36:26, 54.67s/it, lr=9.8e-04, train_acc=0.5776, train_loss=1.8533, val_acc=0.6270, val_loss=1.5064]

🔥 Epoch 11: Unfreezing backbone block 17...
  -> Added 27 backbone params with low LR (1e-05).
  -> Added 27 new parameters to the optimizer.


Training InceptionV3:  24%|██▍       | 12/50 [10:20<35:28, 56.01s/it, lr=9.6e-04, train_acc=0.5816, train_loss=1.8619, val_acc=0.6288, val_loss=1.5162]

🔥 Epoch 13: Unfreezing backbone block 16...
  -> Added 27 backbone params with low LR (1e-05).
  -> Added 27 new parameters to the optimizer.


Training InceptionV3:  28%|██▊       | 14/50 [12:16<34:11, 57.00s/it, lr=9.2e-04, train_acc=0.5724, train_loss=1.8505, val_acc=0.6312, val_loss=1.5139]

🔥 Epoch 15: Unfreezing backbone block 15...
  -> Added 18 backbone params with low LR (1e-05).
  -> Added 18 new parameters to the optimizer.


Training InceptionV3:  32%|███▏      | 16/50 [14:13<32:43, 57.74s/it, lr=8.8e-04, train_acc=0.5756, train_loss=1.8533, val_acc=0.6278, val_loss=1.5213]

🔥 Epoch 17: Unfreezing backbone block 14...
  -> Added 30 backbone params with low LR (1e-05).
  -> Added 30 new parameters to the optimizer.


Training InceptionV3:  36%|███▌      | 18/50 [16:11<31:07, 58.36s/it, lr=8.3e-04, train_acc=0.5800, train_loss=1.8445, val_acc=0.6337, val_loss=1.5003]

🔥 Epoch 19: Unfreezing backbone block 13...
  -> Added 30 backbone params with low LR (1e-05).
  -> Added 30 new parameters to the optimizer.


Training InceptionV3:  40%|████      | 20/50 [18:10<29:31, 59.04s/it, lr=7.8e-04, train_acc=0.5808, train_loss=1.8830, val_acc=0.6280, val_loss=1.5344]

🔥 Epoch 21: Unfreezing backbone block 12...
  -> Added 30 backbone params with low LR (1e-05).
  -> Added 30 new parameters to the optimizer.


Training InceptionV3:  44%|████▍     | 22/50 [20:11<27:50, 59.68s/it, lr=7.2e-04, train_acc=0.5860, train_loss=1.8370, val_acc=0.6305, val_loss=1.4987]

🔥 Epoch 23: Unfreezing backbone block 11...
  -> Added 30 backbone params with low LR (1e-05).
  -> Added 30 new parameters to the optimizer.


Training InceptionV3:  48%|████▊     | 24/50 [22:13<26:08, 60.34s/it, lr=6.5e-04, train_acc=0.5856, train_loss=1.8435, val_acc=0.6276, val_loss=1.4850]

🔥 Epoch 25: Unfreezing backbone block 10...
  -> Added 12 backbone params with low LR (1e-05).
  -> Added 12 new parameters to the optimizer.


Training InceptionV3:  52%|█████▏    | 26/50 [24:16<24:25, 61.06s/it, lr=5.9e-04, train_acc=0.5851, train_loss=1.8420, val_acc=0.6301, val_loss=1.4950]

🔥 Epoch 27: Unfreezing backbone block 9...
  -> Added 21 backbone params with low LR (1e-05).
  -> Added 21 new parameters to the optimizer.


Training InceptionV3:  56%|█████▌    | 28/50 [26:20<22:35, 61.61s/it, lr=5.2e-04, train_acc=0.5840, train_loss=1.8526, val_acc=0.6320, val_loss=1.5084]

🔥 Epoch 29: Unfreezing backbone block 8...
  -> Added 21 backbone params with low LR (1e-05).
  -> Added 21 new parameters to the optimizer.


Training InceptionV3:  60%|██████    | 30/50 [28:26<20:42, 62.12s/it, lr=4.5e-04, train_acc=0.5884, train_loss=1.8607, val_acc=0.6316, val_loss=1.5213]

🔥 Epoch 31: Unfreezing backbone block 7...
  -> Added 21 backbone params with low LR (1e-05).
  -> Added 21 new parameters to the optimizer.


Training InceptionV3:  64%|██████▍   | 32/50 [30:33<18:50, 62.79s/it, lr=3.8e-04, train_acc=0.5871, train_loss=1.8385, val_acc=0.6310, val_loss=1.5062]

Epoch 33: Block 6 params are already in optimizer. Skipping.


Training InceptionV3:  68%|██████▊   | 34/50 [32:40<16:52, 63.29s/it, lr=3.1e-04, train_acc=0.5968, train_loss=1.8476, val_acc=0.6352, val_loss=1.5042]

🔥 Epoch 35: Unfreezing backbone block 5...
  -> Added 3 backbone params with low LR (1e-05).
  -> Added 3 new parameters to the optimizer.


Training InceptionV3:  72%|███████▏  | 36/50 [34:48<14:52, 63.73s/it, lr=2.5e-04, train_acc=0.5922, train_loss=1.8373, val_acc=0.6324, val_loss=1.4968]

🔥 Epoch 37: Unfreezing backbone block 4...
  -> Added 3 backbone params with low LR (1e-05).
  -> Added 3 new parameters to the optimizer.


Training InceptionV3:  76%|███████▌  | 38/50 [36:58<12:50, 64.18s/it, lr=1.9e-04, train_acc=0.5913, train_loss=1.8378, val_acc=0.6345, val_loss=1.5053]

Epoch 39: Block 3 params are already in optimizer. Skipping.


Training InceptionV3:  80%|████████  | 40/50 [39:07<10:43, 64.34s/it, lr=1.4e-04, train_acc=0.5922, train_loss=1.8475, val_acc=0.6341, val_loss=1.4990]

🔥 Epoch 41: Unfreezing backbone block 2...
  -> Added 3 backbone params with low LR (1e-05).
  -> Added 3 new parameters to the optimizer.


Training InceptionV3:  84%|████████▍ | 42/50 [41:17<08:39, 64.91s/it, lr=9.6e-05, train_acc=0.5919, train_loss=1.8372, val_acc=0.6343, val_loss=1.4949]

🔥 Epoch 43: Unfreezing backbone block 1...
  -> Added 3 backbone params with low LR (1e-05).
  -> Added 3 new parameters to the optimizer.


Training InceptionV3:  88%|████████▊ | 44/50 [43:29<06:32, 65.42s/it, lr=5.9e-05, train_acc=0.5934, train_loss=1.8575, val_acc=0.6337, val_loss=1.5016]

🔥 Epoch 45: Unfreezing backbone block 0...
  -> Added 3 backbone params with low LR (1e-05).
  -> Added 3 new parameters to the optimizer.


Training InceptionV3: 100%|██████████| 50/50 [50:08<00:00, 60.16s/it, lr=2.2e-06, train_acc=0.5966, train_loss=1.8283, val_acc=0.6337, val_loss=1.5014]



✅ Saved best model weights to: /content/drive/MyDrive/personal_research_project/results_dir/fold_1_InceptionV3_cbam/weights/fold_1_InceptionV3_cbam_best.pth
Skipping Grad-CAM for this run.

--- Evaluating model: InceptionV3_cbam_fold1 ---
✅ t-SNE plot saved to: /content/drive/MyDrive/personal_research_project/results_dir/fold_1_InceptionV3_cbam/graphs/tsne_InceptionV3_cbam_fold1.png

Accuracy: 63.50%

Classification Report:
                  precision    recall  f1-score   support

          Black       0.82      0.80      0.81       750
     East Asian       0.69      0.69      0.69       750
         Indian       0.63      0.65      0.64       750
Latino_Hispanic       0.49      0.46      0.47       750
 Middle Eastern       0.55      0.58      0.56       750
Southeast Asian       0.57      0.63      0.60       750
          White       0.72      0.63      0.67       750

       accuracy                           0.64      5250
      macro avg       0.64      0.64      0.64      525

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

✅ Optimizing 4 new parameter groups with high LR (0.001).
No valid checkpoint found. Starting training from scratch (Epoch 0).
🧊 Backbone frozen: 10 total blocks.
📅 Strategy: Unfreeze from the last block backward (high-level features first).
   - Starting from block #9 at epoch 5.
   - Unfreezing one block every 2 epoch(s).
   - A maximum of 10 blocks will be unfrozen.


Training resnet152d:   0%|          | 0/50 [00:00<?, ?it/s]

📈 LR Increase: Set LR to 0.000333


Training resnet152d:   2%|▏         | 1/50 [00:48<39:25, 48.28s/it, lr=2.0e-04, train_acc=0.4704, train_loss=1.7034, val_acc=0.5859, val_loss=1.4361]

📈 LR Increase: Set LR to 0.000667


Training resnet152d:   4%|▍         | 2/50 [01:36<38:47, 48.49s/it, lr=4.0e-04, train_acc=0.5819, train_loss=1.2506, val_acc=0.6187, val_loss=1.3540]

📈 LR Increase: Set LR to 0.001000


Training resnet152d:   8%|▊         | 4/50 [03:13<36:58, 48.24s/it, lr=8.0e-04, train_acc=0.6054, train_loss=1.1868, val_acc=0.6251, val_loss=1.3396]

Epoch 5: Block 9 params are already in optimizer. Skipping.


Training resnet152d:  12%|█▏        | 6/50 [05:00<38:06, 51.97s/it, lr=1.0e-03, train_acc=0.6029, train_loss=1.9090, val_acc=0.6261, val_loss=1.5114]

Epoch 7: Block 8 params are already in optimizer. Skipping.


Training resnet152d:  16%|█▌        | 8/50 [06:58<39:06, 55.86s/it, lr=1.0e-03, train_acc=0.6106, train_loss=1.8231, val_acc=0.6293, val_loss=1.4886]

🔥 Epoch 9: Unfreezing backbone block 7...
  -> Added 30 backbone params with low LR (1e-05).
  -> Added 30 new parameters to the optimizer.


Training resnet152d:  20%|██        | 10/50 [08:59<38:45, 58.14s/it, lr=9.8e-04, train_acc=0.6113, train_loss=1.8157, val_acc=0.6286, val_loss=1.4993]

🔥 Epoch 11: Unfreezing backbone block 6...
  -> Added 327 backbone params with low LR (1e-05).
  -> Added 327 new parameters to the optimizer.


Training resnet152d:  24%|██▍       | 12/50 [11:16<40:18, 63.65s/it, lr=9.6e-04, train_acc=0.6102, train_loss=1.8420, val_acc=0.6293, val_loss=1.4894]

🔥 Epoch 13: Unfreezing backbone block 5...
  -> Added 75 backbone params with low LR (1e-05).
  -> Added 75 new parameters to the optimizer.


Training resnet152d:  28%|██▊       | 14/50 [13:40<40:43, 67.87s/it, lr=9.2e-04, train_acc=0.6136, train_loss=1.8473, val_acc=0.6320, val_loss=1.5021]

🔥 Epoch 15: Unfreezing backbone block 4...
  -> Added 30 backbone params with low LR (1e-05).
  -> Added 30 new parameters to the optimizer.


Training resnet152d:  30%|███       | 15/50 [14:54<40:42, 69.78s/it, lr=9.0e-04, train_acc=0.6130, train_loss=1.8294, val_acc=0.6360, val_loss=1.5098]