### **Sigmoid vs Contrastive Loss**
Recommendation: Use SIGMOID loss for your use case. Here's why:

AspectSigmoid LossContrastive Loss
Best for Multi-label (multiple conditions per image)Single-label classification
Your data‚úÖ 20%+ multi-label samples‚ùå Treats each sample as single class
Output Independent probabilities per classProbability distribution (sums to 1)
TrainingEach label trained independentlyContrasts positive vs negative pairs
Class imbalance‚úÖ Handles better with per-class thresholds‚ùå Can struggle with imbalanced data

In [None]:
#!/usr/bin/env python
"""
Train SigLIP with FINE-GRAINED contrastive learning.
Uses the 66 specific condition labels with their descriptions for more precise matching.

This approach should provide:
- Better discrimination between specific conditions
- More precise retrieval accuracy
- Direct condition-level matching without hierarchical complexity
"""
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from pathlib import Path
from PIL import Image
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from tqdm import tqdm
import json
from peft import LoraConfig, get_peft_model
from datetime import datetime
from collections import defaultdict, Counter
import random

# Get the current timestamp
now = datetime.now()

timestamp_str = now.strftime("%Y%m%d%H%M")

from transformers import (
    AutoModel,
    AutoProcessor,
    EarlyStoppingCallback,
    TrainingArguments,
    Trainer,
)
from torch.utils.data import Dataset

@dataclass
class Config:
    # Model
    model_name: str = 'google/siglip-base-patch16-224'
    
    # Training
    batch_size: int = 32
    num_epochs: int = 15
    learning_rate: float = 1e-6
    weight_decay: float = 0.05  # Increased this (was 1e-4). High decay prevents overfitting.
    warmup_steps: int = 100
    
    #Lora parameters
    lora_rank: int = 32
    lora_alpha: int = 32

    # Device
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    fp16: bool = torch.cuda.is_available()
    temperature: float = 0.7
    
    # Set this to "contrastive" or "sigmoid"
    loss_type: str = "sigmoid" # or "contrastive"
    
    # finetune_mode. # Options: "FULL", "LoRA"
    finetune_mode: str = "LoRA"  
    
    # Data paths
    data_dir: Path = Path('/home/agentic-health/data')
    visual_descriptions_path: Path ='./data/condition_visual_descriptions.json'
    metadata_file: Path = './data/coarse_labeled_metadata_with_labels.csv'
    images_dir: Path = data_dir / 'scin/images'
    
    @property
    def output_dir(self) -> Path:
        return Path(f'./siglip_{self.loss_type.lower()}_{self.loss_type}_{timestamp_str}')

config = Config()
config.output_dir.mkdir(parents=True, exist_ok=True)

class DualLogger:
    def __init__(self, filepath):
        self.terminal = sys.stdout
        self.log = open(filepath, "w") # Overwrite file each run

    def write(self, message):
        self.terminal.write(message) # Write to screen
        self.log.write(message)      # Write to file
        self.log.flush()             # Ensure it saves immediately

    def flush(self):
        # Needed for compatibility with some libraries
        self.terminal.flush()
        self.log.flush()
        
log_file_path = config.output_dir / "training_log.txt"
sys.stdout = DualLogger(log_file_path)

print("="*80)
print(f"SIGLIP FINE-GRAINED {config.loss_type} LEARNING USLING {config.finetune_mode}")
print("="*80)
print(f"\nConfiguration:")
print(f"  Model: {config.model_name}")
print(f"  Output: {config.output_dir}")
print(f"  Device: {config.device}")
print(f"  Batch size: {config.batch_size}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Epochs: {config.num_epochs}")
print(f"  Lora_rank: {config.lora_rank}")   
print(f"  Output directory: {config.output_dir}")

# Load data
print("\n" + "="*80)
print("Loading Data")
print("="*80)

df = pd.read_csv(config.metadata_file)
print(f"Total samples: {len(df)}")
print(f"Fine-grained conditions: {df['condition'].nunique()}")
print(f"Coarse categories: {df['coarse_category'].nunique()}")


In [228]:
def get_labels_from_item(item):
    """
        Extract label from a dataset item (dictionary row).
        We prefer the 'description' column if available, otherwise template the condition
    """   
    if pd.notna(item['description']):
        text = item['description']
    else:
        # Fallback template if description is missing
        text = f"A dermatological photo showing {item['condition']}"
    
    return text

