# Sigmoid vs Contrastive Loss

**Recommendation:** Use **SIGMOID loss** for your use case. Here's why:

| Aspect            | Sigmoid Loss                                   | Contrastive Loss                                  |
|-------------------|------------------------------------------------|----------------------------------------------------|
| **Best for**      | Multi-label (multiple conditions per image)    | Single-label classification                        |
| **Your data**     | ‚úÖ 20%+ multi-label samples                     | ‚ùå Treats each sample as single class              |
| **Output**        | Independent probabilities per class            | Probability distribution (sums to 1)               |
| **Training**      | Each label trained independently               | Contrasts positive vs negative pairs               |
| **Class imbalance** | ‚úÖ Handles better with per-class thresholds  | ‚ùå Can struggle with imbalanced data               |


In [None]:
#!/usr/bin/env python
"""
Train SigLIP.
Uses the 66 specific condition labels with their descriptions for more precise matching.

This approach should provide:
- Better discrimination between specific conditions
- More precise retrieval accuracy
- Direct condition-level matching without hierarchical complexity
"""
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from pathlib import Path
from PIL import Image
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from tqdm import tqdm
import json
from peft import LoraConfig, get_peft_model
from datetime import datetime
from collections import defaultdict, Counter
import random

# Get the current timestamp
now = datetime.now()

timestamp_str = now.strftime("%Y%m%d%H%M")

from transformers import (
    AutoModel,
    AutoProcessor,
    EarlyStoppingCallback,
    TrainingArguments,
    Trainer,
)
from torch.utils.data import Dataset

@dataclass
class Config:
    # Model
    model_name: str = 'google/siglip-base-patch16-224'
    
    # Training
    batch_size: int = 32
    num_epochs: int = 15
    learning_rate: float = 1e-5 # 4e-5 seems like a good sweet spot.
    weight_decay: float = 0.01  #Was 5e-02 or 0.05. High decay prevents overfitting.
    warmup_steps: int = 100
    
    #Lora parameters
    lora_rank: int = 32
    lora_alpha: int = 32

    # Device
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    fp16: bool = torch.cuda.is_available()
    temperature: float = 0.7
    
    # Set this to "contrastive" or "sigmoid"
    loss_type: str = "sigmoid" # or "contrastive"
    
    # finetune_mode. # Options: "FULL", "LoRA"
    finetune_mode: str = "LoRA"  
    
    # Data paths
    log_dir: Path = Path('./runs')
    data_dir: Path = Path('/home/agentic-health/data')
    visual_descriptions_path: Path ='./data/condition_visual_descriptions.json'
    metadata_file: Path = './data/coarse_labeled_metadata_with_labels.csv'
    images_dir: Path = data_dir / 'scin/images'
    ICD_DISSIMILARITY_MATRIX: Path = Path('./data/icd_dissimilarity_matrix.pt')
    CONDITION_IDX_MAP: Path = Path('./data/condition_idx_map.json')
    output_dir: Path = Path(f'./siglip_{finetune_mode.lower()}_{loss_type.lower()}_{timestamp_str}')

config = Config()
config.output_dir.mkdir(parents=True, exist_ok=True)

class DualLogger:
    def __init__(self, filepath):
        self.terminal = sys.stdout
        self.log = open(filepath, "w")  # Overwrite file each run
        # Store original stdout before any redirection
        self._original_stdout = sys.__stdout__

    def write(self, message):
        # Write to Jupyter's display system
        self._original_stdout.write(message)
        self._original_stdout.flush()
        
        # Also write to file
        self.log.write(message)
        self.log.flush()

    def flush(self):
        self._original_stdout.flush()
        self.log.flush()
        
log_file_path = config.output_dir / "training_log.txt"
sys.stdout = DualLogger(log_file_path)

print("="*80)
print(f"SIGLIP {config.loss_type} LEARNING USLING {config.finetune_mode}")
print("="*80)
print(f"\nConfiguration:")
print(f"  Model: {config.model_name}")
print(f"  Output: {config.output_dir}")
print(f"  Device: {config.device}")
print(f"  Batch size: {config.batch_size}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Weight decay: {config.weight_decay}")
print(f"  Epochs: {config.num_epochs}")
print(f"  Lora_rank: {config.lora_rank}")   
print(f"  Output directory: {config.output_dir}")

# Load data
print("\n" + "="*80)
print("Loading Data")
print("="*80)

df = pd.read_csv(config.metadata_file)
print(f"Total real samples in {config.metadata_file}: {len(df)}")
print(f"conditions: {df['condition'].nunique()}")
print(f"Coarse categories: {df['coarse_category'].nunique()}")

  from .autonotebook import tqdm as notebook_tqdm


SIGLIP sigmoid LEARNING USLING LoRA

Configuration:
  Model: google/siglip-base-patch16-224
  Output: siglip_lora_sigmoid_202511221028
  Device: cuda
  Batch size: 32
  Learning rate: 1e-05
  Weight decay: 0.01
  Epochs: 15
  Lora_rank: 32
  Output directory: siglip_lora_sigmoid_202511221028

Loading Data
Total real samples in ./data/coarse_labeled_metadata_with_labels.csv: 5909
conditions: 66
Coarse categories: 16

Loading and Processing Data
Raw data loaded: 5909 samples

Balancing training data (target min: 30)...
  Added 45 samples via oversampling.
Training set size before adding synthetic data: 4147
Loaded metadata for 1354 existing synthetic samples.
Checking 32 rare classes (Target: 30)...
No new images generated.
New training set size after adding synthetic data: 5501
 DATASET ANALYSIS FOR SIGLIP FINE-TUNING

1. OVERALL DATASET STATISTICS
--------------------------------------------------------------------------------
Total samples: 5,909
  - Training:   5,501 (93.1%)
  - Vali

In [2]:
def get_labels_from_item(item):
    """
        Extract label from a dataset item (dictionary row).
        We prefer the 'description' column if available, otherwise template the condition
    """   
    if pd.notna(item['description']):
        text = item['description']
    else:
        # Fallback template if description is missing
        text = f"A dermatological photo showing {item['condition']}"
    
    return text

def stratified_split_data(all_data, val_ratio=0.15, test_ratio=0.15):
    """
    Splits data while ensuring rare classes are distributed across all sets.
    Strategy: Assigns rare classes first to ensure Val/Test coverage.
    """
    train_data, val_data, test_data = [], [], []
    
    # 1. Map each class to its samples
    class_to_samples = defaultdict(list)
    sample_to_id = {} # Track assigned samples
    
    for idx, item in enumerate(all_data):
        labels = get_labels_from_item(item)
        sample_to_id[idx] = False # False = not assigned
        for label in labels:
            class_to_samples[label].append(idx)
            
    # 2. Sort classes by rarity (rarest first)
    sorted_classes = sorted(class_to_samples.keys(), key=lambda k: len(class_to_samples[k]))
    
    for cls in sorted_classes:
        indices = class_to_samples[cls]
        # Filter for unassigned indices
        available_indices = [i for i in indices if not sample_to_id[i]]
        random.shuffle(available_indices)
        
        n = len(available_indices)
        # Ensure at least 1 sample in Val/Test if possible
        n_val = max(1, int(n * val_ratio)) if n > 1 else 0
        n_test = max(1, int(n * test_ratio)) if n > 2 else 0
        
        # Allocate indices
        val_idxs = available_indices[:n_val]
        test_idxs = available_indices[n_val:n_val+n_test]
        train_idxs = available_indices[n_val+n_test:]
        
        for i in val_idxs:
            if not sample_to_id[i]:
                val_data.append(all_data[i])
                sample_to_id[i] = True
        for i in test_idxs:
            if not sample_to_id[i]:
                test_data.append(all_data[i])
                sample_to_id[i] = True
        for i in train_idxs:
            if not sample_to_id[i]:
                train_data.append(all_data[i])
                sample_to_id[i] = True

    # 3. Assign any leftovers
    leftovers = [all_data[i] for i in range(len(all_data)) if not sample_to_id[i]]
    for item in leftovers:
        r = random.random()
        if r < val_ratio: val_data.append(item)
        elif r < val_ratio + test_ratio: test_data.append(item)
        else: train_data.append(item)

    return train_data, val_data, test_data

def balance_training_data(train_data, target_min_samples=30):
    """
    Oversamples rare classes in the training set to reach a minimum count.
    """
    print(f"\nBalancing training data (target min: {target_min_samples})...")
    class_counts = Counter()
    class_samples = defaultdict(list)
    
    for item in train_data:
        labels = get_labels_from_item(item)
        for label in labels:
            class_counts[label] += 1
            class_samples[label].append(item)
            
    balanced_data = list(train_data)
    added_count = 0
    
    for cls, count in class_counts.items():
        if count < target_min_samples:
            needed = target_min_samples - count
            available = class_samples[cls]
            if not available: continue
            
            # Resample with replacement
            extras = random.choices(available, k=needed)
            balanced_data.extend(extras)
            added_count += needed

    random.shuffle(balanced_data)
    print(f"  Added {added_count} samples via oversampling.")
    return balanced_data

In [3]:
import numpy as np
import random
from PIL import Image, ImageEnhance, ImageFilter, ImageOps
import uuid
import pandas as pd
from pathlib import Path
import os

