In [1]:
# Cell 1: Imports and Configuration
import torch
import os
import torch.nn.functional as F
from datasets import load_dataset, Dataset
from PIL import Image
from transformers import (
    AutoProcessor,
    AutoModel,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model
from collections import defaultdict
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import requests
from io import BytesIO
from tqdm import tqdm
import pickle
from pathlib import Path

# Configuration
CONFIG = {
    'MODEL_ID': "google/siglip-base-patch16-224",
    'OUTPUT_DIR': "./siglip-scin-lora",
    'DATA_DIR': "./data/scin_cache",  # Local cache directory
    'BATCH_SIZE': 16,
    'LEARNING_RATE': 1e-4,
    'LORA_RANK': 16,
    'LORA_ALPHA': 16,
    'MAX_STEPS': 500,
    'LOSS_TYPE': "sigmoid",  # or "contrastive"
    'N_VAL_SAMPLES': 1000,
    'N_TRAIN_SAMPLES': 5000,  # Set to None to use all available
}

# Create directories
os.makedirs(CONFIG['OUTPUT_DIR'], exist_ok=True)
os.makedirs(CONFIG['DATA_DIR'], exist_ok=True)

# Device setup
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
print(f"Configuration: {CONFIG}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Configuration: {'MODEL_ID': 'google/siglip-base-patch16-224', 'OUTPUT_DIR': './siglip-scin-lora', 'DATA_DIR': './data/scin_cache', 'BATCH_SIZE': 16, 'LEARNING_RATE': 0.0001, 'LORA_RANK': 16, 'LORA_ALPHA': 16, 'MAX_STEPS': 500, 'LOSS_TYPE': 'sigmoid', 'N_VAL_SAMPLES': 1000, 'N_TRAIN_SAMPLES': 5000}


In [2]:
# Cell 2: Data Loading and Caching Functions

def download_and_cache_dataset(n_train=5000, n_val=1000, force_redownload=False):
    """
    Download dataset from HuggingFace and cache locally.
    
    Args:
        n_train: Number of training samples (None for all)
        n_val: Number of validation samples
        force_redownload: Force redownload even if cache exists
    
    Returns:
        tuple: (train_data_list, val_data_list)
    """
    train_cache = Path(CONFIG['DATA_DIR']) / f"train_{n_train}.pkl"
    val_cache = Path(CONFIG['DATA_DIR']) / f"val_{n_val}.pkl"
    
    # Check if cache exists
    if not force_redownload and train_cache.exists() and val_cache.exists():
        print(f"Loading cached dataset from {CONFIG['DATA_DIR']}")
        with open(train_cache, 'rb') as f:
            train_data = pickle.load(f)
        with open(val_cache, 'rb') as f:
            val_data = pickle.load(f)
        print(f"Loaded {len(train_data)} training samples and {len(val_data)} validation samples from cache")
        return train_data, val_data
    
    # Download from HuggingFace
    print("Downloading dataset from HuggingFace...")
    try:
        base_iterable = load_dataset("google/scin", split="train", streaming=True)
    except Exception as e:
        print(f"Failed to load dataset. Error: {e}")
        print("Please ensure you have an internet connection and have accepted 'google/scin' terms if any.")
        raise
    
    image_columns = ["image_1_path", "image_2_path", "image_3_path"]
    
    # Load validation data
    print(f"Loading {n_val} validation samples...")
    val_data = []
    for item in tqdm(base_iterable.take(n_val), total=n_val, desc="Loading val samples"):
        text = item.get("related_category")
        if not text or not isinstance(text, str):
            continue
        
        for img_col in image_columns:
            image = item.get(img_col)
            if image and isinstance(image, Image.Image):
                try:
                    val_data.append({"image": image.convert("RGB"), "text": text})
                    break
                except Exception as e:
                    print(f"Error converting image, skipping: {e}")
                    break
    
    # Load training data
    train_iterable = base_iterable.skip(n_val)
    print(f"Loading training samples (max: {n_train if n_train else 'all'})...")
    train_data = []
    
    if n_train:
        iterator = tqdm(train_iterable.take(n_train), total=n_train, desc="Loading train samples")
    else:
        iterator = tqdm(train_iterable, desc="Loading train samples")
    
    for item in iterator:
        text = item.get("related_category")
        if not text or not isinstance(text, str):
            continue
        
        for img_col in image_columns:
            image = item.get(img_col)
            if image and isinstance(image, Image.Image):
                try:
                    train_data.append({"image": image.convert("RGB"), "text": text})
                    break
                except Exception as e:
                    print(f"Error converting image, skipping: {e}")
                    break
        
        if n_train and len(train_data) >= n_train:
            break
    
    # Save to cache
    print(f"Saving {len(train_data)} training samples to cache...")
    with open(train_cache, 'wb') as f:
        pickle.dump(train_data, f)
    
    print(f"Saving {len(val_data)} validation samples to cache...")
    with open(val_cache, 'wb') as f:
        pickle.dump(val_data, f)
    
    print(f"Dataset cached to {CONFIG['DATA_DIR']}")
    return train_data, val_data


def clear_cache():
    """Clear all cached dataset files."""
    cache_dir = Path(CONFIG['DATA_DIR'])
    if cache_dir.exists():
        for file in cache_dir.glob("*.pkl"):
            file.unlink()
            print(f"Deleted {file}")
        print("Cache cleared")
    else:
        print("No cache to clear")


# Load or download dataset
train_data, val_data = download_and_cache_dataset(
    n_train=CONFIG['N_TRAIN_SAMPLES'],
    n_val=CONFIG['N_VAL_SAMPLES'],
    force_redownload=False  # Set to True to force redownload
)

print(f"\nDataset loaded: {len(train_data)} train, {len(val_data)} val samples")

Downloading dataset from HuggingFace...


Some datasets params were ignored: ['splits', 'download_size', 'dataset_size']. Make sure to use only valid params for the dataset builder and to have a up-to-date version of the `datasets` library.


Loading 1000 validation samples...


Loading val samples: 100%|██████████| 1000/1000 [01:06<00:00, 15.03it/s]


Loading training samples (max: 5000)...


Loading train samples:  81%|████████  | 4033/5000 [04:32<01:05, 14.78it/s]


Saving 3032 training samples to cache...
Saving 747 validation samples to cache...
Dataset cached to ./data/scin_cache

Dataset loaded: 3032 train, 747 val samples


In [None]:
# Cell 3: Dataset Classes

class SCIN_Dataset(torch.utils.data.Dataset):
    """
    PyTorch Dataset for SCIN data.
    Works with pre-loaded data lists (from cache or download).
    """
    def __init__(self, data_list):
        print(f"Initializing SCIN_Dataset with {len(data_list)} samples.")
        self.data = data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def collate_fn(batch, processor):
    """
    Data collator with robust error handling.
    """
    processed_images = []
    processed_texts_input_ids = []
    skipped_count = 0

    for i, item in enumerate(batch):
        try:
            if item is None:
                skipped_count += 1
                continue

            img = item.get("image")
            txt = item.get("text")

            # Check for invalid content
            if img is None:
                skipped_count += 1
                continue
            if txt is None or txt.strip() == "":
                skipped_count += 1
                continue

            # Process valid items
            inputs = processor(
                text=[txt],
                images=[img],
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=64
            )

            processed_images.append(inputs["pixel_values"])
            processed_texts_input_ids.append(inputs["input_ids"])

        except Exception as e:
            print(f"WARNING (collate_fn): Skipping item {i} due to error: {e}")
            skipped_count += 1

    if not processed_images:
        if len(batch) > 0:
            print(f"ERROR: Entire batch was skipped! ({skipped_count} items failed)")
        return {}

    try:
        batch_pixel_values = torch.cat(processed_images, dim=0)
        batch_input_ids = torch.cat(processed_texts_input_ids, dim=0)

        return {
            "pixel_values": batch_pixel_values,
            "input_ids": batch_input_ids
        }

    except Exception as e:
        print(f"Error during final batch collation: {e}")
        return {}


# Create datasets
train_dataset = SCIN_Dataset(train_data)
val_dataset = SCIN_Dataset(val_data)

print(f"Datasets created: {len(train_dataset)} train, {len(val_dataset)} val")

In [None]:
# Cell 4: Model Loading

def load_models_and_processor(model_id, device):
    """
    Load processor and two model instances (base and tunable).
    
    Returns:
        tuple: (processor, base_model, model_to_tune)
    """
    print(f"Loading processor and models from: {model_id}")
    
    processor = AutoProcessor.from_pretrained(model_id)
    dtype = torch.float16 if device == "cuda" else torch.float32
    
    # Base model (for baseline evaluation)
    base_model = AutoModel.from_pretrained(
        model_id,
        torch_dtype=dtype
    ).to(device)
    
    # Model to fine-tune
    model_to_tune = AutoModel.from_pretrained(
        model_id,
        torch_dtype=dtype
    )
    
    print("Models and processor loaded successfully")
    return processor, base_model, model_to_tune


def apply_lora(model, rank=16, alpha=16):
    """
    Apply LoRA configuration to a model.
    
    Args:
        model: Base model
        rank: LoRA rank
        alpha: LoRA alpha
    
    Returns:
        Model with LoRA applied
    """
    print(f"Applying LoRA configuration (rank={rank}, alpha={alpha})...")
    
    lora_config = LoraConfig(
        r=rank,
        lora_alpha=alpha,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.1,
        bias="none",
    )
    
    model = get_peft_model(model, lora_config)
    
    print("LoRA applied. Trainable parameters:")
    model.print_trainable_parameters()
    
    return model


# Load models
processor, base_model, model_to_tune = load_models_and_processor(
    CONFIG['MODEL_ID'], 
    device
)

# Apply LoRA to the tunable model
model_to_tune = apply_lora(
    model_to_tune,
    rank=CONFIG['LORA_RANK'],
    alpha=CONFIG['LORA_ALPHA']
)
model_to_tune = model_to_tune.to(device)

print(f"\nModels ready on device: {device}")

In [None]:
# Cell 5: Metrics and Loss Functions

def compute_metrics(eval_pred):
    """
    Calculate evaluation metrics.
    """
    logits = eval_pred.predictions
    predictions = np.argmax(logits, axis=1)
    true_labels = np.arange(len(predictions))
    
    precision, recall, f1, _ = precision_recall_fscore_support(
        true_labels, predictions, average='macro', zero_division=0
    )
    acc = accuracy_score(true_labels, predictions)
    
    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }


def compute_loss_function(logits_per_image, logits_per_text, loss_type="sigmoid", device="cpu"):
    """
    Compute loss based on loss type.
    
    Args:
        logits_per_image: Image logits
        logits_per_text: Text logits
        loss_type: "contrastive" or "sigmoid"
        device: Device for computation
    
    Returns:
        Combined loss
    """
    batch_size = logits_per_image.shape[0]
    
    if batch_size <= 1:
        return torch.tensor(0.0, device=device, requires_grad=True)
    
    if loss_type == "contrastive":
        labels = torch.arange(batch_size, device=device)
        loss_images = F.cross_entropy(logits_per_image, labels)
        loss_text = F.cross_entropy(logits_per_text, labels)
        loss = (loss_images + loss_text) / 2.0
    elif loss_type == "sigmoid":
        labels = torch.eye(batch_size, device=device)
        loss_images = F.binary_cross_entropy_with_logits(logits_per_image, labels)
        loss_text = F.binary_cross_entropy_with_logits(logits_per_text, labels)
        loss = (loss_images + loss_text) / 2.0
    else:
        raise ValueError(f"Unknown loss_type: {loss_type}")
    
    return loss


print("Metrics and loss functions defined")

In [None]:
# Cell 6: Custom Trainer Class

class CustomTrainer(Trainer):
    """
    Custom Trainer with:
    1. Switchable loss (Contrastive or Sigmoid)
    2. Gradient accumulation for heatmap
    3. Proper evaluation with loss and metrics
    """
    def __init__(self, *args, loss_type="contrastive", **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_type = loss_type
        print(f"CustomTrainer initialized with loss_type: {self.loss_type}")
        
        # Gradient tracking for heatmap
        self.gradient_accumulator = defaultdict(float)
        self.step_count = 0
        
        # Batch stats
        self.successful_batches = 0
        self.skipped_batches_eval = 0

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Training loss computation."""
        if not inputs or "pixel_values" not in inputs:
            dummy_loss = torch.tensor(0.0, device=model.device, requires_grad=True)
            return (dummy_loss, {}) if return_outputs else dummy_loss

        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text
        
        loss = compute_loss_function(
            logits_per_image,
            logits_per_text,
            loss_type=self.loss_type,
            device=model.device
        )

        return (loss, outputs) if return_outputs else loss

    def training_step(self, model, inputs, num_items_in_batch):
        """Training step with gradient tracking."""
        loss = super().training_step(model, inputs, num_items_in_batch)

        if loss is not None:
            self.step_count += 1
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if param.grad is not None and param.requires_grad:
                        self.gradient_accumulator[name] += param.grad.norm().item()
        return loss

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        """Evaluation step with loss computation."""
        if not inputs or "pixel_values" not in inputs:
            self.skipped_batches_eval += 1
            return (None, None, None)
        
        with torch.no_grad():
            outputs = model(**inputs)
            
            logits_per_image = outputs.logits_per_image
            logits_per_text = outputs.logits_per_text
            batch_size = logits_per_image.shape[0]
            
            # Check for NaN/Inf
            if torch.isnan(logits_per_image).any() or torch.isinf(logits_per_image).any():
                print("WARNING: NaN or Inf detected in logits during eval.")
                self.skipped_batches_eval += 1
                return (None, None, None)
            
            # Compute loss
            loss = None
            if batch_size <= 1:
                self.skipped_batches_eval += 1
            else:
                loss = compute_loss_function(
                    logits_per_image,
                    logits_per_text,
                    loss_type=self.loss_type,
                    device=model.device
                )
                self.successful_batches += 1
        
        predictions = logits_per_image.cpu()
        return (loss, predictions, None)

    def _extract_layer_index(self, name_parts):
        """Extract layer index from parameter name."""
        for part in name_parts:
            if part.isdigit():
                return int(part)
        return None

    def _extract_component_name(self, name_parts):
        """Extract component name from parameter name."""
        name_str = ".".join(name_parts)
        if "lora_A" in name_str:
            if "q_proj" in name_str: return "LoRA A (Query)"
            if "v_proj" in name_str: return "LoRA A (Value)"
        elif "lora_B" in name_str:
            if "q_proj" in name_str: return "LoRA B (Query)"
            if "v_proj" in name_str: return "LoRA B (Value)"
        if "q_proj" in name_str: return "Query Proj"
        if "v_proj" in name_str: return "Value Proj"
        if "k_proj" in name_str: return "Key Proj"
        if "fc1" in name_str: return "MLP Layer 1"
        if "fc2" in name_str: return "MLP Layer 2"
        return None

    def _process_gradients_for_heatmap(self):
        """Process accumulated gradients for heatmap generation."""
        if self.step_count == 0:
            print("No training steps recorded. Skipping heatmap.")
            return None, None, []

        vision_data = defaultdict(lambda: defaultdict(float))
        text_data = defaultdict(lambda: defaultdict(float))
        skipped_params = []

        for name, avg_grad_norm in self.gradient_accumulator.items():
            avg_norm = avg_grad_norm / self.step_count
            parts = name.split('.')
            layer_idx = self._extract_layer_index(parts)
            component = self._extract_component_name(parts)
            
            if layer_idx is None or component is None:
                if "lora_" in name:
                    skipped_params.append(name)
                continue
            
            if "vision_model" in name:
                vision_data[layer_idx][component] = avg_norm
            elif "text_model" in name:
                text_data[layer_idx][component] = avg_norm
            else:
                if "lora_" in name:
                    skipped_params.append(name)

        vision_df = pd.DataFrame.from_dict(vision_data, orient='index').sort_index()
        text_df = pd.DataFrame.from_dict(text_data, orient='index').sort_index()
        return vision_df, text_df, skipped_params

    def plot_final_heatmap(self, save_path):
        """Generate and save gradient impact heatmap."""
        print("\nGenerating final gradient heatmaps...")
        vision_df, text_df, skipped = self._process_gradients_for_heatmap()
        
        if vision_df is None or (vision_df.empty and text_df.empty):
            print("No gradient data collected. Skipping heatmap file.")
            return
        
        if skipped:
            print(f"[WARN] Skipped {len(skipped)} LoRA params (couldn't parse name)")

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))
        vmin = 0.0
        vmax = max(
            vision_df.max().max() if not vision_df.empty else 0,
            text_df.max().max() if not text_df.empty else 0
        )
        if vmax == 0:
            vmax = 1.0

        if not vision_df.empty:
            sns.heatmap(vision_df, ax=ax1, cmap="magma", annot=True, fmt=".2e",
                        linewidths=.5, vmin=vmin, vmax=vmax)
            ax1.set_title("Vision Encoder Impact (Avg. Gradient Norm)", fontsize=16)
            ax1.set_ylabel("Layer Depth", fontsize=12)
            ax1.set_xlabel("Transformer Component (LoRA)", fontsize=12)
        else:
            ax1.text(0.5, 0.5, "No Vision Gradients Found", ha='center', va='center')
            ax1.set_title("Vision Encoder Impact", fontsize=16)

        if not text_df.empty:
            sns.heatmap(text_df, ax=ax2, cmap="magma", annot=True, fmt=".2e",
                        linewidths=.5, vmin=vmin, vmax=vmax)
            ax2.set_title("Text Encoder Impact (Avg. Gradient Norm)", fontsize=16)
            ax2.set_ylabel("Layer Depth", fontsize=12)
            ax2.set_xlabel("Transformer Component (LoRA)", fontsize=12)
        else:
            ax2.text(0.5, 0.5, "No Text Gradients Found", ha='center', va='center')
            ax2.set_title("Text Encoder Impact", fontsize=16)

        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()
        print(f"Heatmap saved to: {save_path}")


print("CustomTrainer class defined")

In [None]:
# Cell 7: Training Setup

def create_training_args(config, device, debug_mode=False):
    """
    Create training arguments.
    
    Args:
        config: Configuration dictionary
        device: Device string
        debug_mode: If True, use minimal steps for quick testing
    
    Returns:
        TrainingArguments
    """
    use_fp16 = True if device == "cuda" else False
    
    if debug_mode:
        print("Creating DEBUG training arguments (quick run)...")
        args = TrainingArguments(
            output_dir=config['OUTPUT_DIR'],
            per_device_train_batch_size=config['BATCH_SIZE'],
            per_device_eval_batch_size=config['BATCH_SIZE'],
            max_steps=2,
            eval_steps=1,
            logging_steps=1,
            warmup_steps=1,
            weight_decay=0.01,
            learning_rate=config['LEARNING_RATE'],
            save_strategy="no",
            load_best_model_at_end=False,
            fp16=use_fp16,
            report_to="none",
            remove_unused_columns=False,
            prediction_loss_only=False,
        )
    else:
        print("Creating FULL training arguments...")
        args = TrainingArguments(
            output_dir=config['OUTPUT_DIR'],
            per_device_train_batch_size=config['BATCH_SIZE'],
            per_device_eval_batch_size=config['BATCH_SIZE'],
            max_steps=config['MAX_STEPS'],
            weight_decay=0.01,
            learning_rate=config['LEARNING_RATE'],
            warmup_steps=50,
            logging_steps=50,
            save_strategy="steps",
            save_steps=250,
            eval_strategy="steps",
            eval_steps=250,
            load_best_model_at_end=False,
            fp16=use_fp16,
            report_to="none",
            remove_unused_columns=False,
            prediction_loss_only=False,
        )
    
    return args


# Create training arguments
# Set debug_mode=True for quick testing, False for full training
training_args = create_training_args(CONFIG, device, debug_mode=True)

# Initialize trainer
trainer = CustomTrainer(
    model=model_to_tune,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=lambda data: collate_fn(data, processor),
    loss_type=CONFIG['LOSS_TYPE'],
    compute_metrics=compute_metrics
)

print("Trainer initialized and ready")
print(f"Loss type: {CONFIG['LOSS_TYPE']}")
print(f"Training steps: {training_args.max_steps}")

In [None]:
# Cell 8: Run Baseline Evaluation

print("="*60)
print("RUNNING BASELINE EVALUATION (BEFORE FINE-TUNING)")
print("="*60)

# Temporarily swap to base model for baseline
trainer.model = base_model
baseline_metrics = trainer.evaluate()

print("\nBaseline Evaluation Metrics:")
for key, value in baseline_metrics.items():
    print(f"  {key}: {value:.4f}" if isinstance(value, float) else f"  {key}: {value}")

# Swap back to the model we want to tune
trainer.model = model_to_tune.to(device)

print("\nBaseline evaluation complete. Ready for fine-tuning.")

In [None]:
# Cell 9: Run Fine-Tuning

print("="*60)
print("STARTING FINE-TUNING")
print("="*60)

# Train the model
trainer.train()

print("\n" + "="*60)
print("FINE-TUNING COMPLETE")
print("="*60)

In [None]:
# Cell 10: Run Final Evaluation

print("="*60)
print("RUNNING FINAL EVALUATION (AFTER FINE-TUNING)")
print("="*60)

final_metrics = trainer.evaluate()

print("\nFinal Evaluation Metrics:")
for key, value in final_metrics.items():
    print(f"  {key}: {value:.4f}" if isinstance(value, float) else f"  {key}: {value}")

print("\n" + "="*60)
print("COMPARISON: BASELINE vs FINE-TUNED")
print("="*60)

comparison_keys = ['eval_loss', 'eval_accuracy', 'eval_precision', 'eval_recall', 'eval_f1']
for key in comparison_keys:
    if key in baseline_metrics and key in final_metrics:
        baseline_val = baseline_metrics[key]
        final_val = final_metrics[key]
        change = final_val - baseline_val
        print(f"{key:20s}: {baseline_val:.4f} -> {final_val:.4f} (Δ {change:+.4f})")

In [None]:
# Cell 11: Save Model and Generate Heatmap

print("="*60)
print("SAVING MODEL AND GENERATING VISUALIZATIONS")
print("="*60)

# Save LoRA adapter
final_adapter_path = os.path.join(CONFIG['OUTPUT_DIR'], "final-adapter")
model_to_tune.save_pretrained(final_adapter_path)
processor.save_pretrained(final_adapter_path)
print(f"\nLoRA adapter saved to: {final_adapter_path}")

# Generate and save heatmap
heatmap_path = os.path.join(CONFIG['OUTPUT_DIR'], "gradient_impact_heatmap.png")
trainer.plot_final_heatmap(save_path=heatmap_path)

print("\nModel and visualizations saved successfully!")

In [None]:
# Cell 12: Qualitative Analysis Functions

def get_similarity_scores(model, processor, image, text_probes, device):
    """
    Get model similarity scores for an image and text probes.
    
    Args:
        model: Model to evaluate
        processor: Processor for inputs
        image: PIL Image
        text_probes: List of text strings
        device: Device string
    
    Returns:
        numpy array of similarity scores
    """
    inputs = processor(
        text=text_probes,
        images=[image],
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=64
    ).to(device)

    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        scores = torch.sigmoid(outputs.logits_per_image).cpu().numpy().flatten()
    return scores


def plot_similarity_scores(base_scores, tuned_scores, probes, true_category, save_path):
    """
    Generate bar chart comparing baseline and fine-tuned similarity scores.
    
    Args:
        base_scores: Baseline model scores
        tuned_scores: Fine-tuned model scores
        probes: List of text probes
        true_category: The correct category
        save_path: Path to save plot
    """
    df_data = {
        "Text Probe": probes * 2,
        "Similarity Score": np.concatenate([base_scores, tuned_scores]),
        "Model": ["Baseline"] * len(probes) + ["Fine-Tuned"] * len(probes)
    }
    df = pd.DataFrame(df_data)

    plt.figure(figsize=(15, 7))
    sns.barplot(
        data=df,
        x="Text Probe",
        y="Similarity Score",
        hue="Model",
        palette={"Baseline": "lightblue", "Fine-Tuned": "darkblue"}
    )

    # Highlight correct category
    ax = plt.gca()
    for i, probe in enumerate(probes):
        if probe == true_category:
            ax.get_xticklabels()[i].set_color("red")
            ax.get_xticklabels()[i].set_fontweight("bold")

    plt.title(f"Qualitative Similarity Test (True Category: {true_category})", fontsize=16)
    plt.ylabel("Similarity Score (Sigmoid)", fontsize=12)
    plt.xlabel("Text Probes", fontsize=12)
    plt.xticks(rotation=15)
    plt.legend(title="Model", fontsize=12)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    print(f"Similarity plot saved to: {save_path}")


def run_qualitative_test(base_model, tuned_model, processor, device, output_dir):
    """
    Run qualitative similarity test on a sample image.
    
    Args:
        base_model: Baseline model
        tuned_model: Fine-tuned model
        processor: Processor
        device: Device string
        output_dir: Directory to save results
    """
    print("\n" + "="*60)
    print("RUNNING QUALITATIVE SIMILARITY TEST")
    print("="*60)
    
    # Test image and probes
    SAMPLE_IMG_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/6/67/Seborrhoeic_keratosis_-_close-up.jpg/800px-Seborrhoeic_keratosis_-_close-up.jpg"
    SAMPLE_IMG_CATEGORY = "benign keratosis"
    TEXT_PROBES = [
        SAMPLE_IMG_CATEGORY,
        "melanoma",
        "nevus",
        "eczema",
        "basal cell carcinoma"
    ]
    
    print(f"Test image: {SAMPLE_IMG_URL}")
    print(f"True category: {SAMPLE_IMG_CATEGORY}")
    print(f"Text probes: {TEXT_PROBES}")
    
    try:
        # Download and prepare image
        response = requests.get(SAMPLE_IMG_URL)
        image = Image.open(BytesIO(response.content)).convert("RGB")
        
        # Get scores from both models
        print("\nComputing similarity scores...")
        base_scores = get_similarity_scores(base_model, processor, image, TEXT_PROBES, device)
        tuned_scores = get_similarity_scores(tuned_model, processor, image, TEXT_PROBES, device)
        
        # Print scores
        print("\nSimilarity Scores:")
        print(f"{'Probe':<25} {'Baseline':<12} {'Fine-Tuned':<12} {'Change':<12}")
        print("-" * 60)
        for i, probe in enumerate(TEXT_PROBES):
            change = tuned_scores[i] - base_scores[i]
            marker = " *" if probe == SAMPLE_IMG_CATEGORY else ""
            print(f"{probe:<25} {base_scores[i]:<12.4f} {tuned_scores[i]:<12.4f} {change:+.4f}{marker}")
        
        # Plot and save
        sim_report_path = os.path.join(output_dir, "similarity_report.png")
        plot_similarity_scores(base_scores, tuned_scores, TEXT_PROBES, SAMPLE_IMG_CATEGORY, sim_report_path)
        
        print("\nQualitative test complete!")
        return sim_report_path
        
    except Exception as e:
        print(f"ERROR during qualitative test: {e}")
        return None


print("Qualitative analysis functions defined")

In [None]:
# Cell 13: Run Qualitative Test

similarity_plot_path = run_qualitative_test(
    base_model=base_model,
    tuned_model=model_to_tune.to(device),
    processor=processor,
    device=device,
    output_dir=CONFIG['OUTPUT_DIR']
)

In [None]:
# Cell 14: Generate Final Report

def generate_final_report(baseline_metrics, final_metrics, output_dir):
    """
    Generate a markdown report with all results.
    
    Args:
        baseline_metrics: Dictionary of baseline metrics
        final_metrics: Dictionary of final metrics
        output_dir: Directory to save report
    """
    print("\n" + "="*60)
    print("GENERATING FINAL REPORT")
    print("="*60)
    
    report_path = os.path.join(output_dir, "final_report.md")
    
    report_content = "# Fine-Tuning Experiment Report\n\n"
    report_content += f"**Model:** {CONFIG['MODEL_ID']}\n"
    report_content += f"**Loss Type:** {CONFIG['LOSS_TYPE']}\n"
    report_content += f"**LoRA Rank:** {CONFIG['LORA_RANK']}\n"
    report_content += f"**LoRA Alpha:** {CONFIG['LORA_ALPHA']}\n"
    report_content += f"**Training Steps:** {CONFIG['MAX_STEPS']}\n"
    report_content += f"**Learning Rate:** {CONFIG['LEARNING_RATE']}\n\n"
    
    # Quantitative metrics
    report_content += "## 1. Quantitative Metrics\n\n"
    report_content += "Comparison of model performance on the validation set before and after fine-tuning.\n\n"
    report_content += "| Metric | Baseline (Before) | Fine-Tuned (After) | Change |\n"
    report_content += "| :--- | :--- | :--- | :--- |\n"
    
    def get_metric(metrics, key, precision=4):
        val = metrics.get(key)
        if val is None:
            return "N/A"
        return f"{val:.{precision}f}"
    
    def get_change(baseline, final, key, precision=4):
        b = baseline.get(key)
        f = final.get(key)
        if b is None or f is None:
            return "N/A"
        change = f - b
        sign = "+" if change >= 0 else ""
        return f"{sign}{change:.{precision}f}"
    
    metric_keys = [
        ("eval_loss", "Eval Loss"),
        ("eval_accuracy", "Accuracy"),
        ("eval_precision", "Precision (Macro)"),
        ("eval_recall", "Recall (Macro)"),
        ("eval_f1", "F1-Score (Macro)"),
        ("eval_runtime", "Eval Runtime (s)"),
    ]
    
    for key, name in metric_keys:
        b_val = get_metric(baseline_metrics, key)
        f_val = get_metric(final_metrics, key)
        c_val = get_change(baseline_metrics, final_metrics, key)
        report_content += f"| **{name}** | {b_val} | {f_val} | {c_val} |\n"
    
    # Qualitative analysis
    report_content += "\n## 2. Qualitative Analysis (Similarity Test)\n\n"
    report_content += "This test shows how the model's understanding of specific concepts changed.\n\n"
    report_content += "![Similarity Plot](similarity_report.png)\n\n"
    report_content += "**Interpretation:** The fine-tuned model should show higher similarity scores for the correct category.\n"
    
    # Gradient heatmap
    report_content += "\n## 3. Gradient Impact Heatmap\n\n"
    report_content += "This heatmap shows which parts of the model were modified most during fine-tuning.\n\n"
    report_content += "![Gradient Impact Heatmap](gradient_impact_heatmap.png)\n\n"
    report_content += "**Interpretation:** Brighter colors indicate layers heavily modified by fine-tuning.\n"
    
    # Save report
    try:
        with open(report_path, "w") as f:
            f.write(report_content)
        print(f"Report saved to: {report_path}")
    except Exception as e:
        print(f"ERROR saving report: {e}")


# Generate the report
generate_final_report(
    baseline_metrics=baseline_metrics,
    final_metrics=final_metrics,
    output_dir=CONFIG['OUTPUT_DIR']
)

print("\n" + "="*60)
print("ALL TASKS COMPLETE!")
print("="*60)
print(f"Results saved to: {CONFIG['OUTPUT_DIR']}")
print("Files generated:")
print("  - final-adapter/ (LoRA weights)")
print("  - gradient_impact_heatmap.png")
print("  - similarity_report.png")
print("  - final_report.md")

In [None]:
# BONUS CELL: Utility Functions for Debugging and Analysis

def inspect_dataset_sample(dataset, n=3):
    """
    Inspect first n samples from dataset.
    
    Args:
        dataset: Dataset to inspect
        n: Number of samples to show
    """
    print(f"Inspecting first {n} samples from dataset:")
    print("="*60)
    
    for i in range(min(n, len(dataset))):
        sample = dataset[i]
        print(f"\nSample {i}:")
        print(f"  Text: {sample['text']}")
        print(f"  Image: {sample['image'].size}, mode={sample['image'].mode}")
        
        # Display image inline (if in Jupyter)
        try:
            from IPython.display import display
            display(sample['image'])
        except:
            print("  (Image display not available)")


def test_single_batch(trainer, dataset, n_samples=4):
    """
    Test processing a single batch through the model.
    
    Args:
        trainer: Trainer instance
        dataset: Dataset to sample from
        n_samples: Batch size to test
    """
    print(f"Testing single batch with {n_samples} samples...")
    print("="*60)
    
    # Get a small batch
    batch = [dataset[i] for i in range(min(n_samples, len(dataset)))]
    
    # Process through collate_fn
    inputs = trainer.data_collator(batch)
    
    if not inputs:
        print("ERROR: Batch processing failed!")
        return
    
    print(f"Batch processed successfully!")
    print(f"  pixel_values shape: {inputs['pixel_values'].shape}")
    print(f"  input_ids shape: {inputs['input_ids'].shape}")
    
    # Test forward pass
    try:
        trainer.model.eval()
        with torch.no_grad():
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = trainer.model(**inputs)
            print(f"  logits_per_image shape: {outputs.logits_per_image.shape}")
            print(f"  logits_per_text shape: {outputs.logits_per_text.shape}")
        print("\nForward pass successful!")
    except Exception as e:
        print(f"\nERROR in forward pass: {e}")


def compare_model_predictions(base_model, tuned_model, processor, image, text, device):
    """
    Compare predictions from base and tuned models on a single example.
    
    Args:
        base_model: Baseline model
        tuned_model: Fine-tuned model
        processor: Processor
        image: PIL Image
        text: Text string
        device: Device string
    """
    print("Comparing model predictions...")
    print("="*60)
    print(f"Text: {text}")
    
    inputs = processor(
        text=[text],
        images=[image],
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=64
    ).to(device)
    
    # Base model
    base_model.eval()
    with torch.no_grad():
        base_outputs = base_model(**inputs)
        base_score = torch.sigmoid(base_outputs.logits_per_image).item()
    
    # Tuned model
    tuned_model.eval()
    with torch.no_grad():
        tuned_outputs = tuned_model(**inputs)
        tuned_score = torch.sigmoid(tuned_outputs.logits_per_image).item()
    
    print(f"Baseline similarity: {base_score:.4f}")
    print(f"Fine-tuned similarity: {tuned_score:.4f}")
    print(f"Change: {tuned_score - base_score:+.4f}")


def plot_training_history(trainer):
    """
    Plot training loss over time.
    
    Args:
        trainer: Trainer instance with logged history
    """
    history = trainer.state.log_history
    
    # Extract loss values
    train_losses = []
    eval_losses = []
    steps = []
    
    for entry in history:
        if 'loss' in entry:
            train_losses.append(entry['loss'])
            steps.append(entry['step'])
        if 'eval_loss' in entry:
            eval_losses.append(entry['eval_loss'])
    
    if not train_losses:
        print("No training history to plot")
        return
    
    plt.figure(figsize=(12, 5))
    
    # Training loss
    plt.subplot(1, 2, 1)
    plt.plot(steps, train_losses, 'b-', label='Training Loss')
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()
    plt.grid(True)
    
    # Evaluation loss
    if eval_losses:
        plt.subplot(1, 2, 2)
        eval_steps = [entry['step'] for entry in history if 'eval_loss' in entry]
        plt.plot(eval_steps, eval_losses, 'r-', label='Eval Loss')
        plt.xlabel('Step')
        plt.ylabel('Loss')
        plt.title('Evaluation Loss')
        plt.legend()
        plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(CONFIG['OUTPUT_DIR'], 'training_history.png'))
    plt.show()
    print(f"Training history plot saved")


def load_saved_adapter(adapter_path, base_model_id, device):
    """
    Load a saved LoRA adapter.
    
    Args:
        adapter_path: Path to saved adapter
        base_model_id: Base model ID
        device: Device string
    
    Returns:
        Loaded model with adapter
    """
    from peft import PeftModel
    
    print(f"Loading adapter from {adapter_path}...")
    
    base_model = AutoModel.from_pretrained(base_model_id).to(device)
    model = PeftModel.from_pretrained(base_model, adapter_path)
    
    print("Adapter loaded successfully")
    return model


print("Utility functions defined!")
print("\nAvailable utilities:")
print("  - inspect_dataset_sample(dataset, n=3)")
print("  - test_single_batch(trainer, dataset, n_samples=4)")
print("  - compare_model_predictions(base_model, tuned_model, processor, image, text, device)")
print("  - plot_training_history(trainer)")
print("  - load_saved_adapter(adapter_path, base_model_id, device)")