def stratified_split_data(all_data, val_ratio=0.15, test_ratio=0.15):
    """
    Splits data while ensuring rare classes are distributed across all sets.
    Strategy: Assigns rare classes first to ensure Val/Test coverage.
    """
    train_data, val_data, test_data = [], [], []
    
    # 1. Map each class to its samples
    class_to_samples = defaultdict(list)
    sample_to_id = {} # Track assigned samples
    
    for idx, item in enumerate(all_data):
        labels = get_labels_from_item(item)
        sample_to_id[idx] = False # False = not assigned
        for label in labels:
            class_to_samples[label].append(idx)
            
    # 2. Sort classes by rarity (rarest first)
    sorted_classes = sorted(class_to_samples.keys(), key=lambda k: len(class_to_samples[k]))
    
    for cls in sorted_classes:
        indices = class_to_samples[cls]
        # Filter for unassigned indices
        available_indices = [i for i in indices if not sample_to_id[i]]
        random.shuffle(available_indices)
        
        n = len(available_indices)
        # Ensure at least 1 sample in Val/Test if possible
        n_val = max(1, int(n * val_ratio)) if n > 1 else 0
        n_test = max(1, int(n * test_ratio)) if n > 2 else 0
        
        # Allocate indices
        val_idxs = available_indices[:n_val]
        test_idxs = available_indices[n_val:n_val+n_test]
        train_idxs = available_indices[n_val+n_test:]
        
        for i in val_idxs:
            if not sample_to_id[i]:
                val_data.append(all_data[i])
                sample_to_id[i] = True
        for i in test_idxs:
            if not sample_to_id[i]:
                test_data.append(all_data[i])
                sample_to_id[i] = True
        for i in train_idxs:
            if not sample_to_id[i]:
                train_data.append(all_data[i])
                sample_to_id[i] = True

    # 3. Assign any leftovers
    leftovers = [all_data[i] for i in range(len(all_data)) if not sample_to_id[i]]
    for item in leftovers:
        r = random.random()
        if r < val_ratio: val_data.append(item)
        elif r < val_ratio + test_ratio: test_data.append(item)
        else: train_data.append(item)

    return train_data, val_data, test_data

def balance_training_data(train_data, target_min_samples=30):
    """
    Oversamples rare classes in the training set to reach a minimum count.
    """
    print(f"\nBalancing training data (target min: {target_min_samples})...")
    class_counts = Counter()
    class_samples = defaultdict(list)
    
    for item in train_data:
        labels = get_labels_from_item(item)
        for label in labels:
            class_counts[label] += 1
            class_samples[label].append(item)
            
    balanced_data = list(train_data)
    added_count = 0
    
    for cls, count in class_counts.items():
        if count < target_min_samples:
            needed = target_min_samples - count
            available = class_samples[cls]
            if not available: continue
            
            # Resample with replacement
            extras = random.choices(available, k=needed)
            balanced_data.extend(extras)
            added_count += needed

    random.shuffle(balanced_data)
    print(f"  Added {added_count} samples via oversampling.")
    return balanced_data

In [229]:
import numpy as np
import random
from PIL import Image, ImageEnhance, ImageFilter, ImageOps
import uuid
import pandas as pd
from pathlib import Path
import os

def generate_synthetic_data_pil(image, num_samples=5, use_edge_enhance=False):
    """
    Creates synthetic medical images using Pillow (PIL).
    
    Args:
        image (PIL.Image): Input image.
        num_samples (int): Number of synthetic versions to create.
        use_edge_enhance (bool): If True, applies edge enhancement (like Sobel).
        
    Returns:
        list: A list of augmented PIL images.
    """
    synthetic_images = []
    w, h = image.size

    for _ in range(num_samples):
        aug_img = image.copy()
        
        # --- A. GEOMETRIC TRANSFORMS ---
        
        # 1. Random Flip
        if random.random() > 0.5:
            # Mirror (Horizontal)
            if random.random() > 0.5:
                aug_img = ImageOps.mirror(aug_img)
            # Flip (Vertical) - Valid for skin lesions
            else:
                aug_img = ImageOps.flip(aug_img)

        # 2. Random Rotation (-30 to 30 degrees)
        if random.random() > 0.5:
            angle = random.uniform(-30, 30)
            aug_img = aug_img.rotate(angle, resample=Image.BICUBIC, expand=False)

        # 3. Random Resized Crop (Zoom effect)
        if random.random() > 0.3:
            scale = random.uniform(0.75, 0.95)
            new_w, new_h = int(w * scale), int(h * scale)
            
            left = random.randint(0, w - new_w)
            top = random.randint(0, h - new_h)
            
            aug_img = aug_img.crop((left, top, left + new_w, top + new_h))
            aug_img = aug_img.resize((w, h), resample=Image.BICUBIC)

        # --- B. PIXEL TRANSFORMS ---

        # 4. Gaussian Blur
        if random.random() > 0.7:
            radius = random.uniform(1, 2)
            aug_img = aug_img.filter(ImageFilter.GaussianBlur(radius))

        # 5. Gaussian Noise
        if random.random() > 0.7:
            # Convert to numpy to add noise
            img_arr = np.array(aug_img).astype(float)
            noise = np.random.normal(0, 15, img_arr.shape)
            img_arr = np.clip(img_arr + noise, 0, 255).astype('uint8')
            aug_img = Image.fromarray(img_arr)

        # 6. Color Jitter (Brightness/Contrast/Saturation)
        if random.random() > 0.5:
            # Brightness
            enhancer = ImageEnhance.Brightness(aug_img)
            aug_img = enhancer.enhance(random.uniform(0.8, 1.2))
            
            # Contrast
            enhancer = ImageEnhance.Contrast(aug_img)
            aug_img = enhancer.enhance(random.uniform(0.9, 1.1))
            
            # Saturation
            enhancer = ImageEnhance.Color(aug_img)
            aug_img = enhancer.enhance(random.uniform(0.9, 1.1))

        # --- C. STRUCTURAL TRANSFORMS ---
        
        # 7. Edge Enhancement (Alternative to Sobel)
        if use_edge_enhance and random.random() > 0.8:
            # Convert to grayscale and find edges
            gray = aug_img.convert("L").filter(ImageFilter.FIND_EDGES)
            # Convert back to RGB
            edge_img = gray.convert("RGB")
            # Blend with original (50%)
            aug_img = Image.blend(aug_img, edge_img, alpha=0.5)

        synthetic_images.append(aug_img)

    return synthetic_images