def generate_synthetic_data_pil(image, num_samples=5, use_edge_enhance=False):
    """
    Creates synthetic medical images using Pillow (PIL).
    
    Args:
        image (PIL.Image): Input image.
        num_samples (int): Number of synthetic versions to create.
        use_edge_enhance (bool): If True, applies edge enhancement (like Sobel).
        
    Returns:
        list: A list of augmented PIL images.
    """
    synthetic_images = []
    w, h = image.size

    for _ in range(num_samples):
        aug_img = image.copy()
        
        # --- A. GEOMETRIC TRANSFORMS ---
        
        # 1. Random Flip
        if random.random() > 0.5:
            # Mirror (Horizontal)
            if random.random() > 0.5:
                aug_img = ImageOps.mirror(aug_img)
            # Flip (Vertical) - Valid for skin lesions
            else:
                aug_img = ImageOps.flip(aug_img)

        # 2. Random Rotation (-30 to 30 degrees)
        if random.random() > 0.5:
            angle = random.uniform(-30, 30)
            aug_img = aug_img.rotate(angle, resample=Image.BICUBIC, expand=False)

        # 3. Random Resized Crop (Zoom effect)
        if random.random() > 0.3:
            scale = random.uniform(0.75, 0.95)
            new_w, new_h = int(w * scale), int(h * scale)
            
            left = random.randint(0, w - new_w)
            top = random.randint(0, h - new_h)
            
            aug_img = aug_img.crop((left, top, left + new_w, top + new_h))
            aug_img = aug_img.resize((w, h), resample=Image.BICUBIC)

        # --- B. PIXEL TRANSFORMS ---

        # 4. Gaussian Blur
        if random.random() > 0.7:
            radius = random.uniform(1, 2)
            aug_img = aug_img.filter(ImageFilter.GaussianBlur(radius))

        # 5. Gaussian Noise
        if random.random() > 0.7:
            # Convert to numpy to add noise
            img_arr = np.array(aug_img).astype(float)
            noise = np.random.normal(0, 15, img_arr.shape)
            img_arr = np.clip(img_arr + noise, 0, 255).astype('uint8')
            aug_img = Image.fromarray(img_arr)

        # 6. Color Jitter (Brightness/Contrast/Saturation)
        if random.random() > 0.5:
            # Brightness
            enhancer = ImageEnhance.Brightness(aug_img)
            aug_img = enhancer.enhance(random.uniform(0.8, 1.2))
            
            # Contrast
            enhancer = ImageEnhance.Contrast(aug_img)
            aug_img = enhancer.enhance(random.uniform(0.9, 1.1))
            
            # Saturation
            enhancer = ImageEnhance.Color(aug_img)
            aug_img = enhancer.enhance(random.uniform(0.9, 1.1))

        # --- C. STRUCTURAL TRANSFORMS ---
        
        # 7. Edge Enhancement (Alternative to Sobel)
        if use_edge_enhance and random.random() > 0.8:
            # Convert to grayscale and find edges
            gray = aug_img.convert("L").filter(ImageFilter.FIND_EDGES)
            # Convert back to RGB
            edge_img = gray.convert("RGB")
            # Blend with original (50%)
            aug_img = Image.blend(aug_img, edge_img, alpha=0.5)

        synthetic_images.append(aug_img)

    return synthetic_images

In [4]:
import pandas as pd
from pathlib import Path
import uuid
from PIL import Image

def augment_rare_classes(df, images_dir, output_dir, target_count=30):
    """
    Identifies rare classes and generates synthetic images to reach target_count.
    Skips generation if sufficient synthetic data already exists in output_dir.
    
    Args:
        df (pd.DataFrame): Original training dataframe.
        images_dir (str/Path): Directory containing original images.
        output_dir (str/Path): Directory to save synthetic images and metadata.
        target_count (int): Minimum samples desired per class.
        
    Returns:
        pd.DataFrame: A DataFrame containing ALL synthetic data (previously existing + newly generated).
    """
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # 1. Load existing synthetic metadata if it exists
    metadata_path = output_path / "synthetic_metadata.csv"
    if metadata_path.exists():
        try:
            existing_syn_df = pd.read_csv(metadata_path)
            print(f"Loaded metadata for {len(existing_syn_df)} existing synthetic samples.")
        except Exception as e:
            print(f"Warning: Could not read existing metadata ({e}). Starting fresh.")
            existing_syn_df = pd.DataFrame()
    else:
        existing_syn_df = pd.DataFrame()

    # Identify rare classes based on ORIGINAL data
    counts = df['condition'].value_counts()
    rare_classes = counts[counts < target_count].index.tolist()
    
    new_rows = []
    
    print(f"Checking {len(rare_classes)} rare classes (Target: {target_count})...")

    for condition in rare_classes:
        # Count original samples
        original_count = len(df[df['condition'] == condition])
        
        # Count existing synthetic samples (if any)
        if not existing_syn_df.empty:
            syn_count = len(existing_syn_df[existing_syn_df['condition'] == condition])
        else:
            syn_count = 0
            
        # Calculate how many MORE we need
        total_current = original_count + syn_count
        needed = target_count - total_current
        
        if needed <= 0:
            # print(f"  ‚úì {condition}: Has {total_current} (Skipping)")
            continue
            
        print(f"  - {condition}: Generating {needed} new images (Original: {original_count}, Existing Syn: {syn_count})...")
        
        # Get existing samples to augment
        existing_samples = df[df['condition'] == condition]
        if existing_samples.empty: continue

        generated_count = 0
        while generated_count < needed:
            # Pick random source image
            source_row = existing_samples.sample(1).iloc[0]
            
            img_path = Path(images_dir) / source_row['image_path']
            try:
                original_img = Image.open(img_path).convert("RGB")
            except Exception as e:
                continue

            # Generate synthetic version
            # Assuming generate_synthetic_data_pil is defined as before
            synthetic_imgs = generate_synthetic_data_pil(original_img, num_samples=1)
            syn_img = synthetic_imgs[0]
            
            # --- METADATA ---
            new_id = str(uuid.uuid4())
            
            # Sanitize filename
            safe_condition_name = condition.replace(' ', '_').replace('/', '_').replace('\\', '_')
            new_filename = f"syn_{safe_condition_name}_{new_id[:8]}.png"
            save_path = output_path / new_filename
            
            # Save Image
            syn_img.save(save_path)
            
            # Create Metadata Row
            new_row = source_row.copy()
            new_row['image_id'] = new_id
            new_row['image_path'] = new_filename # Just filename, relative to output_dir
            new_row['split'] = 'train'
            
            # Add helper flag (optional)
            new_row['is_synthetic'] = True
            
            new_rows.append(new_row)
            generated_count += 1

    # 2. Combine and Save
    if new_rows:
        new_df = pd.DataFrame(new_rows)
        if not existing_syn_df.empty:
            # Align columns before concat to prevent errors
            # (Optional safety step if schemas drift)
            final_syn_df = pd.concat([existing_syn_df, new_df], ignore_index=True)
        else:
            final_syn_df = new_df
            
        # Save updated metadata
        final_syn_df.to_csv(metadata_path, index=False)
        print(f"Saved updated metadata with {len(new_rows)} new samples to {metadata_path}")
        return final_syn_df
    else:
        print("No new images generated.")
        return existing_syn_df

In [5]:
# ==============================================================================
#  IMPROVED DATA LOADING SECTION (Replaces original loading logic)
# ==============================================================================
print("\n" + "="*80)
print("Loading and Processing Data")
print("="*80)

# 1. Load raw CSV
df = pd.read_csv(config.metadata_file)
print(f"Raw data loaded: {len(df)} samples")

# 2. Convert to list of dictionaries for processing
all_data = df.to_dict('records')

# 3. Perform Stratified Split (Fixes missing classes in Val/Test)
train_raw, val_data, test_data = stratified_split_data(
    all_data, 
    val_ratio=0.15, 
    test_ratio=0.15
)

# 4. Balance Training Data (Fixes class imbalance)
# Target 30 samples ensures the model sees rare classes ~1 per batch on average
train_balanced = balance_training_data(train_raw, target_min_samples=30)

# 5. Convert back to DataFrames for compatibility
train_df = pd.DataFrame(train_balanced)
print(f"Training set size before adding synthetic data: {len(train_df)}")
synthetic_dir = "./scin_synthetic"
synthetic_df = augment_rare_classes(
    train_df, 
    config.images_dir, 
    synthetic_dir, 
    target_count=30
)

# Merge
train_df = pd.concat([train_df, synthetic_df], ignore_index=True)
print(f"New training set size after adding synthetic data: {len(train_df)}")

val_df = pd.DataFrame(val_data)
test_df = pd.DataFrame(test_data)

In [6]:
import pandas as pd
import numpy as np
from collections import Counter

def print_comprehensive_report(df, train_df, val_df, test_df):
    """
    Print a comprehensive data analysis report for SigLip fine-tuning
    """
    
    print("="*80)
    print(" DATASET ANALYSIS FOR SIGLIP FINE-TUNING")
    print("="*80)
    
    # Overall statistics
    print("\n1. OVERALL DATASET STATISTICS")
    print("-" * 80)
    print(f"Total samples: {len(df):,}")
    print(f"  - Training:   {len(train_df):,} ({len(train_df)/len(df)*100:.1f}%)")
    print(f"  - Validation: {len(val_df):,} ({len(val_df)/len(df)*100:.1f}%)")
    print(f"  - Test:       {len(test_df):,} ({len(test_df)/len(df)*100:.1f}%)")
    print(f"\nUnique conditions: {df['condition'].nunique()}")
    print(f"Unique coarse categories: {df['coarse_category'].nunique()}")
    
    # Label distribution by split
    print("\n2. LABEL DISTRIBUTION BY SPLIT")
    print("-" * 80)
    
    # Get all unique conditions
    all_conditions = df['condition'].unique()
    
    # Create distribution table
    distribution_data = []
    for condition in sorted(all_conditions):
        train_count = len(train_df[train_df['condition'] == condition])
        val_count = len(val_df[val_df['condition'] == condition])
        test_count = len(test_df[test_df['condition'] == condition])
        total = train_count + val_count + test_count
        
        distribution_data.append({
            'Condition': condition,
            'Train': train_count,
            'Val': val_count,
            'Test': test_count,
            'Total': total
        })
    
    dist_df = pd.DataFrame(distribution_data).sort_values('Total', ascending=False)
    
    # Print top 15 most common conditions
    print("\nTop 15 Most Common Conditions:")
    print(f"{'Condition':<40} {'Train':>8} {'Val':>8} {'Test':>8} {'Total':>8}")
    print("-" * 80)
    for _, row in dist_df.head(15).iterrows():
        print(f"{row['Condition']:<40} {row['Train']:>8} {row['Val']:>8} {row['Test']:>8} {row['Total']:>8}")
    
    # Print conditions with imbalanced splits (potential issues)
    print("\n3. POTENTIAL TRAINING ISSUES")
    print("-" * 80)
    
    # Classes with very few samples
    rare_classes = dist_df[dist_df['Total'] < 10]
    print(f"\n‚ö†Ô∏è  Conditions with <10 total samples: {len(rare_classes)}")
    if len(rare_classes) > 0:
        print("(These may cause overfitting or poor generalization)")
        for _, row in rare_classes.iterrows():
            print(f"  - {row['Condition']:<40} (Total: {row['Total']})")
    
    # Classes missing from validation or test
    missing_val = dist_df[dist_df['Val'] == 0]
    missing_test = dist_df[dist_df['Test'] == 0]
    
    print(f"\n‚ö†Ô∏è  Conditions missing from validation set: {len(missing_val)}")
    if len(missing_val) > 0 and len(missing_val) <= 10:
        for _, row in missing_val.iterrows():
            print(f"  - {row['Condition']}")
    
    print(f"\n‚ö†Ô∏è  Conditions missing from test set: {len(missing_test)}")
    if len(missing_test) > 0 and len(missing_test) <= 10:
        for _, row in missing_test.iterrows():
            print(f"  - {row['Condition']}")
    
    # Multi-label analysis
    print("\n4. MULTI-LABEL ANALYSIS")
    print("-" * 80)
    
    def count_labels(df_subset):
        """Count images by number of labels, safely evaluating string representations."""
        import ast
        label_counts = []
        
        for all_conds in df_subset['all_conditions']:
            num_labels = 1 # Default to 1 label if we encounter valid data but no list structure
            
            # 1. Check if the value is already a list/tuple (Path A)
            if isinstance(all_conds, (list, tuple)):
                num_labels = len(all_conds)
                
            # 2. Check if the value is a string (Path B - The source of the bug)
            elif isinstance(all_conds, str) and all_conds.strip():
                try:
                    # Safely convert the string representation of the list
                    conds = ast.literal_eval(all_conds)
                    
                    if isinstance(conds, (list, tuple)):
                        num_labels = len(conds)
                        
                except:
                    # If ast.literal_eval fails (e.g., malformed string), 
                    # we maintain the default assumption of 1 label.
                    pass
            
            # 3. Final Check: If filtering failed somewhere and we hit 0, revert to 1 (Adhering to previous constraints)
            if num_labels == 0:
                num_labels = 1
                
            label_counts.append(num_labels)
                
        return Counter(label_counts)
    
    train_label_dist = count_labels(train_df)
    val_label_dist = count_labels(val_df)
    test_label_dist = count_labels(test_df)
    
    print(f"{'Labels per Image':<20} {'Train':>10} {'Val':>10} {'Test':>10}")
    print("-" * 80)
    all_label_counts = sorted(set(list(train_label_dist.keys()) + list(val_label_dist.keys()) + list(test_label_dist.keys())))
    for num_labels in all_label_counts:
        print(f"{num_labels} label(s): {train_label_dist.get(num_labels, 0):>10} {val_label_dist.get(num_labels, 0):>10} {test_label_dist.get(num_labels, 0):>10}")
    
    # Calculate multi-label percentage
    multi_label_train = sum(count for labels, count in train_label_dist.items() if labels > 1)
    multi_label_pct = multi_label_train / len(train_df) * 100 if len(train_df) > 0 else 0
    print(f"\nMulti-label samples in training: {multi_label_train:,} ({multi_label_pct:.1f}%)")
    
    # Coarse category distribution
    print("\n5. COARSE CATEGORY DISTRIBUTION")
    print("-" * 80)
    print(f"{'Category':<40} {'Train':>8} {'Val':>8} {'Test':>8} {'Total':>8}")
    print("-" * 80)
    
    for category in sorted(df['coarse_category'].unique()):
        train_count = len(train_df[train_df['coarse_category'] == category])
        val_count = len(val_df[val_df['coarse_category'] == category])
        test_count = len(test_df[test_df['coarse_category'] == category])
        total = train_count + val_count + test_count
        print(f"{category:<40} {train_count:>8} {val_count:>8} {test_count:>8} {total:>8}")
    
    # Class imbalance metrics
    print("\n6. CLASS IMBALANCE METRICS")
    print("-" * 80)
    
    train_condition_counts = train_df['condition'].value_counts()
    max_samples = train_condition_counts.max()
    min_samples = train_condition_counts.min()
    imbalance_ratio = max_samples / min_samples if min_samples > 0 else float('inf')
    
    print(f"Training set class imbalance:")
    print(f"  - Most common class:  {train_condition_counts.index[0]:<40} ({max_samples} samples)")
    print(f"  - Least common class: {train_condition_counts.index[-1]:<40} ({min_samples} samples)")
    print(f"  - Imbalance ratio:    {imbalance_ratio:.1f}:1")
    print(f"\nMedian samples per class: {train_condition_counts.median():.0f}")
    print(f"Mean samples per class:   {train_condition_counts.mean():.1f}")
    
    # Recommendations
    print("\n7. RECOMMENDATIONS FOR SIGLIP FINE-TUNING")
    print("="*80)
    
    print("\nüìä Data Considerations:")
    if imbalance_ratio > 100:
        print("  ‚ö†Ô∏è  SEVERE class imbalance detected (>100:1)")
        print("     ‚Üí Consider: class weights, focal loss, or oversampling rare classes")
    elif imbalance_ratio > 20:
        print("  ‚ö†Ô∏è  Significant class imbalance (>20:1)")
        print("     ‚Üí Consider: balanced sampling or class weights")
    
    if multi_label_pct > 20:
        print(f"\n  ‚ÑπÔ∏è  High multi-label percentage ({multi_label_pct:.1f}%)")
        print("     ‚Üí Use SIGMOID loss (not contrastive) for multi-label classification")
    else:
        print(f"\n  ‚ÑπÔ∏è  Low multi-label percentage ({multi_label_pct:.1f}%)")
        print("     ‚Üí Either SIGMOID or CONTRASTIVE loss can work")
    
    if len(rare_classes) > 0:
        print(f"\n  ‚ö†Ô∏è  {len(rare_classes)} classes with <10 samples")
        print("     ‚Üí These may not learn well; consider grouping into coarse categories")
    
    print("\n" + "="*80)

# Call the report function
print_comprehensive_report(df, train_df, val_df, test_df)

In [7]:
import torch
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
import json # Ensure this is imported at the top of your file

class CreateDataset(Dataset):
    """
    Dataset for SigLIP fine-tuning, incorporating ICD indices and verbose descriptions.
    """

    def __init__(self, df: pd.DataFrame, images_dir: Path, processor, mode='train', 
                 visual_descriptions_path=None, icd_idx_map_path=None): # Accept map path
        
        self.df = df.reset_index(drop=True)
        self.images_dir = Path(images_dir)
        self.processor = processor
        self.mode = mode
        
        # --- 1. Load Visual Descriptions (Text Signal) ---
        self.visual_descriptions = {}
        if visual_descriptions_path and Path(visual_descriptions_path).exists():
            try:
                with open(visual_descriptions_path, 'r') as f:
                    self.visual_descriptions = json.load(f)
                print(f"[{mode}] Loaded {len(self.visual_descriptions)} visual descriptions.")
            except:
                pass

        # --- 2. Load ICD Index Map (Geometric Signal) ---
        self.condition_to_idx = {}
        if icd_idx_map_path and Path(icd_idx_map_path).exists():
            try:
                with open(icd_idx_map_path, 'r') as f:
                    self.condition_to_idx = json.load(f)
                print(f"[{mode}] Loaded {len(self.condition_to_idx)} ICD index mappings.")
            except:
                pass
    
    def _get_text_prompt(self, row):
        """Helper to generate the text prompt based on mode and description availability."""
        condition = row['condition']
        
        # Priority 1: AI-generated visual description
        visual_desc = self.visual_descriptions.get(condition, "")
        
        # Priority 2: Fallback to dataset's raw description
        raw_desc = str(row.get('description', '')).strip()
        if raw_desc.lower() == 'nan' or not raw_desc:
            raw_desc = ""

        if self.mode == 'train':
            # Training: Prioritize the strongest visual feature for the model to learn from
            if visual_desc:
                text = f"A high-resolution dermatological image of {condition}: {visual_desc}"
            elif raw_desc:
                 # Fallback template for generic raw descriptions
                text = np.random.choice([
                    f"Patient presenting with {condition}: {raw_desc}",
                    f"Clinical image of {condition}, described as: {raw_desc}"
                ])
            else:
                text = f"A dermatological photo of {condition}"
        else:
            # Testing/Evaluation Mode: Consistent format using the best available description
            if visual_desc:
                text = visual_desc
            elif raw_desc:
                text = f"{condition}: {raw_desc}"
            else:
                text = f"A photo of {condition}"

        return text.strip()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # --- Image/Text Loading (Omitted for brevity) ---
        image_path = self.images_dir / row['image_path']
        try:
            image = Image.open(image_path).convert('RGB')
        except:
            image = Image.new('RGB', (224, 224), color='black')

        text = self._get_text_prompt(row)
        
        # 1. Prepare and clean the condition name
        # FIX: Ensure all whitespace is removed from the beginning/end
        condition_name_raw = str(row['condition'])
        lookup_key = condition_name_raw.strip()
        
        # --- Lookup the numerical index ---
        # The key is now guaranteed to be clean.
        cond_idx = self.condition_to_idx.get(lookup_key, -1) 

        # --- DEBUG CHECK (For verification only) ---
        # The fact that Min/Max is -1, -1 means this check is still failing
        if cond_idx == -1 and self.condition_to_idx:
            print(f"FATAL LOOKUP ERROR: Looked for '{lookup_key}', but failed.")
            print(f"DEBUG: First Map Key: '{list(self.condition_to_idx.keys())[0]}'")

        # --- Processor Call ---
        inputs = self.processor(
            text=[text],
            images=image,
            padding="max_length",
            truncation=True,
            max_length=64,
            return_tensors="pt"
        )
        
        # 1. Remove batch dimension from all processor outputs
        for key in inputs:
            inputs[key] = inputs[key].squeeze(0)

        # --- CRITICAL FIX: ADD ICD ARTIFACT TO INPUTS DICTIONARY ---
        # The ICD Loss requires this index tensor.
        inputs["condition_indices"] = torch.tensor(cond_idx, dtype=torch.long)
        
        # 2. Return the complete dictionary (which now includes condition_indices)
        return inputs