In [230]:
import pandas as pd
from pathlib import Path
import uuid
from PIL import Image

def augment_rare_classes(df, images_dir, output_dir, target_count=30):
    """
    Identifies rare classes and generates synthetic images to reach target_count.
    Skips generation if sufficient synthetic data already exists in output_dir.
    
    Args:
        df (pd.DataFrame): Original training dataframe.
        images_dir (str/Path): Directory containing original images.
        output_dir (str/Path): Directory to save synthetic images and metadata.
        target_count (int): Minimum samples desired per class.
        
    Returns:
        pd.DataFrame: A DataFrame containing ALL synthetic data (previously existing + newly generated).
    """
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # 1. Load existing synthetic metadata if it exists
    metadata_path = output_path / "synthetic_metadata.csv"
    if metadata_path.exists():
        try:
            existing_syn_df = pd.read_csv(metadata_path)
            print(f"Loaded metadata for {len(existing_syn_df)} existing synthetic samples.")
        except Exception as e:
            print(f"Warning: Could not read existing metadata ({e}). Starting fresh.")
            existing_syn_df = pd.DataFrame()
    else:
        existing_syn_df = pd.DataFrame()

    # Identify rare classes based on ORIGINAL data
    counts = df['condition'].value_counts()
    rare_classes = counts[counts < target_count].index.tolist()
    
    new_rows = []
    
    print(f"Checking {len(rare_classes)} rare classes (Target: {target_count})...")

    for condition in rare_classes:
        # Count original samples
        original_count = len(df[df['condition'] == condition])
        
        # Count existing synthetic samples (if any)
        if not existing_syn_df.empty:
            syn_count = len(existing_syn_df[existing_syn_df['condition'] == condition])
        else:
            syn_count = 0
            
        # Calculate how many MORE we need
        total_current = original_count + syn_count
        needed = target_count - total_current
        
        if needed <= 0:
            # print(f"  ‚úì {condition}: Has {total_current} (Skipping)")
            continue
            
        print(f"  - {condition}: Generating {needed} new images (Original: {original_count}, Existing Syn: {syn_count})...")
        
        # Get existing samples to augment
        existing_samples = df[df['condition'] == condition]
        if existing_samples.empty: continue

        generated_count = 0
        while generated_count < needed:
            # Pick random source image
            source_row = existing_samples.sample(1).iloc[0]
            
            img_path = Path(images_dir) / source_row['image_path']
            try:
                original_img = Image.open(img_path).convert("RGB")
            except Exception as e:
                continue

            # Generate synthetic version
            # Assuming generate_synthetic_data_pil is defined as before
            synthetic_imgs = generate_synthetic_data_pil(original_img, num_samples=1)
            syn_img = synthetic_imgs[0]
            
            # --- METADATA ---
            new_id = str(uuid.uuid4())
            
            # Sanitize filename
            safe_condition_name = condition.replace(' ', '_').replace('/', '_').replace('\\', '_')
            new_filename = f"syn_{safe_condition_name}_{new_id[:8]}.png"
            save_path = output_path / new_filename
            
            # Save Image
            syn_img.save(save_path)
            
            # Create Metadata Row
            new_row = source_row.copy()
            new_row['image_id'] = new_id
            new_row['image_path'] = new_filename # Just filename, relative to output_dir
            new_row['split'] = 'train'
            
            # Add helper flag (optional)
            new_row['is_synthetic'] = True
            
            new_rows.append(new_row)
            generated_count += 1

    # 2. Combine and Save
    if new_rows:
        new_df = pd.DataFrame(new_rows)
        if not existing_syn_df.empty:
            # Align columns before concat to prevent errors
            # (Optional safety step if schemas drift)
            final_syn_df = pd.concat([existing_syn_df, new_df], ignore_index=True)
        else:
            final_syn_df = new_df
            
        # Save updated metadata
        final_syn_df.to_csv(metadata_path, index=False)
        print(f"Saved updated metadata with {len(new_rows)} new samples to {metadata_path}")
        return final_syn_df
    else:
        print("No new images generated.")
        return existing_syn_df

In [231]:
# ==============================================================================
#  IMPROVED DATA LOADING SECTION (Replaces original loading logic)
# ==============================================================================
print("\n" + "="*80)
print("Loading and Processing Data")
print("="*80)

# 1. Load raw CSV
df = pd.read_csv(config.metadata_file)
print(f"Raw data loaded: {len(df)} samples")

# 2. Convert to list of dictionaries for processing
all_data = df.to_dict('records')

# 3. Perform Stratified Split (Fixes missing classes in Val/Test)
train_raw, val_data, test_data = stratified_split_data(
    all_data, 
    val_ratio=0.15, 
    test_ratio=0.15
)

# 4. Balance Training Data (Fixes class imbalance)
# Target 30 samples ensures the model sees rare classes ~1 per batch on average
train_balanced = balance_training_data(train_raw, target_min_samples=30)

# 5. Convert back to DataFrames for compatibility
print(f"Training set size before adding synthetic data: {len(train_df)}")
train_df = pd.DataFrame(train_balanced)
synthetic_dir = "./scin_synthetic"
synthetic_df = augment_rare_classes(
    train_df, 
    config.images_dir, 
    synthetic_dir, 
    target_count=30
)

# Merge
train_df = pd.concat([train_df, synthetic_df], ignore_index=True)
print(f"New training set size after adding synthetic data: {len(train_df)}")