In [None]:
class SigLipModel(nn.Module):
    """SigLIP model for contrastive learning."""

    def __init__(self, model_name: str, temperature: float = 0.07):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.temperature = temperature
        self.logit_bias = nn.Parameter(torch.zeros(1))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / temperature))
        self.proj = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(768, 768) 
        )

    def forward(self, input_ids=None, pixel_values=None, attention_mask=None, **kwargs):
        outputs = self.model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        image_feat = outputs.vision_model_output.pooler_output
        image_embeds = F.normalize(self.proj(image_feat), dim=-1)

        image_embeds = outputs.vision_model_output.pooler_output
        text_embeds = outputs.text_model_output.pooler_output

        image_embeds = F.normalize(image_embeds, dim=-1)
        text_embeds = F.normalize(text_embeds, dim=-1)

        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_embeds @ text_embeds.t()
        logits_per_text = logits_per_image.t()

        batch_size = image_embeds.shape[0]
        
        # --- Switchable Loss Calculation ---
        if config.loss_type == "contrastive":
            labels = torch.arange(batch_size, device=image_embeds.device)
            loss_i = F.cross_entropy(logits_per_image, labels)
            loss_t = F.cross_entropy(logits_per_text, labels)
            loss = (loss_i + loss_t) / 2
        elif config.loss_type == "sigmoid":
            labels = torch.eye(batch_size, device=image_embeds.device)
            loss_i = F.binary_cross_entropy_with_logits(logits_per_image, labels)
            loss_t = F.binary_cross_entropy_with_logits(logits_per_text, labels)
            loss = (loss_i + loss_t) / 2.0
        else:
            raise ValueError(f"Unknown loss_type: {config.loss_type}")

        return {
            'loss': loss,
            'logits_per_image': logits_per_image,
            'logits_per_text': logits_per_text,
            'image_embeds': image_embeds,
            'text_embeds': text_embeds,
        }

In [9]:
print("\n" + "="*80)
print("Loading Model and Creating Datasets")
print("="*80)

# --- 1. Load Processor ---
# The processor handles image resizing/normalization and text tokenization.
processor = AutoProcessor.from_pretrained(config.model_name)

# --- 2. Load Model & Weights ---
# AutoModel detects that 'google/siglip-base-patch16-224' is a SiglipModel 
# and loads the corresponding pre-trained weights.
base_model = SigLipModel(
    model_name=config.model_name,
    temperature=config.temperature
)

# --- 3. Cast and Move to Device ---
# For stability, we explicitly cast the entire model to float32.
# This prevents the NaN/Inf errors often seen with half-precision (fp16) during loss calculation.
model = base_model.to(config.device).float()

print(f"‚úì Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters")

# Create datasets
# Note: Ensure your df has 'condition' and 'description' columns
train_dataset = CreateDataset(train_df, config.images_dir, processor, mode='train')
val_dataset = CreateDataset(val_df, config.images_dir, processor, mode='val')
test_dataset = CreateDataset(test_df, config.images_dir, processor, mode='test')

print(f"‚úì Datasets created")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val: {len(val_dataset)} samples")
print(f"  Test: {len(test_dataset)} samples")

# Show sample texts
print("\nSample fine-grained texts:")
for i in range(3):
    sample_text = train_df.iloc[i]['description']
    condition = train_df.iloc[i]['condition']
    print(f"  [{condition}] {sample_text[:80]}...")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [10]:
"""
Production-quality SigLipTrainer

Usage:
    trainer = SigLipTrainer(..., use_temp=False, log_interval=100)
    trainer.freeze_logit_scale()   # recommended
    trainer.train()

Notes:
- Expects model.forward(...) to return a dict or ModelOutput containing:
    - 'image_embeds' (B, D)
    - 'text_embeds'  (B, D)
    - optionally 'logits_per_image' (B, B) (not required; computed from embeddings here)
- Optional dataset must put 'condition_indices' into inputs when you want ICD to be used.
- Default behavior: pure SigLIP (no temperature). To experiment with temperature, set use_temp=True.
"""

import json
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Any, Union, Optional
from transformers import Trainer


# Default ramp sizes (tune for your dataset)
ICD_RAMP_STEPS = 4000
WASS_RAMP_STEPS = 4000
HARDNEG_RAMP_STEPS = 1000

class SigLipTrainer(Trainer):
    def __init__(
        self,
        *args,
        use_temp: bool = False,
        log_interval: int = 100,
        icd_matrix_path: Optional[str] = config.ICD_DISSIMILARITY_MATRIX,
        condition_idx_map: Optional[str] = config.CONDITION_IDX_MAP,
        alpha_max_influence: float = 0.05,
        total_icd_steps: int = ICD_RAMP_STEPS,
        total_wass_steps: int = WASS_RAMP_STEPS,
        total_hardneg_steps: int = HARDNEG_RAMP_STEPS,
        **kwargs,
    ):
        """
        Args:
            use_temp: if True, use model.logit_scale (CLIP-like). Default False (pure SigLIP).
            log_interval: how often (in steps) to print debug information.
            icd_matrix_path: optional path to ICD dissimilarity matrix (torch.load-compatible).
            condition_idx_map: optional json path mapping condition->index (if needed).
            alpha_max_influence: max fraction of total loss contributed by ICD (0..1).
            total_*_steps: durations of ramp phases for alpha, beta, gamma.
        """
        super().__init__(*args, **kwargs)

        self.use_temp = use_temp
        self.log_interval = int(log_interval)

        # Load ICD matrix / mapping if provided
        self.icd_matrix = None
        self.condition_to_idx = None
        
        # Get the target device (e.g., 'cuda:0')
        target_device = self.args.device
        
        try:
            if icd_matrix_path is not None:
                # 1. Load the matrix (which is currently on CPU)
                self.icd_matrix = torch.load(icd_matrix_path)
                
                # 2. Move the ENTIRE matrix to the GPU/Target Device
                self.icd_matrix = self.icd_matrix.to(target_device)
                
            if condition_idx_map is not None:
                with open(condition_idx_map, "r") as f:
                    raw_map = json.load(f)
                # Strip all whitespace from the keys before assigning them to self.condition_to_idx.
                self.condition_to_idx = {k.strip(): v for k, v in raw_map.items()}
                    
            if not self.condition_to_idx:
                print(f"FATAL ERROR: ICD Map loaded but is EMPTY: {icd_map_path}")
            else:
                print(f"‚úÖ ICD Map and Matrix loaded successfully.")
            
        except FileNotFoundError:
            print(f"FATAL ERROR: ICD Index Map NOT FOUND at {icd_map_path.resolve()}.")
            print("ACTION REQUIRED: Ensure 'icd_geometry' folder exists and files were generated.")
            
        except json.JSONDecodeError:
            print(f"FATAL ERROR: ICD Index Map is CORRUPTED: {icd_map_path}.")        
                

        # Modulation coefficients (will be scheduled in training_step)
        self.alpha = 0.0
        self.beta = 0.0
        self.gamma = 0.0

        self.total_icd_steps = 0
        self.total_wass_steps = 0
        self.total_hardneg_steps = 0
        
        # Store for scheduling
        # self.total_icd_steps = int(total_icd_steps)
        # self.total_wass_steps = int(total_wass_steps)
        # self.total_hardneg_steps = int(total_hardneg_steps)
        
        self.alpha_max_influence = float(alpha_max_influence)

        # Safety: make sure alpha_max_influence in [0,1]
        if not (0.0 <= self.alpha_max_influence <= 1.0):
            raise ValueError("alpha_max_influence must be between 0 and 1")

        # For debugging/logging
        self._last_logged_step = -1

    # -----------------------
    # Utility kernels / small helpers
    # -----------------------
    @staticmethod
    def _gaussian_kernel(x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0) -> torch.Tensor:
        """RBF kernel matrix between x and y. x: (n, d), y: (m, d) -> (n, m)."""
        xx = (x ** 2).sum(dim=1, keepdim=True)  # (n,1)
        yy = (y ** 2).sum(dim=1, keepdim=True)  # (m,1)
        dist = xx + yy.t() - 2.0 * (x @ y.t())
        return torch.exp(-dist / (2.0 * sigma ** 2))

    # -----------------------
    # SigLIP loss (pairwise sigmoid BCE)
    # -----------------------
    def siglip_loss(
        self,
        image_embeds: torch.Tensor,
        text_embeds: torch.Tensor,
        scale: float = 1.0,
    ):
        """
        Compute SigLIP pairwise sigmoid loss and return (loss_scalar, logits_per_image, logits_per_text).

        scale: scalar multiplier applied to cosine similarities. For pure SigLIP use scale=1.0.
        If using temperature, pass model.logit_scale.exp().
        """
        device = image_embeds.device

        # Normalize to get cosine similarities
        img = F.normalize(image_embeds, dim=-1)
        txt = F.normalize(text_embeds, dim=-1)

        logits_per_image = scale * (img @ txt.t())  # (B, B)
        logits_per_text = logits_per_image.t()

        B = logits_per_image.size(0)
        if B <= 1:
            # Return a dummy loss with gradient to avoid HF errors
            dummy = torch.tensor(0.0, device=device, requires_grad=True)
            return dummy, logits_per_image, logits_per_text

        labels = torch.eye(B, device=device)
        # Use BCEWithLogitsLoss (stable) averaged over all entries
        loss_fn = nn.BCEWithLogitsLoss()
        loss_i = loss_fn(logits_per_image, labels)
        loss_t = loss_fn(logits_per_text, labels)
        loss = 0.5 * (loss_i + loss_t)
        return loss, logits_per_image, logits_per_text

    # -----------------------
    # ICD structural loss (MSE between pairwise distances)
    # -----------------------
    def icd_loss(self, image_embeds: torch.Tensor, text_embeds: torch.Tensor, indices: torch.LongTensor):
        """
        Compute ICD MSE between predicted pairwise distances and ICD D-matrix for this batch.

        - predicted_distance = 1 - cosine_similarity (range [0,2])
        - both predicted_distance and D_batch are normalized to [0,1] before computing MSE
        - diagonal entries are ignored
        """
        device = image_embeds.device
        B = image_embeds.size(0)
        if B <= 1:
            return torch.tensor(0.0, device=device)

        # predicted pairwise similarity -> distance
        img = F.normalize(image_embeds, dim=-1)
        txt = F.normalize(text_embeds, dim=-1)
        pred_sim = img @ txt.t()  # (B,B)
        pred_dist = 1.0 - pred_sim  # in [0,2]

        # prepare D_batch from ICD matrix
        if self.icd_matrix is None:
            raise RuntimeError("ICD matrix was not provided but icd_loss was called")

        indices = indices.long().to(device)
        if indices.numel() != B:
            # mismatched mapping; skip with a warning
            if (self.state.global_step % max(1, self.log_interval)) == 0:
                print(f"WARNING: icd_loss indices length ({indices.numel()}) != batch_size ({B}); skipping ICD term")
            return torch.tensor(0.0, device=device)

        # Index and cast to float
        D_batch = self.icd_matrix[indices][:, indices].to(device).float()

        # Normalize both to [0,1]
        pd_min, pd_max = float(pred_dist.min().detach()), float(pred_dist.max().detach())
        pred_norm = (pred_dist - pd_min) / (pd_max - pd_min + 1e-8)

        D_min, D_max = float(D_batch.min().detach()), float(D_batch.max().detach())
        D_norm = (D_batch - D_min) / (D_max - D_min + 1e-8)

        # mask out diagonal
        mask = ~torch.eye(B, dtype=torch.bool, device=device)
        if mask.sum() == 0:
            return torch.tensor(0.0, device=device)

        return F.mse_loss(pred_norm[mask], D_norm[mask])

    # -----------------------
    # MMD (Wasserstein proxy) loss
    # -----------------------
    def mmd_loss(self, image_embeds: torch.Tensor, text_embeds: torch.Tensor, sigma: float = 1.0):
        device = image_embeds.device
        B = image_embeds.size(0)
        if B <= 1:
            return torch.tensor(0.0, device=device)
        Kxx = self._gaussian_kernel(image_embeds, image_embeds, sigma)
        Kyy = self._gaussian_kernel(text_embeds, text_embeds, sigma)
        Kxy = self._gaussian_kernel(image_embeds, text_embeds, sigma)
        return Kxx.mean() + Kyy.mean() - 2.0 * Kxy.mean()

    # -----------------------
    # Hard negative margin-based loss
    # -----------------------
    def hard_negative_loss(self, image_embeds: torch.Tensor, text_embeds: torch.Tensor, margin: float = 0.2):
        device = image_embeds.device
        B = image_embeds.size(0)
        if B <= 1:
            return torch.tensor(0.0, device=device)
        sims = image_embeds @ text_embeds.t()
        pos_sim = sims.diag()  # (B,)
        sims_no_pos = sims.clone()
        sims_no_pos.fill_diagonal_(-1e9)
        hard_neg_sim, _ = sims_no_pos.max(dim=1)  # (B,)
        loss = F.relu(margin + hard_neg_sim - pos_sim).mean()
        return loss

    # -----------------------
    # compute_loss (core)
    # -----------------------
    def compute_loss(self, model, inputs: Dict[str, Any], return_outputs: bool = False, num_items_in_batch: int = None):
        """
        Compute full layered loss:
            L_total = (1 - alpha) * L_siglip + alpha * L_icd + beta * L_mmd + gamma * L_hardneg

        This function is robust to ModelOutput/dict outputs from model.forward.
        """
        device = self.args.device or (torch.device("cuda" if torch.cuda.is_available() else "cpu"))

        # pop condition indices if present (dataset must place them there)
        condition_indices = inputs.pop("condition_indices", None)
        
        # if condition_indices is not None:
            # print(f"DEBUG: Indices Shape: {condition_indices.shape} | Indices Dtype: {condition_indices.dtype}")
            # print(f"DEBUG: Indices Min/Max: {condition_indices.min()}, {condition_indices.max()}")

        outputs = model(**inputs)

        # Extract embeddings robustly
        if isinstance(outputs, dict):
            image_embeds = outputs.get("image_embeds", None)
            text_embeds = outputs.get("text_embeds", None)
            logits_provided = outputs.get("logits_per_image", None)
        else:
            image_embeds = getattr(outputs, "image_embeds", None)
            text_embeds = getattr(outputs, "text_embeds", None)
            logits_provided = getattr(outputs, "logits_per_image", None)

        if image_embeds is None or text_embeds is None:
            raise RuntimeError("Model forward must return 'image_embeds' and 'text_embeds'")

        B = image_embeds.size(0)
        if B <= 1:
            dummy = torch.tensor(0.0, device=image_embeds.device, requires_grad=True)
            return (dummy, outputs) if return_outputs else dummy

        # optionally apply small dropout on embeddings to regularize
        image_embeds = F.dropout(image_embeds, p=0.1, training=self.model.training)
        text_embeds = F.dropout(text_embeds, p=0.1, training=self.model.training)

        # Determine scale for SigLIP: either 1.0 (pure) or provided logit_scale
        if self.use_temp and hasattr(model, "logit_scale"):
            # Use scalar value (clamped to avoid massive BCE)
            scale = float(model.logit_scale.exp().clamp(max=100.0))
        else:
            scale = 1.0

        # -------------------------
        # SigLIP baseline
        # -------------------------
        loss_siglip, logits_per_image, logits_per_text = self.siglip_loss(image_embeds, text_embeds, scale)

        # -------------------------
        # ICD structural loss (alpha)
        # -------------------------
        loss_icd = torch.tensor(0.0, device=image_embeds.device)
        if (self.alpha > 0.0) and (condition_indices is not None) and (self.icd_matrix is not None):
            # Condition indices may be list or tensor
            if isinstance(condition_indices, torch.Tensor):
                indices = condition_indices
            else:
                indices = torch.tensor(condition_indices, device=image_embeds.device)
            if indices.numel() == B:
                loss_icd = self.icd_loss(image_embeds, text_embeds, indices)
            else:
                # warn occasionally
                if (self.state.global_step % max(1, self.log_interval)) == 0:
                    print(f"WARNING: ICD indices length mismatch ({indices.numel()} != {B}) ‚Äî skipping ICD term")

        # -------------------------
        # MMD/Wasserstein proxy (beta)
        # -------------------------
        loss_mmd = torch.tensor(0.0, device=image_embeds.device)
        if getattr(self, "beta", 0.0) > 0.0:
            loss_mmd = self.mmd_loss(image_embeds, text_embeds, sigma=1.0)

        # -------------------------
        # Hard negative (gamma)
        # -------------------------
        loss_hardneg = torch.tensor(0.0, device=image_embeds.device)
        if getattr(self, "gamma", 0.0) > 0.0:
            loss_hardneg = self.hard_negative_loss(image_embeds, text_embeds, margin=0.2)

        # -------------------------
        # Aggregate layered loss
        # -------------------------
        alpha = float(self.alpha)
        beta = float(getattr(self, "beta", 0.0))
        gamma = float(getattr(self, "gamma", 0.0))

        loss_total = (1.0 - alpha) * loss_siglip
        if alpha > 0.0:
            loss_total = loss_total + alpha * loss_icd
        if beta > 0.0:
            loss_total = loss_total + beta * loss_mmd
        if gamma > 0.0:
            loss_total = loss_total + gamma * loss_hardneg

        # Debug logging for components
        step = int(self.state.global_step)
        if step % max(1, self.log_interval) == 0 and step != self._last_logged_step:
            def _safe_float(t: torch.Tensor) -> Optional[float]:
                try:
                    return float(t.detach().cpu())
                except Exception:
                    return None
            print(
                f"[Step {step}] L_total={_safe_float(loss_total):.6f} | "
                f"L_siglip={_safe_float(loss_siglip):.6f} | L_icd={_safe_float(loss_icd):.6f} | "
                f"L_mmd={_safe_float(loss_mmd):.6f} | L_hardneg={_safe_float(loss_hardneg):.6f} | "
                f"alpha={alpha:.4f}, beta={beta:.4f}, gamma={gamma:.4f}"
            )
            self._last_logged_step = step

        return (loss_total, outputs) if return_outputs else loss_total

    # -----------------------
    # prediction_step (keeps consistent logic)
    # -----------------------
    def prediction_step(self, model, inputs: Dict[str, Any], prediction_loss_only: bool, ignore_keys=None):
        """
        Compute logits for evaluation and (optionally) the BCE loss on validation.
        """
        with torch.no_grad():
            outputs = model(**inputs)
            if isinstance(outputs, dict):
                logits = outputs.get("logits_per_image", None)
            else:
                logits = getattr(outputs, "logits_per_image", None)

            if logits is None:
                # compute from embeddings if model didn't provide logits
                image_embeds = outputs.get("image_embeds", getattr(outputs, "image_embeds", None))
                text_embeds = outputs.get("text_embeds", getattr(outputs, "text_embeds", None))
                if image_embeds is None or text_embeds is None:
                    return (None, None, None)
                scale = float(getattr(model, "logit_scale", torch.tensor(1.0)).exp()) if self.use_temp else 1.0
                _, logits, _ = self.siglip_loss(image_embeds, text_embeds, scale)

            # safe check
            if torch.isnan(logits).any() or torch.isinf(logits).any():
                return (None, None, None)

            loss = None
            if not prediction_loss_only:
                B = logits.size(0)
                if B > 1:
                    labels = torch.eye(B, device=logits.device)
                    loss = F.binary_cross_entropy_with_logits(logits, labels)

            preds = logits.cpu().contiguous()
            return (loss, preds, None)

    # -----------------------
    # training_step override: update schedule BEFORE doing the step
    # -----------------------
    def training_step(self, model: nn.Module, inputs: Dict[str, Any], dataloader_or_placeholder=None):
        """
        Update alpha/beta/gamma schedule first, then call the parent Trainer.training_step which
        will call compute_loss (and thus use the updated coefficients).
        """
        global_step = int(self.state.global_step)

        icd_end = max(1, self.total_icd_steps)
        wass_end = icd_end + max(1, self.total_wass_steps)
        hard_end = wass_end + max(1, self.total_hardneg_steps)

        if global_step < icd_end:
            alpha_raw = float(global_step) / float(icd_end)
            self.alpha = min(alpha_raw, self.alpha_max_influence)
            self.beta = 0.0
            self.gamma = 0.0
        elif global_step < wass_end:
            self.alpha = self.alpha_max_influence
            steps_in = global_step - icd_end
            self.beta = float(steps_in) / float(max(1, self.total_wass_steps))
            self.gamma = 0.0
        elif global_step < hard_end:
            self.alpha = self.alpha_max_influence
            self.beta = 1.0
            steps_in = global_step - wass_end
            self.gamma = float(steps_in) / float(max(1, self.total_hardneg_steps))
        else:
            self.alpha = self.alpha_max_influence
            self.beta = 1.0
            self.gamma = 1.0

        if global_step % max(1, self.log_interval) == 0:
            print(f"[SCHEDULE] step={global_step} | alpha={self.alpha:.4f} beta={self.beta:.4f} gamma={self.gamma:.4f}")
            
        # If the total duration for a ramp was 0, the coefficient MUST be 0, 
        # regardless of what the scheduler calculated above.

        if self.total_icd_steps == 0:
            self.alpha = 0.0
        
        if self.total_wass_steps == 0:
            self.beta = 0.0
            
        if self.total_hardneg_steps == 0:
            self.gamma = 0.0

        # Now forward/backprop is performed by the parent Trainer which will call compute_loss.
        return super().training_step(model, inputs, dataloader_or_placeholder)

    # -----------------------
    # Utilities: freeze logit scale, and a tune_report
    # -----------------------
    def freeze_logit_scale(self, model=None, value: Optional[float] = None):
        """Freeze model.logit_scale if present. Optionally set to a specific log-value (value is logit_scale.raw, e.g. np.log(1/0.07))."""
        if model is None:
            model = self.model
        if hasattr(model, "logit_scale"):
            model.logit_scale.requires_grad = False
            if value is not None:
                with torch.no_grad():
                    model.logit_scale.data = torch.tensor(float(value)).to(next(model.parameters()).device)

    def tune_report(self) -> str:
        """Structured step-by-step tuning plan to run experiments in a safe order."""
        lines = [
            "SigLipTrainer Tuning Plan:",
            "1) Baseline: use_temp=False, alpha=beta=gamma=0. Train backbone LR=1e-5, head LR=5e-5, batch>=64.",
            "2) Freeze model.logit_scale: trainer.freeze_logit_scale(). This prevents BCE explosion.",
            "3) If small dataset: freeze vision encoder for first N epochs (5-10), train only projection/head.",
            "4) Add ICD (alpha) slowly: alpha_max_influence=0.05, long total_icd_steps (e.g., 10000).",
            "5) Add MMD (beta) tiny: beta=0.01. Monitor L_mmd and val retrieval.",
            "6) Add HardNeg (gamma) last (0.01->0.1). Use margin=0.2.",
            "7) Use aggressive augmentations for images and token dropout for text on small-data.",
            "8) Use gradient clipping (1.0), weight decay (1e-2) and linear warmup (500-2000 steps).",
        ]
        return "\n".join(lines)