val_df = pd.DataFrame(val_data)
test_df = pd.DataFrame(test_data)

In [232]:
import pandas as pd
import numpy as np
from collections import Counter

def print_comprehensive_report(df, train_df, val_df, test_df):
    """
    Print a comprehensive data analysis report for SigLip fine-tuning
    """
    
    print("="*80)
    print(" DATASET ANALYSIS FOR SIGLIP FINE-TUNING")
    print("="*80)
    
    # Overall statistics
    print("\n1. OVERALL DATASET STATISTICS")
    print("-" * 80)
    print(f"Total samples: {len(df):,}")
    print(f"  - Training:   {len(train_df):,} ({len(train_df)/len(df)*100:.1f}%)")
    print(f"  - Validation: {len(val_df):,} ({len(val_df)/len(df)*100:.1f}%)")
    print(f"  - Test:       {len(test_df):,} ({len(test_df)/len(df)*100:.1f}%)")
    print(f"\nUnique conditions: {df['condition'].nunique()}")
    print(f"Unique coarse categories: {df['coarse_category'].nunique()}")
    
    # Label distribution by split
    print("\n2. LABEL DISTRIBUTION BY SPLIT")
    print("-" * 80)
    
    # Get all unique conditions
    all_conditions = df['condition'].unique()
    
    # Create distribution table
    distribution_data = []
    for condition in sorted(all_conditions):
        train_count = len(train_df[train_df['condition'] == condition])
        val_count = len(val_df[val_df['condition'] == condition])
        test_count = len(test_df[test_df['condition'] == condition])
        total = train_count + val_count + test_count
        
        distribution_data.append({
            'Condition': condition,
            'Train': train_count,
            'Val': val_count,
            'Test': test_count,
            'Total': total
        })
    
    dist_df = pd.DataFrame(distribution_data).sort_values('Total', ascending=False)
    
    # Print top 15 most common conditions
    print("\nTop 15 Most Common Conditions:")
    print(f"{'Condition':<40} {'Train':>8} {'Val':>8} {'Test':>8} {'Total':>8}")
    print("-" * 80)
    for _, row in dist_df.head(15).iterrows():
        print(f"{row['Condition']:<40} {row['Train']:>8} {row['Val']:>8} {row['Test']:>8} {row['Total']:>8}")
    
    # Print conditions with imbalanced splits (potential issues)
    print("\n3. POTENTIAL TRAINING ISSUES")
    print("-" * 80)
    
    # Classes with very few samples
    rare_classes = dist_df[dist_df['Total'] < 10]
    print(f"\n‚ö†Ô∏è  Conditions with <10 total samples: {len(rare_classes)}")
    if len(rare_classes) > 0:
        print("(These may cause overfitting or poor generalization)")
        for _, row in rare_classes.iterrows():
            print(f"  - {row['Condition']:<40} (Total: {row['Total']})")
    
    # Classes missing from validation or test
    missing_val = dist_df[dist_df['Val'] == 0]
    missing_test = dist_df[dist_df['Test'] == 0]
    
    print(f"\n‚ö†Ô∏è  Conditions missing from validation set: {len(missing_val)}")
    if len(missing_val) > 0 and len(missing_val) <= 10:
        for _, row in missing_val.iterrows():
            print(f"  - {row['Condition']}")
    
    print(f"\n‚ö†Ô∏è  Conditions missing from test set: {len(missing_test)}")
    if len(missing_test) > 0 and len(missing_test) <= 10:
        for _, row in missing_test.iterrows():
            print(f"  - {row['Condition']}")
    
    # Multi-label analysis
    print("\n4. MULTI-LABEL ANALYSIS")
    print("-" * 80)
    
    def count_labels(df_subset):
        """Count images by number of labels"""
        import ast
        label_counts = []
        for all_conds in df_subset['all_conditions']:
            if pd.notna(all_conds):
                try:
                    conds = ast.literal_eval(all_conds) if isinstance(all_conds, str) else all_conds
                    label_counts.append(len(conds))
                except:
                    label_counts.append(1)
            else:
                label_counts.append(1)
        return Counter(label_counts)
    
    train_label_dist = count_labels(train_df)
    val_label_dist = count_labels(val_df)
    test_label_dist = count_labels(test_df)
    
    print(f"{'Labels per Image':<20} {'Train':>10} {'Val':>10} {'Test':>10}")
    print("-" * 80)
    all_label_counts = sorted(set(list(train_label_dist.keys()) + list(val_label_dist.keys()) + list(test_label_dist.keys())))
    for num_labels in all_label_counts:
        print(f"{num_labels} label(s): {train_label_dist.get(num_labels, 0):>10} {val_label_dist.get(num_labels, 0):>10} {test_label_dist.get(num_labels, 0):>10}")
    
    # Calculate multi-label percentage
    multi_label_train = sum(count for labels, count in train_label_dist.items() if labels > 1)
    multi_label_pct = multi_label_train / len(train_df) * 100 if len(train_df) > 0 else 0
    print(f"\nMulti-label samples in training: {multi_label_train:,} ({multi_label_pct:.1f}%)")
    
    # Coarse category distribution
    print("\n5. COARSE CATEGORY DISTRIBUTION")
    print("-" * 80)
    print(f"{'Category':<40} {'Train':>8} {'Val':>8} {'Test':>8} {'Total':>8}")
    print("-" * 80)
    
    for category in sorted(df['coarse_category'].unique()):
        train_count = len(train_df[train_df['coarse_category'] == category])
        val_count = len(val_df[val_df['coarse_category'] == category])
        test_count = len(test_df[test_df['coarse_category'] == category])
        total = train_count + val_count + test_count
        print(f"{category:<40} {train_count:>8} {val_count:>8} {test_count:>8} {total:>8}")
    
    # Class imbalance metrics
    print("\n6. CLASS IMBALANCE METRICS")
    print("-" * 80)
    
    train_condition_counts = train_df['condition'].value_counts()
    max_samples = train_condition_counts.max()
    min_samples = train_condition_counts.min()
    imbalance_ratio = max_samples / min_samples if min_samples > 0 else float('inf')
    
    print(f"Training set class imbalance:")
    print(f"  - Most common class:  {train_condition_counts.index[0]:<40} ({max_samples} samples)")
    print(f"  - Least common class: {train_condition_counts.index[-1]:<40} ({min_samples} samples)")
    print(f"  - Imbalance ratio:    {imbalance_ratio:.1f}:1")
    print(f"\nMedian samples per class: {train_condition_counts.median():.0f}")
    print(f"Mean samples per class:   {train_condition_counts.mean():.1f}")
    
    # Recommendations
    print("\n7. RECOMMENDATIONS FOR SIGLIP FINE-TUNING")
    print("="*80)
    
    print("\nüìä Data Considerations:")
    if imbalance_ratio > 100:
        print("  ‚ö†Ô∏è  SEVERE class imbalance detected (>100:1)")
        print("     ‚Üí Consider: class weights, focal loss, or oversampling rare classes")
    elif imbalance_ratio > 20:
        print("  ‚ö†Ô∏è  Significant class imbalance (>20:1)")
        print("     ‚Üí Consider: balanced sampling or class weights")
    
    if multi_label_pct > 20:
        print(f"\n  ‚ÑπÔ∏è  High multi-label percentage ({multi_label_pct:.1f}%)")
        print("     ‚Üí Use SIGMOID loss (not contrastive) for multi-label classification")
    else:
        print(f"\n  ‚ÑπÔ∏è  Low multi-label percentage ({multi_label_pct:.1f}%)")
        print("     ‚Üí Either SIGMOID or CONTRASTIVE loss can work")
    
    if len(rare_classes) > 0:
        print(f"\n  ‚ö†Ô∏è  {len(rare_classes)} classes with <10 samples")
        print("     ‚Üí These may not learn well; consider grouping into coarse categories")
    
    print("\n" + "="*80)

# Call the report function
print_comprehensive_report(df, train_df, val_df, test_df)

In [233]:
class CreateDataset(Dataset):
    """
    Dataset for fine-grained contrastive learning with SigLIP.
    Prioritizes verbose descriptions and robust image handling.
    """

    def __init__(self, df: pd.DataFrame, images_dir: Path, processor, mode='train'):
        self.df = df.reset_index(drop=True)
        self.images_dir = Path(images_dir)
        self.processor = processor
        self.mode = mode
        
        # --- NEW: Load Synthetic Visual Descriptions ---
        self.visual_descriptions = {}
        try:
            with open(config.visual_descriptions_path, 'r') as f:
                self.visual_descriptions = json.load(f)
            print(f"[{mode}] Loaded {len(self.visual_descriptions)} visual descriptions.")
        except FileNotFoundError:
            print(f"[{mode}] Warning: Visual descriptions file not found. Using basic labels.")
    

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load Image
        image_path = self.images_dir / row['image_path']
        try:
            image = Image.open(image_path).convert('RGB')
        except:
            image = Image.new('RGB', (224, 224), color='black')

        # --- NEW TEXT GENERATION ---
        condition = row['condition']
        
        # 1. Try to get the RICH visual description (e.g. "Red, scaly plaques...")
        visual_desc = self.visual_descriptions.get(condition, "")
        
        # 2. Fallback to dataset description
        raw_desc = str(row.get('description', ''))
        if raw_desc.lower() == 'nan': raw_desc = ""
        
        # Construct the prompt
        if self.mode == 'train':
            # Training: Give the model the rich visual features
            if visual_desc:
                text = f"A photo of {condition}: {visual_desc}"
            elif raw_desc:
                text = f"A photo of {condition}: {raw_desc}"
            else:
                text = f"A dermatological photo of {condition}"
        else:
            # Testing: Keep it standard but descriptive
            text = f"A photo of {condition}: {visual_desc if visual_desc else raw_desc}"

        # Label ID for Multi-Positive Loss
        label_id = hash(condition)

        inputs = self.processor(
            text=[text],
            images=image,
            padding="max_length",
            truncation=True,
            max_length=64,
            return_tensors="pt"
        )
        
        # Remove batch dimension
        for key in inputs:
            inputs[key] = inputs[key].squeeze(0)

        # Remove batch dimension (1, C, H, W) -> (C, H, W)
        return inputs