In [11]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import io
from PIL import Image

def plot_confusion_matrix_fig(predictions, class_names):
    """Generates a Matplotlib Figure of the confusion matrix."""
    
    # 1. Generate True Labels (Assuming diagonal alignment based on prediction size)
    # The true labels are always [0, 1, 2, ..., N-1] repeated across all batches
    N_total_samples = len(predictions)
    true_labels = np.arange(N_total_samples) % predictions.shape[1] 

    # 2. Convert Predictions (Find Top-1 class index for each sample)
    # The predictions passed here are the raw logits (N_total_samples, N_candidates)
    predicted_classes = np.argmax(predictions, axis=1)

    # 3. Handle Multiclass vs Multilabel Confusion Matrix (Crucial distinction)
    # We are performing a contrastive retrieval task, which is a multi-class problem
    # where the 'classes' are the indices 0 to BATCH_SIZE - 1. 
    # For reporting, we must simplify this. Let's rely on simple binary correctness.
    # --- Since we cannot easily reconstruct the *true* class name here, 
    #     we will use the "Correct Index" (0) vs "Incorrect Index" (1..N-1) logic.

    # NOTE: The Trainer's compute_metrics function already handles this logic.
    # For a *report*, the simplest visualization is usually the Top-1 accuracy over the whole set.
    # However, since we want a full confusion matrix, we must use the indices 0..N-1.
    
    # We must generate the true class names (the 66 condition names) for the axis labels.
    # Since we don't have the final class list, we'll use numeric indices.

    cm = confusion_matrix(true_labels, predicted_classes, labels=np.arange(predictions.shape[1]))

    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", cbar=False, ax=ax)
    
    ax.set_title("Top-1 Retrieval Confusion Matrix (Indices 0-N)")
    ax.set_ylabel('True Index (Expected Rank)')
    ax.set_xlabel('Predicted Index (Model Rank)')
    plt.close(fig)
    return fig

In [12]:
from transformers import TrainerCallback
import torch
import numpy as np

# Assuming plot_confusion_matrix_fig is defined above
# Assuming explicit_writer is defined globally

class ConfusionMatrixCallback(TrainerCallback):
    """
    Logs the confusion matrix to TensorBoard after each evaluation.
    Requires the SummaryWriter instance upon initialization.
    """
    def __init__(self, writer):
        # Store the provided SummaryWriter instance
        self.writer = writer
        
        # You may want to initialize other internal variables here
        self.global_step = 0
        
        # NOTE: You must also ensure the logic in on_evaluate 
        # uses self.writer instead of a global tb_writer.
        
        # Call the parent TrainerCallback init (optional, but good practice)
        super().__init__()
    
    def on_evaluate(self, args, state, control, metrics, **kwargs):
        # 1. Check for predictions in Trainer state (necessary for manual retrieval)
        if hasattr(state, 'predictions') and state.predictions is not None:
            # state.predictions holds the result from the prediction_step
            # Format: (loss_values, predictions_array, labels_array)
            predictions_raw = state.predictions[1] 

            if isinstance(predictions_raw, np.ndarray):
                # The predictions are already collected as a NumPy array (logits)
                final_predictions = predictions_raw
            elif isinstance(predictions_raw, tuple):
                # Trainer sometimes packs predictions as a tuple (e.g., if multiple heads)
                # We expect the first element to be the logits array
                final_predictions = predictions_raw[0]
            else:
                return # Cannot process, data format is incorrect

            # 2. Generate Matplotlib Figure
            try:
                # We need the class names, but since we don't have them, we use the raw predictions
                fig = plot_confusion_matrix_fig(final_predictions, None) 
                print(f"| DEBUG: After plotting confusion matrix figure.")
                
                # 3. Log the figure to TensorBoard
                global explicit_writer
                if explicit_writer:
                    print(f"| DEBUG: Confusion Matrix - Before add_figure step {state.global_step}.")
                    explicit_writer.add_figure(
                        f"{metrics.get('epoch', 'final'):.2f}/Confusion_Matrix",
                        fig,
                        global_step=state.global_step
                    )
            except Exception as e:
                print(f"[CM Callback Error] Failed to generate/log confusion matrix: {e}")

In [13]:
import numpy as np
from transformers import TrainerCallback
from transformers.integrations import TensorBoardCallback

class CovarianceLoggingCallback(TrainerCallback):
    """
    Logs the covariance of model predictions/features to TensorBoard.
    """
    def __init__(self, writer):
        # Store the provided SummaryWriter instance
        self.writer = writer
        
        # You may want to initialize other internal variables here
        self.global_step = 0
        
        # NOTE: You must also ensure the logic in on_evaluate 
        # uses self.writer instead of a global tb_writer.
        
        # Call the parent TrainerCallback init (optional, but good practice)
        super().__init__()
        
    def on_evaluate(self, args, state, control, model, **kwargs):
        # Only proceed if TensorBoard logging is active
        if args.report_to and "tensorboard" in args.report_to:
            tb_writer = None
            # Find the active TensorBoard writer instance
            for callback in kwargs.get("callbacks", []):
                if isinstance(callback, TensorBoardCallback):
                    tb_writer = callback.tb_writer
                    break
            
            if tb_writer:
                # --- Compute Covariance Here ---
                # This is highly context-dependent. A simplified example
                # using random data. Replace with your actual calculation 
                # (e.g., from validation dataset logits/features).
                
                # Example: Compute covariance of a dummy feature vector
                # In a real scenario, you'd extract logits or features 
                # during evaluation somehow.
                # For this example, we log a sample matrix:
                
                sample_data = np.random.randn(100, 5) # 100 samples, 5 features
                cov_matrix = np.cov(sample_data, rowvar=False) # Features in columns
                cov_mean = np.mean(cov_matrix)
                cov_std = np.std(cov_matrix)

                print(f"| DEBUG: Logging Covariance for step {state.global_step}. Shape: {cov_matrix.shape}")
                tb_writer.add_histogram(
                    tag="Evaluation/Covariance_Matrix_Values_Distribution",
                    values=cov_matrix.flatten(), # Log all elements as a distribution
                    global_step=state.global_step
                )
                
                # You might also want to log the max/min of the covariance
                tb_writer.add_scalar("Evaluation/Covariance_Max_Value", np.max(cov_matrix), state.global_step)
                tb_writer.add_scalar("Evaluation/Covariance_Mean_Value", cov_mean, state.global_step)
                tb_writer.add_scalar("Evaluation/Covariance_STD_Deviation", cov_std, state.global_step)
                tb_writer.flush()


#### 1. Loss Analysis Diagnosis Table