In [None]:
class SigLipModel(nn.Module):
    """SigLIP model for contrastive learning."""

    def __init__(self, model_name: str, temperature: float = 0.07):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.temperature = temperature
        self.logit_bias = nn.Parameter(torch.zeros(1))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / temperature))

    def forward(self, input_ids=None, pixel_values=None, attention_mask=None, **kwargs):
        outputs = self.model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            return_dict=True
        )

        image_embeds = outputs.vision_model_output.pooler_output
        text_embeds = outputs.text_model_output.pooler_output

        image_embeds = F.normalize(image_embeds, dim=-1)
        text_embeds = F.normalize(text_embeds, dim=-1)

        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_embeds @ text_embeds.t()
        logits_per_text = logits_per_image.t()

        batch_size = image_embeds.shape[0]
        
        # --- Switchable Loss Calculation ---
        if config.loss_type == "contrastive":
            labels = torch.arange(batch_size, device=image_embeds.device)
            loss_i = F.cross_entropy(logits_per_image, labels)
            loss_t = F.cross_entropy(logits_per_text, labels)
            loss = (loss_i + loss_t) / 2
        elif config.loss_type == "sigmoid":
            labels = torch.eye(batch_size, device=image_embeds.device)
            loss_i = F.binary_cross_entropy_with_logits(logits_per_image, labels)
            loss_t = F.binary_cross_entropy_with_logits(logits_per_text, labels)
            loss = (loss_i + loss_t) / 2.0
        else:
            raise ValueError(f"Unknown loss_type: {config.loss_type}")

        return {
            'loss': loss,
            'logits_per_image': logits_per_image,
            'logits_per_text': logits_per_text,
            'image_embeds': image_embeds,
            'text_embeds': text_embeds,
        }

In [235]:
print("\n" + "="*80)
print("Loading Model and Creating Datasets")
print("="*80)

processor = AutoProcessor.from_pretrained(config.model_name)
model = SigLipModel(config.model_name, config.temperature)
model.to(config.device)

print(f"‚úì Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters")

# Create datasets
# Note: Ensure your df has 'condition' and 'description' columns
train_dataset = CreateDataset(train_df, config.images_dir, processor, mode='train')
val_dataset = CreateDataset(val_df, config.images_dir, processor, mode='val')
test_dataset = CreateDataset(test_df, config.images_dir, processor, mode='test')

print(f"‚úì Datasets created")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val: {len(val_dataset)} samples")
print(f"  Test: {len(test_dataset)} samples")

# Show sample texts
print("\nSample fine-grained texts:")
for i in range(3):
    sample_text = train_df.iloc[i]['description']
    condition = train_df.iloc[i]['condition']
    print(f"  [{condition}] {sample_text[:80]}...")

In [240]:
# Custom trainer
class SigLipTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # Extract labels but IGNORE them in the matrix construction
        label_ids = inputs.pop("label_ids", None) # Keep this to pop the tensor out of inputs
        
        outputs = model(**inputs)
        
        # Logits extraction (robust)
        if isinstance(outputs, dict):
            logits = outputs["logits_per_image"]
        else:
            logits = outputs.logits_per_image
            
        batch_size = logits.shape[0]
        device = logits.device
        
        # --- BIAS INCORPORATION FIX ---
        # 1. Retrieve the logit_bias parameter safely.
        #    We use getattr to safely handle models that might not have this parameter.
        logit_bias = getattr(model, 'logit_bias', None)
        
        if logit_bias is not None:
            # 2. Apply the bias to the logits
            #    The bias is broadcasted across the entire [B, B] logit matrix.
            logits = logits + logit_bias
        # ----------------------------
        
        # --- CRITICAL FIX: DIAGONAL-ONLY LABELS ---
        # We must ignore label_ids and enforce the standard SigLIP diagonal.
        # This stops the conflicting gradients (Image A vs Image B of same class).
        labels = torch.eye(batch_size, device=device, dtype=logits.dtype)
        
        # --- POSITIVE WEIGHT CALCULATION ---
        # The weight remains, forcing the model to value the diagonal pair highly.
        pos_weight_value = float(batch_size - 1) if batch_size > 1 else 1.0
        pos_weight = torch.tensor([pos_weight_value], device=device, dtype=logits.dtype)
        
        # --- WEIGHTED SIGMOID LOSS ---
        # This is now standard SigLIP loss, but with heavy positive weighting.
        loss = F.binary_cross_entropy_with_logits(
            logits, 
            labels, 
            pos_weight=pos_weight
        )
        
        return (loss, outputs) if return_outputs else loss


    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        """Custom prediction step for contrastive learning."""
        with torch.no_grad():
            outputs = model(**inputs)
            loss = outputs['loss']

        if prediction_loss_only:
            return (loss, None, None)

        return (loss, outputs['logits_per_image'].cpu(), None)

## 1. Loss Analysis Diagnosis Table

| Training Loss | Validation Loss | Diagnosis | Action Required |
|---------------|-----------------|-----------|-----------------|
| High | High | **Underfitting** | Increase model complexity, train longer |
| Low | High (Increasing) | **Overfitting** | Add regularization (Dropout/L2), get more data, Early Stopping |
| Low | Low | **Good Fit** | Save model weights |
| High | Low | **Data Issue** | Check regularization, data augmentation, or data split |

## 2. Loss Value Interpretation Table

| Value Range | Meaning | Status |
|-------------|---------|--------|
| > 0.60 | **High** | Underfitting / Starting: Model is guessing |
| 0.40 - 0.60 | **Medium** | Learning: Model is picking up patterns |
| < 0.30 | **Low** | Good Fit: Model is confident in its predictions |

## 3. Additional Metrics for Multi-Label Classification

| Metric | Good Range | Interpretation |
|--------|------------|----------------|
| **Per-class Accuracy** | > 0.70 | Percentage of correct predictions per condition |
| **Macro F1-Score** | > 0.60 | Average F1 across all classes (handles imbalance) |
| **Hamming Loss** | < 0.20 | Fraction of wrong labels (lower is better) |
| **Subset Accuracy** | > 0.40 | Exact match of all labels (strict metric) |

## 4. SigLip Fine-tuning Specific Guidelines

| Training Stage | Expected Train Loss | Expected Val Loss | What to Look For |
|----------------|-------------------|-------------------|------------------|
| **Early (Epoch 1-3)** | 0.80 - 1.20 | 0.70 - 1.10 | Both should decrease rapidly |
| **Mid (Epoch 4-8)** | 0.40 - 0.70 | 0.45 - 0.75 | Val loss should track train loss closely |
| **Late (Epoch 9-15)** | 0.20 - 0.45 | 0.30 - 0.55 | Gap should be small (<0.10) |
| **Converged** | < 0.30 | < 0.40 | Both stable, gap <0.10 means good generalization |

## 5. Warning Signs During Training

| Observation | Problem | Solution |
|-------------|---------|----------|
| Val loss increases after epoch 5 | **Overfitting** | Enable early stopping, increase dropout |
| Both losses stuck > 0.70 | **Underfitting** | Decrease LoRA rank, increase learning rate |
| Val loss oscillates wildly | **Unstable training** | Reduce learning rate, increase batch size |
| Train loss = 0.0, Val loss > 1.0 | **Severe overfitting** | Add strong regularization, reduce model capacity |
| Gap between train/val > 0.20 | **Poor generalization** | More data augmentation, check data leakage |

## 6. Recommended Actions Based on Your Dataset

Given your characteristics:
- **79.4% multi-label samples** ‚Üí Focus on Hamming Loss and Macro F1
- **758:1 class imbalance** ‚Üí Monitor per-class metrics, not just overall accuracy
- **9 rare classes (<10 samples)** ‚Üí These will have high loss; group to coarse categories

| Target Metrics (Epoch 15) | Realistic Goal | Stretch Goal |
|---------------------------|----------------|--------------|
| **Validation Loss** | < 0.40 | < 0.30 |
| **Macro F1-Score** | > 0.55 | > 0.65 |
| **Top-3 Accuracy** | > 0.80 | > 0.90 |
| **Hamming Loss** | < 0.25 | < 0.15 |

In [237]:

def compute_metrics(eval_pred):
    # predictions (logits) are tuple: (logits_per_image, logits_per_text)
    logits = eval_pred.predictions 
    # Robustly handle tuple unpacking
    if isinstance(logits, tuple):
        logits = logits[0]
    
    # For evaluation, we treat the diagonal as the ground truth class
    # (retrieval task: given image i, find text i)
    batch_size = logits.shape[0]
    labels = np.arange(batch_size)
    
    preds = np.argmax(logits, axis=1)
    
    accuracy = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro', zero_division=0)
    
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }


# Training arguments
training_args = TrainingArguments(
    output_dir=str(config.output_dir),
    num_train_epochs=config.num_epochs,
    per_device_train_batch_size=config.batch_size,
    per_device_eval_batch_size=config.batch_size,
    learning_rate=config.learning_rate,
    weight_decay=config.weight_decay,
    warmup_steps=config.warmup_steps,
    fp16=config.fp16,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    save_total_limit=2, #Was 3
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    remove_unused_columns=False,
    dataloader_num_workers=0,
    report_to="none",
    prediction_loss_only=False, # CRITICAL for compute_metrics to run!
)

# Apply LoRA
lora_config = LoraConfig(
        r=config.lora_alpha,
        lora_alpha=config.lora_alpha,
        target_modules=["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"],
        lora_dropout=0.1,
        bias="none",
    )

# Apply LoRA if specified
if config.finetune_mode == "LoRA":
    print("\nApplying LoRA fine-tuning...")
    model = get_peft_model(model, lora_config)  

# 1. Create the callback instance
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=5)

trainer = SigLipTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=[early_stopping_callback],
    compute_metrics=compute_metrics,
)


print("\n" + "="*80)
print("STARTING FINE-GRAINED TRAINING")
print("="*80)
print("Training with fine-grained condition labels...")
print("Focus on discriminating between 66 specific conditions")
print()

train_result = trainer.train()

print(f"\n‚úì Training complete! Final loss: {train_result.training_loss:.4f}")

# Evaluate
test_results = trainer.evaluate(test_dataset)
print(f"‚úì Test loss: {test_results['eval_loss']:.4f}")

# Save
final_path = config.output_dir / 'final_model'
trainer.save_model(str(final_path))
processor.save_pretrained(str(final_path))
print(f"‚úì Model saved to: {final_path}")



Epoch,Training Loss,Validation Loss
1,1.3354,0.741521
2,1.3341,0.734111
3,1.3291,0.726259
4,1.3262,0.719635
5,1.3217,0.714725
6,1.3177,0.711181
7,1.3121,0.709177
8,1.3077,0.707673
9,1.3014,0.706846
10,1.2974,0.706235