| Training Loss | Validation Loss | Diagnosis | Action Required |
|---------------|-----------------|-----------|-----------------|
| High | High | **Underfitting** | Increase model complexity, train longer |
| Low | High (Increasing) | **Overfitting** | Add regularization (Dropout/L2), get more data, Early Stopping |
| Low | Low | **Good Fit** | Save model weights |
| High | Low | **Data Issue** | Check regularization, data augmentation, or data split |

#### 2. Loss Value Interpretation Table

| Value Range | Meaning | Status |
|-------------|---------|--------|
| > 0.60 | **High** | Underfitting / Starting: Model is guessing |
| 0.40 - 0.60 | **Medium** | Learning: Model is picking up patterns |
| < 0.30 | **Low** | Good Fit: Model is confident in its predictions |

#### 3. Additional Metrics for Multi-Label Classification

| Metric | Good Range | Interpretation |
|--------|------------|----------------|
| **Per-class Accuracy** | > 0.70 | Percentage of correct predictions per condition |
| **Macro F1-Score** | > 0.60 | Average F1 across all classes (handles imbalance) |
| **Hamming Loss** | < 0.20 | Fraction of wrong labels (lower is better) |
| **Subset Accuracy** | > 0.40 | Exact match of all labels (strict metric) |

#### 4. SigLip Fine-tuning Specific Guidelines

| Training Stage | Expected Train Loss | Expected Val Loss | What to Look For |
|----------------|-------------------|-------------------|------------------|
| **Early (Epoch 1-3)** | 0.80 - 1.20 | 0.70 - 1.10 | Both should decrease rapidly |
| **Mid (Epoch 4-8)** | 0.40 - 0.70 | 0.45 - 0.75 | Val loss should track train loss closely |
| **Late (Epoch 9-15)** | 0.20 - 0.45 | 0.30 - 0.55 | Gap should be small (<0.10) |
| **Converged** | < 0.30 | < 0.40 | Both stable, gap <0.10 means good generalization |

#### 5. Warning Signs During Training

| Observation | Problem | Solution |
|-------------|---------|----------|
| Val loss increases after epoch 5 | **Overfitting** | Enable early stopping, increase dropout |
| Both losses stuck > 0.70 | **Underfitting** | Decrease LoRA rank, increase learning rate |
| Val loss oscillates wildly | **Unstable training** | Reduce learning rate, increase batch size |
| Train loss = 0.0, Val loss > 1.0 | **Severe overfitting** | Add strong regularization, reduce model capacity |
| Gap between train/val > 0.20 | **Poor generalization** | More data augmentation, check data leakage |

#### 6. Recommended Actions Based on Your Dataset

Given your characteristics:
- **79.4% multi-label samples** ‚Üí Focus on Hamming Loss and Macro F1
- **758:1 class imbalance** ‚Üí Monitor per-class metrics, not just overall accuracy
- **9 rare classes (<10 samples)** ‚Üí These will have high loss; group to coarse categories

| Target Metrics (Epoch 15) | Realistic Goal | Stretch Goal |
|---------------------------|----------------|--------------|
| **Validation Loss** | < 0.40 | < 0.30 |
| **Macro F1-Score** | > 0.55 | > 0.65 |
| **Top-3 Accuracy** | > 0.80 | > 0.90 |
| **Hamming Loss** | < 0.25 | < 0.15 |

In [14]:
import torch
from torch.utils.tensorboard import SummaryWriter 
from transformers import default_data_collator

def custom_layered_loss_data_collator(features):
    """
    Custom collator to handle the custom 'condition_indices' tensor alongside 
    standard Hugging Face processor outputs.
    """
    
    # 1. Extract the custom indices (Tensor)
    condition_indices = [f['condition_indices'] for f in features]
    
    # 2. Remove the custom key so the default collator handles the rest
    for f in features:
        del f['condition_indices']
        
    # 3. Use the default collator for the standard keys 
    #    (This correctly handles padding/stacking for pixel_values, input_ids, etc.)
    batch = default_data_collator(features)
    
    # 4. Add the custom indices back as a single stacked tensor
    # Note: torch.stack is required here because they are single tensors from __getitem__
    batch['condition_indices'] = torch.stack(condition_indices)
    
    return batch

def debug_icd_map_integrity(trainer_instance: Any):
    """Prints the loaded ICD index map keys and their raw string representation."""
    
    # Access the map stored in the trainer instance
    icd_map: Dict[str, int] = trainer_instance.condition_to_idx 
    
    if not icd_map:
        print("\n--- ICD MAP INTEGRITY CHECK ---")
        print("‚ùå ERROR: self.condition_to_idx is empty. Check file loading in __init__.")
        return

    print("\n--- ICD MAP INTEGRITY CHECK ---")
    print(f"‚úÖ Map Size: {len(icd_map)} conditions.")
    print("---------------------------------")
    print("Sample Keys (First 25 for detailed verification):")
    
    # Iterate through the first 25 keys
    for i, (condition, index) in enumerate(list(icd_map.items())[:25]):
        # The repr() function shows invisible characters like spaces (' ') or newlines ('\n')
        print(f"  [{i:2d} | Index {index}] Key: '{condition}' (Raw: {repr(condition)})")
    
    print("\n--- Failing Input Check ---")
    print(f"Input Check: 'Eczema' in map? {('Eczema' in icd_map)}")

def compute_metrics(eval_pred):
    # predictions (logits) are tuple: (logits_per_image, logits_per_text)
    logits = eval_pred.predictions 
    # Robustly handle tuple unpacking
    if isinstance(logits, tuple):
        logits = logits[0]
    
    # For evaluation, we treat the diagonal as the ground truth class
    # (retrieval task: given image i, find text i)
    batch_size = logits.shape[0]
    labels = np.arange(batch_size)
    
    preds = np.argmax(logits, axis=1)
    
    accuracy = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro', zero_division=0)
    
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

# Define your custom hyperparameters for tracking (outside TrainingArguments if needed)
CUSTOM_HP_DROPOUT = 0.1
CUSTOM_HP_LAYERS = 3

# Training arguments
training_args = TrainingArguments(
    output_dir=str(config.output_dir),
    num_train_epochs=config.num_epochs,
    per_device_train_batch_size=config.batch_size,
    per_device_eval_batch_size=config.batch_size,
    learning_rate=config.learning_rate,
    weight_decay=config.weight_decay,
    warmup_steps=config.warmup_steps,
    lr_scheduler_type="cosine", # Controls lerning rate behavior.
    fp16=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    report_to=["tensorboard"], # Tell the Trainer to send all logs to TensorBoard
    logging_dir=config.log_dir,
    logging_strategy="steps",
    run_name=f"run_lr3e4_dp{CUSTOM_HP_DROPOUT}", # Great for fine-tuning experiment tracking
    logging_steps=50,
    save_total_limit=2, #Was 3
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    remove_unused_columns=False,
    dataloader_num_workers=0,
    prediction_loss_only=False, # CRITICAL for compute_metrics to run!
)

# Apply LoRA
lora_config = LoraConfig(
        r=config.lora_alpha,
        lora_alpha=config.lora_alpha,
        target_modules=["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"],
        lora_dropout=0.1,
        bias="none",
    )

# Apply LoRA if specified
if config.finetune_mode == "LoRA":
    print("\nApplying LoRA fine-tuning...")
    model = get_peft_model(model, lora_config)  

explicit_writer = SummaryWriter(log_dir=config.log_dir)
# 1. Create the callback instance
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=5)
# Instantiate the custom callback
cov_callback = CovarianceLoggingCallback(writer=explicit_writer)
conf_matrix_callback = ConfusionMatrixCallback(writer=explicit_writer)

explicit_writer = SummaryWriter(log_dir=config.log_dir)

trainer = SigLipTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,  # Your standard metrics function
    callbacks=[early_stopping_callback, cov_callback, conf_matrix_callback], # Your custom covariance logger
    data_collator=custom_layered_loss_data_collator, # Custom collator to handle condition indices
)

print("\n" + "="*80)
print("ICD MAP INTEGRITY CHECK")
print("="*80)
debug_icd_map_integrity(trainer)



print("\n" + "="*80)
print("STARTING FINE-GRAINED TRAINING")
print("="*80)
print("Training with fine-grained condition labels...")
print("Focus on discriminating between 66 specific conditions")
print()

train_result = trainer.train()
print(f"\n‚úì Training complete! Final loss: {train_result.training_loss:.4f}")

# Evaluate
test_results = trainer.evaluate(test_dataset)
print(f"Test Loss: {test_results['eval_loss']:.4f}") # (This runs successfully)

# Required for Tensor Board HParams logging
if explicit_writer:
    # Log final metrics with the 'hparam/' prefix for the HParams tab.
    explicit_writer.add_hparams(
        hparam_dict={
            'custom_dropout_rate': CUSTOM_HP_DROPOUT,
            'another_custom_hp': CUSTOM_HP_LAYERS,
            'learning_rate': config.learning_rate,
            'batch_size': config.batch_size,
        },
        metric_dict={
            # Now these will be populated with real, non-zero values
            'hparam/loss': test_results.get('eval_loss', 0.0),
            'hparam/accuracy': test_results.get('eval_accuracy', 0.0),
        }
    )
    explicit_writer.flush()

print(f"‚úì Test loss: {test_results['eval_loss']:.4f}")

# Save
final_path = config.output_dir / 'final_model'
trainer.save_model(str(final_path))
processor.save_pretrained(str(final_path))
print(f"‚úì Model saved to: {final_path}")

Epoch,Training Loss,Validation Loss
1,0.5127,0.341094
2,0.3822,0.27866
3,0.3714,0.270648
4,0.3705,0.267424
5,0.3701,0.265739
6,0.3699,0.264728
7,0.3698,0.264016
8,0.3696,0.263557
9,0.3696,0.2633
10,0.3696,0.263117


In [15]:
def test_retrieval(model, processor, test_df, images_dir, n_samples=20, debug_limit=5):
    """
    Test fine-grained retrieval capabilities with detailed debugging.
    
    Args:
        debug_limit (int): Number of samples to print detailed logs for.
    """
    model.eval()

    # Ensure we don't sample more than available
    n_samples = min(n_samples, len(test_df))
    sample_indices = np.random.choice(len(test_df), n_samples, replace=False)

    accuracies = []
    condition_accuracies = {}
    debug_counter = 0

    print("\n" + "="*80)
    print(f"New Testing retrieval on {n_samples} samples...")
    print(f"Detailed debug info will be printed for the first {debug_limit} samples.\n")
    print("\n" + "="*80)
    
    for idx in tqdm(sample_indices, desc="Testing retrieval"):
        row = test_df.iloc[idx]

        # 1. Load Image
        image_path = images_dir / row['image_path']
        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            continue

        # 2. Prepare Candidates
        desc = str(row['description'])
        if desc.lower() == 'nan' or not desc:
            correct_text = f"A dermatological image of {row['condition']}"
        else:
            correct_text = f"A dermatological image of {row['condition']}: {desc}"
            
        correct_condition = row['condition']

        # Get distractors
        other_conditions = test_df[test_df['condition'] != correct_condition]['condition'].unique()
        
        n_distractors = min(9, len(other_conditions))
        if n_distractors < 1: continue

        distractor_conditions = np.random.choice(other_conditions, n_distractors, replace=False)

        distractor_texts = []
        for dist_cond in distractor_conditions:
            dist_rows = test_df[test_df['condition'] == dist_cond]
            if len(dist_rows) > 0:
                dist_row = dist_rows.sample(1).iloc[0]
                # Clean description logic for distractors too
                d_desc = str(dist_row['description'])
                if d_desc.lower() == 'nan' or not d_desc:
                     d_text = f"A dermatological image of {dist_row['condition']}"
                else:
                     d_text = f"A dermatological image of {dist_row['condition']}: {d_desc}"
                distractor_texts.append(d_text)

        # Index 0 is ALWAYS the correct answer in this list
        all_texts = [correct_text] + distractor_texts

        # 3. Process Inputs
        inputs = processor(
            text=all_texts,
            images=[image] * len(all_texts),
            padding=True, 
            truncation=True, 
            max_length=64, 
            return_tensors="pt"
        ).to(config.device)

        # 4. Get Predictions
        with torch.no_grad():
            outputs = model(**inputs)
            
            if isinstance(outputs, dict):
                logits_per_image = outputs['logits_per_image']
            else:
                logits_per_image = outputs.logits_per_image
            
            # Get logits for the single image against all texts
            logits = logits_per_image[0]
            
            # Apply Sigmoid to get readable probabilities (0.0 to 1.0)
            probs = torch.sigmoid(logits)

        # 5. Check accuracy
        predicted_idx = probs.argmax().item()
        correct = (predicted_idx == 0)
        accuracies.append(correct)

        # Track per-condition accuracy
        if correct_condition not in condition_accuracies:
            condition_accuracies[correct_condition] = []
        condition_accuracies[correct_condition].append(correct)

        # --- DEBUGGING OUTPUT ---
        if debug_counter < debug_limit:
            print(f"\n{'='*80}")
            print(f"DEBUG SAMPLE {debug_counter + 1} | Condition: {correct_condition}")
            print(f"Image Path: {image_path}")
            print(f"Result: {'‚úÖ CORRECT' if correct else '‚ùå INCORRECT'}")
            print(f"{'-'*40}")
            
            # Get Top 5 predictions for this image
            top_k = min(5, len(all_texts))
            top_probs, top_indices = torch.topk(probs, top_k)
            
            print(f"Model's Top {top_k} Guesses:")
            for score, idx in zip(top_probs, top_indices):
                idx = idx.item()
                score = score.item()
                
                is_ground_truth = (idx == 0)
                marker = "üëà TRUE LABEL" if is_ground_truth else ""
                
                # Truncate text for cleaner printing
                display_text = all_texts[idx][:100] + "..." if len(all_texts[idx]) > 100 else all_texts[idx]
                
                print(f"  [{score:.4f}] {display_text} {marker}")

            # If correct answer wasn't in top 5, print it explicitly
            if 0 not in top_indices.tolist():
                true_score = probs[0].item()
                print(f"  ...\n  [{true_score:.4f}] {correct_text} üëà TRUE LABEL (Ranked > {top_k})")
                
            print(f"{'='*80}\n")
            debug_counter += 1

    if not accuracies:
        return 0.0

    accuracy = np.mean(accuracies)
    print(f"\n‚úì Fine-grained retrieval accuracy: {accuracy:.2%}")
    
    return accuracy

In [16]:
import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path
from PIL import Image

# NOTE: This function assumes config, processor, model, and test_df are in scope.

def test_retrieval(model, processor, test_df, images_dir, n_samples=20, k_accuracy=5, debug_limit=5):
    """
    Test fine-grained retrieval capabilities with detailed logging and True Rank calculation.
    
    Args:
        k_accuracy (int): The 'K' value for Top-K accuracy (the final metric).
        debug_limit (int): Number of samples to print detailed logs for (shows Top 5).
    """
    model.eval()

    n_samples = min(n_samples, len(test_df))
    sample_indices = np.random.choice(len(test_df), n_samples, replace=False)

    accuracies = []
    condition_accuracies = {}
    debug_counter = 0

    print(f"Testing retrieval on {n_samples} samples for R@{k_accuracy} accuracy...")

    for idx in tqdm(sample_indices, desc="Testing retrieval"):
        row = test_df.iloc[idx]
        image_path = Path(images_dir) / row['image_path']
        try:
            image = Image.open(image_path).convert('RGB')
        except:
            continue

        # Prepare Candidates (Assuming previous logic)
        desc = str(row['description'])
        if desc.lower() == 'nan' or not desc:
            correct_text = f"A dermatological image of {row['condition']}"
        else:
            correct_text = f"A dermatological image of {row['condition']}: {desc}"
            
        correct_condition = row['condition']

        # Get distractors (logic omitted for brevity)
        other_conditions = test_df[test_df['condition'] != correct_condition]['condition'].unique()
        n_distractors = min(9, len(other_conditions))
        if n_distractors < 1: continue

        distractor_conditions = np.random.choice(other_conditions, n_distractors, replace=False)
        distractor_texts = []
        for dist_cond in distractor_conditions:
            dist_rows = test_df[test_df['condition'] == dist_cond]
            if len(dist_rows) > 0:
                dist_row = dist_rows.sample(1).iloc[0]
                d_desc = str(dist_row['description'])
                d_text = f"A dermatological image of {dist_row['condition']}: {d_desc}" if not d_desc.lower() == 'nan' else f"A dermatological image of {dist_row['condition']}"
                distractor_texts.append(d_text)

        all_texts = [correct_text] + distractor_texts

        # 3. Process Inputs
        inputs = processor(
            text=all_texts,
            images=[image] * len(all_texts),
            padding=True, truncation=True, max_length=64, return_tensors="pt"
        ).to(config.device)

        # 4. Get Predictions
        with torch.no_grad():
            outputs = model(**inputs)
            if isinstance(outputs, dict):
                logits_per_image = outputs['logits_per_image']
            else:
                logits_per_image = outputs.logits_per_image
            
            similarities = logits_per_image[0]
            probs = torch.sigmoid(similarities)

        # 5. Check accuracy (R@K) and Calculate Rank
        K = min(k_accuracy, len(all_texts)) 
        
        # Calculate Rank
        true_score = probs[0].item()
        # Rank is 1 + the number of scores that are strictly greater than the true score
        rank = (probs > true_score).sum().item() + 1 
        
        top_k_indices = torch.topk(probs, k=K).indices
        correct = (0 in top_k_indices.tolist()) 
        
        accuracies.append(correct)

        # 6. --- DEBUGGING OUTPUT ---
        if debug_counter < debug_limit:
            top_k_debug = min(5, len(all_texts))
            top_probs, top_indices = torch.topk(probs, top_k_debug)
            
            print(f"\n{'='*80}")
            print(f"DEBUG SAMPLE {debug_counter + 1} | Condition: {correct_condition}")
            print(f"Image Path: {image_path}")
            print(f"Result (R@{K}): {'‚úÖ CORRECT' if correct else '‚ùå INCORRECT'}") 
            print(f"TRUE LABEL RANK: {rank}th") # <--- NEW RANK OUTPUT
            print(f"{'-'*40}")
            
            print(f"Model's Top {top_k_debug} Guesses:")
            for score, idx in zip(top_probs, top_indices):
                idx = idx.item()
                score = score.item()
                
                is_ground_truth = (idx == 0)
                marker = "üëà TRUE LABEL" if is_ground_truth else ""
                
                display_text = all_texts[idx][:100] + "..." if len(all_texts[idx]) > 100 else all_texts[idx]
                
                print(f"  [{score:.4f}] {display_text} {marker}")

            if 0 not in top_indices.tolist():
                print(f"  ...\n  [{true_score:.4f}] {correct_text} üëà TRUE LABEL (Actual Rank: {rank}th)") # <--- NEW RANK OUTPUT
                
            print(f"{'='*80}\n")
            debug_counter += 1

    if not accuracies:
        return 0.0

    accuracy = np.mean(accuracies)
    print(f"\n‚úì Fine-grained retrieval accuracy (R@{k_accuracy}): {accuracy:.2%}")
    
    return accuracy

In [17]:
# Test retrieval
print("\n" + "="*80)
print("Testing Fine-Grained Retrieval")
print("="*80)

retrieval_acc = test_retrieval(model, processor, test_df, config.images_dir)

# Save summary
summary = {
    'approach': 'fine_grained',
    'num_conditions': int(df['condition'].nunique()),
    'train_samples': len(train_df),
    'final_loss': float(train_result.training_loss),
    'test_loss': float(test_results['eval_loss']),
    'retrieval_accuracy': float(retrieval_acc),
}

with open(config.output_dir / 'training_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\n" + "="*80)
print("FINE-GRAINED TRAINING COMPLETE")
print("="*80)
print(f"\nFinal metrics:")
print(f"  Test loss: {test_results['eval_loss']:.4f}")
print(f"  Retrieval accuracy: {retrieval_acc:.2%}")

Testing retrieval: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:02<00:00,  9.09it/s]