In [238]:
def test_retrieval(model, processor, test_df, images_dir, n_samples=20, debug_limit=5):
    """
    Test fine-grained retrieval capabilities with detailed debugging.
    
    Args:
        debug_limit (int): Number of samples to print detailed logs for.
    """
    model.eval()

    # Ensure we don't sample more than available
    n_samples = min(n_samples, len(test_df))
    sample_indices = np.random.choice(len(test_df), n_samples, replace=False)

    accuracies = []
    condition_accuracies = {}
    debug_counter = 0

    print("\n" + "="*80)
    print(f"New Testing retrieval on {n_samples} samples...")
    print(f"Detailed debug info will be printed for the first {debug_limit} samples.\n")
    print("\n" + "="*80)
    
    for idx in tqdm(sample_indices, desc="Testing retrieval"):
        row = test_df.iloc[idx]

        # 1. Load Image
        image_path = images_dir / row['image_path']
        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            continue

        # 2. Prepare Candidates
        desc = str(row['description'])
        if desc.lower() == 'nan' or not desc:
            correct_text = f"A dermatological image of {row['condition']}"
        else:
            correct_text = f"A dermatological image of {row['condition']}: {desc}"
            
        correct_condition = row['condition']

        # Get distractors
        other_conditions = test_df[test_df['condition'] != correct_condition]['condition'].unique()
        
        n_distractors = min(9, len(other_conditions))
        if n_distractors < 1: continue

        distractor_conditions = np.random.choice(other_conditions, n_distractors, replace=False)

        distractor_texts = []
        for dist_cond in distractor_conditions:
            dist_rows = test_df[test_df['condition'] == dist_cond]
            if len(dist_rows) > 0:
                dist_row = dist_rows.sample(1).iloc[0]
                # Clean description logic for distractors too
                d_desc = str(dist_row['description'])
                if d_desc.lower() == 'nan' or not d_desc:
                     d_text = f"A dermatological image of {dist_row['condition']}"
                else:
                     d_text = f"A dermatological image of {dist_row['condition']}: {d_desc}"
                distractor_texts.append(d_text)

        # Index 0 is ALWAYS the correct answer in this list
        all_texts = [correct_text] + distractor_texts

        # 3. Process Inputs
        inputs = processor(
            text=all_texts,
            images=[image] * len(all_texts),
            padding=True, 
            truncation=True, 
            max_length=64, 
            return_tensors="pt"
        ).to(config.device)

        # 4. Get Predictions
        with torch.no_grad():
            outputs = model(**inputs)
            
            if isinstance(outputs, dict):
                logits_per_image = outputs['logits_per_image']
            else:
                logits_per_image = outputs.logits_per_image
            
            # Get logits for the single image against all texts
            logits = logits_per_image[0]
            
            # Apply Sigmoid to get readable probabilities (0.0 to 1.0)
            probs = torch.sigmoid(logits)

        # 5. Check accuracy
        predicted_idx = probs.argmax().item()
        correct = (predicted_idx == 0)
        accuracies.append(correct)

        # Track per-condition accuracy
        if correct_condition not in condition_accuracies:
            condition_accuracies[correct_condition] = []
        condition_accuracies[correct_condition].append(correct)

        # --- DEBUGGING OUTPUT ---
        if debug_counter < debug_limit:
            print(f"\n{'='*80}")
            print(f"DEBUG SAMPLE {debug_counter + 1} | Condition: {correct_condition}")
            print(f"Image Path: {image_path}")
            print(f"Result: {'‚úÖ CORRECT' if correct else '‚ùå INCORRECT'}")
            print(f"{'-'*40}")
            
            # Get Top 5 predictions for this image
            top_k = min(5, len(all_texts))
            top_probs, top_indices = torch.topk(probs, top_k)
            
            print(f"Model's Top {top_k} Guesses:")
            for score, idx in zip(top_probs, top_indices):
                idx = idx.item()
                score = score.item()
                
                is_ground_truth = (idx == 0)
                marker = "üëà TRUE LABEL" if is_ground_truth else ""
                
                # Truncate text for cleaner printing
                display_text = all_texts[idx][:100] + "..." if len(all_texts[idx]) > 100 else all_texts[idx]
                
                print(f"  [{score:.4f}] {display_text} {marker}")

            # If correct answer wasn't in top 5, print it explicitly
            if 0 not in top_indices.tolist():
                true_score = probs[0].item()
                print(f"  ...\n  [{true_score:.4f}] {correct_text} üëà TRUE LABEL (Ranked > {top_k})")
                
            print(f"{'='*80}\n")
            debug_counter += 1

    if not accuracies:
        return 0.0

    accuracy = np.mean(accuracies)
    print(f"\n‚úì Fine-grained retrieval accuracy: {accuracy:.2%}")
    
    return accuracy

In [239]:
# Test retrieval
print("\n" + "="*80)
print("Testing Fine-Grained Retrieval")
print("="*80)

retrieval_acc = test_retrieval(model, processor, test_df, config.images_dir)

# Save summary
summary = {
    'approach': 'fine_grained',
    'num_conditions': int(df['condition'].nunique()),
    'train_samples': len(train_df),
    'final_loss': float(train_result.training_loss),
    'test_loss': float(test_results['eval_loss']),
    'retrieval_accuracy': float(retrieval_acc),
}

with open(config.output_dir / 'training_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\n" + "="*80)
print("FINE-GRAINED TRAINING COMPLETE")
print("="*80)
print("\nAdvantages of fine-grained approach:")
print("  ‚Ä¢ Direct discrimination between 66 specific conditions")
print("  ‚Ä¢ No hierarchical complexity")
print("  ‚Ä¢ Focused on exact condition matching")
print("  ‚Ä¢ Should improve retrieval accuracy")
print(f"\nFinal metrics:")
print(f"  Test loss: {test_results['eval_loss']:.4f}")
print(f"  Retrieval accuracy: {retrieval_acc:.2%}")

Testing retrieval: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:01<00:00, 10.22it/s]
