In [1]:
import torch
import os
import torch.nn.functional as F
from datasets import load_dataset
from PIL import Image
from transformers import (
    AutoProcessor,
    AutoModel,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model
from collections import defaultdict
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import requests
from io import BytesIO
from tqdm import tqdm # Import tqdm for the progress bar

# --- Configuration --- #
MODEL_ID = "google/siglip-base-patch16-224"
OUTPUT_DIR = "./siglip-scin-lora"
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
LORA_RANK = 16
LORA_ALPHA = 16
MAX_STEPS = 500

# --- [NEW] Experiment Toggle ---
# Set this to "contrastive" or "sigmoid"
LOSS_TYPE = "sigmoid" # or "contrastive"

# ===================================================================
#  Part 1: The Dataset Loader Code
# ===================================================================

class SCIN_Iterable_Dataset(torch.utils.data.IterableDataset):
    """
    Custom PyTorch IterableDataset for the google/scin dataset.
    This class STREAMS the dataset to prevent OOM errors.

    This is used for the TRAINING set.
    """
    def __init__(self, dataset_iterable):
        print(f"Initializing SCIN_Iterable_Dataset with provided iterable...")
        self.dataset = dataset_iterable
        self.image_columns = ["image_1_path", "image_2_path", "image_3_path"]

    def __iter__(self):
        """
        Yields dictionaries containing processed image and text data.
        """
        for item in self.dataset:
            text = item.get("related_category")
            if not text or not isinstance(text, str):
                continue

            for img_col in self.image_columns:
                image = item.get(img_col)
                if image and isinstance(image, Image.Image):
                    # [FIX] Force convert to RGB to handle PNGs with alpha
                    # or other non-standard image modes.
                    try:
                        yield {
                            "image": image.convert("RGB"),
                            "text": text
                        }
                        break # Move to the next item once an image is found
                    except Exception as e:
                        print(f"Error converting image, skipping: {e}")
                        break


# --- [NEW] Map-style Dataset for Evaluation ---
class SCIN_List_Dataset(torch.utils.data.Dataset):
    """
    Custom PyTorch Dataset that wraps a simple list.
    This is used for the VALIDATION set to ensure it's reusable
    and works correctly with trainer.evaluate().
    """
    def __init__(self, data_list):
        print(f"Initializing SCIN_List_Dataset with {len(data_list)} pre-loaded samples.")
        self.data = data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def collate_fn(batch, processor):
    """
    [Robust, Item-by-Item] Data collator.
    This version includes explicit checks for None or empty content
    to prevent the processor from crashing.
    """
    processed_images = []
    processed_texts_input_ids = []
    skipped_count = 0

    for i, item in enumerate(batch):
        try:
            if item is None:
                # This should be rare, but good to check
                skipped_count += 1
                continue

            img = item.get("image")
            txt = item.get("text")

            # --- THIS IS THE FIX ---
            # Explicitly check for bad content *before* calling the processor.
            # The processor will crash on None or empty strings.
            if img is None:
                skipped_count += 1
                continue
            if txt is None or txt.strip() == "":
                skipped_count += 1
                continue
            # -------------------------

            # If we get here, img and txt are valid
            inputs = processor(
                text=[txt],
                images=[img],
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=64
            )

            processed_images.append(inputs["pixel_values"])
            processed_texts_input_ids.append(inputs["input_ids"])

        except Exception as e:
            # This now catches *other* errors, like a truly corrupt image file
            print(f"WARNING (collate_fn): Skipping item {i} due to UNEXPECTED error: {e}")
            skipped_count += 1

    if not processed_images:
        if len(batch) > 0:
            print(f"ERROR: Entire batch was skipped! ({skipped_count} items failed)")
        return {}

    try:
        batch_pixel_values = torch.cat(processed_images, dim=0)
        batch_input_ids = torch.cat(processed_texts_input_ids, dim=0)

        return {
            "pixel_values": batch_pixel_values,
            "input_ids": batch_input_ids
            # No attention_mask needed for SigLIP
        }

    except Exception as e:
        print(f"Error during final batch collation: {e}")
        return {}


# ===================================================================
#  Part 2: The Custom Trainer (with Heatmap & Metrics)
# ===================================================================

def compute_metrics(eval_pred):
    """
    Standard Hugging Face method to calculate metrics.
    This runs ONCE at the end of evaluation on the accumulated predictions.
    """
    # eval_pred.predictions is the tuple (logits_per_image, logits_per_text)
    logits = eval_pred.predictions[0]

    # Handle case where it might not be a tuple (though it should be)
    if isinstance(logits, tuple):
        logits = logits[0]

    # In CLIP/SigLIP, the standard task is: given an image, find the correct text in the batch.
    predictions = np.argmax(logits, axis=1)

    # Generate ground truth labels (diagonal alignment)
    true_labels = np.arange(len(predictions))

    precision, recall, f1, _ = precision_recall_fscore_support(
        true_labels, predictions, average='macro', zero_division=0
    )
    acc = accuracy_score(true_labels, predictions)

    # IMPORTANT: The keys MUST start with "eval_"
    return {
        "eval_accuracy": acc,
        "eval_precision": precision,
        "eval_recall": recall,
        "eval_f1": f1,
    }

class CustomTrainer(Trainer):
    """
    Custom Trainer to:
    1. Compute the switchable loss (Contrastive or Sigmoid).
    2. Accumulate gradient norms for heatmap generation.
    3. Correctly compute loss AND metrics during evaluation by overriding prediction_step.
    """
    def __init__(self, *args, loss_type="contrastive", **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_type = loss_type
        print(f"CustomTrainer initialized with loss_type: {self.loss_type}")

        # --- Heatmap Data ---
        self.gradient_accumulator = defaultdict(float)
        self.step_count = 0

        # --- Track batch stats ---
        self.successful_batches = 0
        self.skipped_batches_eval = 0

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        This is now ONLY used for the TRAINING loss calculation.
        Evaluation loss is handled in prediction_step.
        """
        # --- Handle empty batches from collate_fn ---
        if not inputs or "pixel_values" not in inputs:
             dummy_loss = torch.tensor(0.0, device=model.device, requires_grad=True)
             return (dummy_loss, {}) if return_outputs else dummy_loss

        outputs = model(**inputs)

        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text
        batch_size = logits_per_image.shape[0]

        if batch_size <= 1:
             dummy_loss = torch.tensor(0.0, device=model.device, requires_grad=True)
             return (dummy_loss, {}) if return_outputs else dummy_loss

        # --- Switchable Loss Calculation ---
        if self.loss_type == "contrastive":
            labels = torch.arange(batch_size, device=model.device)
            loss_images = F.cross_entropy(logits_per_image, labels)
            loss_text = F.cross_entropy(logits_per_text, labels)
            loss = (loss_images + loss_text) / 2.0
        elif self.loss_type == "sigmoid":
            labels = torch.eye(batch_size, device=model.device)
            loss_images = F.binary_cross_entropy_with_logits(logits_per_image, labels)
            loss_text = F.binary_cross_entropy_with_logits(logits_per_text, labels)
            loss = (loss_images + loss_text) / 2.0
        else:
            raise ValueError(f"Unknown loss_type: {self.loss_type}")

        # Metric calculation is REMOVED from here.

        return (loss, outputs) if return_outputs else loss

    def training_step(self, model, inputs, num_items_in_batch):
        """
        Overrides the training step to hook into gradient accumulation.
        """
        loss = super().training_step(model, inputs, num_items_in_batch)

        if loss is not None:
            self.step_count += 1
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if param.grad is not None and param.requires_grad:
                        self.gradient_accumulator[name] += param.grad.norm().item()
        return loss

    # --- THIS IS THE CRITICAL NEW FUNCTION ---
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        """
        [CRITICAL FIX]
        This function is called by trainer.evaluate() when compute_metrics is set.
        We MUST override it to manually compute and return the loss.
        """
        # Handle empty batches from our robust collate_fn
        if not inputs or "pixel_values" not in inputs:
            self.skipped_batches_eval += 1
            return (None, None, None) # Return Nones to skip

        # We need to manually compute loss and get outputs
        with torch.no_grad():
            outputs = model(**inputs)

            logits_per_image = outputs.logits_per_image
            logits_per_text = outputs.logits_per_text
            batch_size = logits_per_image.shape[0]

            # Check for NaN/Inf to prevent crashes (good practice)
            if torch.isnan(logits_per_image).any() or torch.isinf(logits_per_image).any():
                print("WARNING: NaN or Inf DETECTED in model logits during eval.")
                self.skipped_batches_eval += 1
                return (None, None, None)

            # Compute loss
            loss = None
            if batch_size <= 1:
                # Can't compute contrastive loss, so we skip
                self.skipped_batches_eval += 1
            else:
                # --- Re-calculate loss just like in compute_loss ---
                if self.loss_type == "contrastive":
                    labels = torch.arange(batch_size, device=model.device)
                    loss_images = F.cross_entropy(logits_per_image, labels)
                    loss_text = F.cross_entropy(logits_per_text, labels)
                    loss = (loss_images + loss_text) / 2.0
                elif self.loss_type == "sigmoid":
                    labels = torch.eye(batch_size, device=model.device)
                    loss_images = F.binary_cross_entropy_with_logits(logits_per_image, labels)
                    loss_text = F.binary_cross_entropy_with_logits(logits_per_text, labels)
                    loss = (loss_images + loss_text) / 2.0

                self.successful_batches += 1

        # Return (loss, logits, labels)
        # We return the logits tuple for compute_metrics
        # We return None for labels, as compute_metrics generates them
        logits_tuple = (logits_per_image.cpu(), logits_per_text.cpu())

        return (loss, logits_tuple, None)

    # --- Heatmap Helper Functions (Keep these as they are) ---
    def _extract_layer_index(self, name_parts):
        for part in name_parts:
            if part.isdigit():
                return int(part)
        return None

    def _extract_component_name(self, name_parts):
        name_str = ".".join(name_parts)
        if "lora_A" in name_str:
            if "q_proj" in name_str: return "LoRA A (Query)"
            if "v_proj" in name_str: return "LoRA A (Value)"
        elif "lora_B" in name_str:
            if "q_proj" in name_str: return "LoRA B (Query)"
            if "v_proj" in name_str: return "LoRA B (Value)"

        if "q_proj" in name_str: return "Query Proj"
        if "v_proj" in name_str: return "Value Proj"
        if "k_proj" in name_str: return "Key Proj"
        if "fc1" in name_str: return "MLP Layer 1"
        if "fc2" in name_str: return "MLP Layer 2"
        return None

    def _process_gradients_for_heatmap(self):
        if self.step_count == 0:
            print("No training steps recorded. Skipping heatmap.")
            return None, None, []

        vision_data = defaultdict(lambda: defaultdict(float))
        text_data = defaultdict(lambda: defaultdict(float))
        skipped_params = []

        for name, avg_grad_norm in self.gradient_accumulator.items():
            avg_norm = avg_grad_norm / self.step_count
            parts = name.split('.')
            layer_idx = self._extract_layer_index(parts)
            component = self._extract_component_name(parts)

            if layer_idx is None or component is None:
                if "lora_" in name:
                    skipped_params.append(name)
                continue

            if "vision_model" in name:
                vision_data[layer_idx][component] = avg_norm
            elif "text_model" in name:
                text_data[layer_idx][component] = avg_norm
            else:
                if "lora_" in name:
                    skipped_params.append(name)

        vision_df = pd.DataFrame.from_dict(vision_data, orient='index').sort_index()
        text_df = pd.DataFrame.from_dict(text_data, orient='index').sort_index()

        return vision_df, text_df, skipped_params

    def plot_final_heatmap(self, save_path):
        print("\nGenerating final gradient heatmaps...")
        vision_df, text_df, skipped = self._process_gradients_for_heatmap()

        if vision_df is None or (vision_df.empty and text_df.empty):
            print("No gradient data collected. Skipping heatmap file.")
            return

        if skipped:
            print(f"[WARN] Skipped {len(skipped)} LoRA params (couldn't parse name):")
            for i, name in enumerate(skipped[:5]):
                print(f"  ... {name}")
            if len(skipped) > 5: print(f"  ... and {len(skipped)-5} more.")

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))

        vmin = 0.0
        vmax = max(vision_df.max().max() if not vision_df.empty else 0,
                   text_df.max().max() if not text_df.empty else 0)
        if vmax == 0:
            vmax = 1.0

        if not vision_df.empty:
            sns.heatmap(vision_df, ax=ax1, cmap="magma", annot=True, fmt=".2e",
                        linewidths=.5, vmin=vmin, vmax=vmax)
            ax1.set_title("Vision Encoder Impact (Avg. Gradient Norm)", fontsize=16)
            ax1.set_ylabel("Layer Depth", fontsize=12)
            ax1.set_xlabel("Transformer Component (LoRA)", fontsize=12)
        else:
            ax1.text(0.5, 0.5, "No Vision Gradients Found", ha='center', va='center')
            ax1.set_title("Vision Encoder Impact (Avg. Gradient Norm)", fontsize=16)

        if not text_df.empty:
            sns.heatmap(text_df, ax=ax2, cmap="magma", annot=True, fmt=".2e",
                        linewidths=.5, vmin=vmin, vmax=vmax)
            ax2.set_title("Text Encoder Impact (Avg. Gradient Norm)", fontsize=16)
            ax2.set_ylabel("Layer Depth", fontsize=12)
            ax2.set_xlabel("Transformer Component (LoRA)", fontsize=12)
        else:
            ax2.text(0.5, 0.5, "No Text Gradients Found", ha='center', va='center')
            ax2.set_title("Text Encoder Impact (Avg. Gradient Norm)", fontsize=16)

        plt.tight_layout()
        plt.savefig(save_path)
        plt.close()
        print(f"Heatmap saved to: {save_path}")


# ===================================================================
#  Part 3: The Main Training Logic
# ===================================================================

def main_training():

    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    print(f"Using device: {device}")

    # --- STEP 1: LOAD MODEL AND PROCESSOR ---
    print(f"Loading base model and processor from: {MODEL_ID}")

    processor = AutoProcessor.from_pretrained(MODEL_ID)
    dtype = torch.float16 if device == "cuda" else torch.float32

    # --- [NEW] Load TWO models: one for baseline, one to tune ---
    base_model = AutoModel.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype
    ).to(device)

    model_to_tune = AutoModel.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype
    ) # LoRA will be applied to this one

    # --- STEP 2: CONFIGURE AND APPLY LORA ---
    print("Applying LoRA configuration...")
    lora_config = LoraConfig(
        r=LORA_RANK,
        lora_alpha=LORA_ALPHA,
        target_modules=["q_proj", "v_proj"], # Target Q and V projections
        lora_dropout=0.1,
        bias="none",
    )

    model_to_tune = get_peft_model(model_to_tune, lora_config)
    model_to_tune = model_to_tune.to(device)

    print("Model configured with LoRA. Trainable parameters:")
    model_to_tune.print_trainable_parameters()

    # --- STEP 3: LOAD DATASET ---
    print("Loading SCIN dataset (streaming)...")

    # --- [FIX] Load 'train' split ONCE and split it ---
    # We must do this because IterableDataset can't be reused.
    try:
        base_iterable = load_dataset("google/scin", split="train", streaming=True)
    except Exception as e:
        print(f"Failed to load dataset. Error: {e}")
        print("Please ensure you have an internet connection and have accepted 'google/scin' terms if any.")
        return

    # Split the stream: 1000 for validation, the rest for training
    N_VAL_SAMPLES = 1000
    print(f"Splitting 'train' stream: first {N_VAL_SAMPLES} for eval, rest for train.")

    # --- [NEW] Build the eval_data_list in memory ---
    eval_data_list = []
    image_columns = ["image_1_path", "image_2_path", "image_3_path"]

    print(f"Pre-loading {N_VAL_SAMPLES} samples for validation set...")
    # Use tqdm for a progress bar
    for item in tqdm(base_iterable.take(N_VAL_SAMPLES), total=N_VAL_SAMPLES, desc="Loading eval samples"):
        text = item.get("related_category")
        if not text or not isinstance(text, str):
            continue

        for img_col in image_columns:
            image = item.get(img_col)
            if image and isinstance(image, Image.Image):
                # [FIX] Force convert to RGB to handle alpha channels, etc.
                try:
                    eval_data_list.append({"image": image.convert("RGB"), "text": text})
                except Exception as e:
                    print(f"Error converting image, skipping: {e}")
                break # Move to next item

    # Create the training iterable (skipping the ones we took)
    train_iterable = base_iterable.skip(N_VAL_SAMPLES)

    # --- [NEW] Use the correct Dataset classes ---
    train_dataset = SCIN_Iterable_Dataset(dataset_iterable=train_iterable)
    eval_dataset = SCIN_List_Dataset(data_list=eval_data_list)

    print(f"Dataset iterators created. Train (streaming), Eval ({len(eval_dataset)} samples).")


    # --- STEP 4: SET UP TRAINING ---
    print("Setting up training arguments...")
    use_fp16 = True if device == "cuda" else False

    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        max_steps=MAX_STEPS,
        weight_decay=0.01,
        learning_rate=LEARNING_RATE,
        warmup_steps=50,
        logging_steps=50,
        save_strategy="steps",
        save_steps=250,
        eval_strategy="steps",
        eval_steps=250,
        load_best_model_at_end=False,
        fp16=use_fp16,
        report_to="none",
        remove_unused_columns=False,
        prediction_loss_only=False,
    )

    # Initialize our CustomTrainer
    trainer = CustomTrainer(
        model=model_to_tune,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=lambda data: collate_fn(data, processor),
        loss_type=LOSS_TYPE,
        compute_metrics=compute_metrics
    )

    # --- STEP 5: [NEW] RUN BASELINE EVALUATION ---
    print("\n" + "="*50)
    print("RUNNING BASELINE EVALUATION (BEFORE FINE-TUNING)...")
    print("="*50)

    trainer.model = base_model # Temporarily swap to the base model
    baseline_metrics = trainer.evaluate()
    print("Baseline Evaluation Metrics:")
    print(baseline_metrics)

    # Swap back to the model we actually want to tune
    trainer.model = model_to_tune.to(device) # Ensure it's on the device

    # --- STEP 6: RUN FINE-TUNING ---
    print("\n" + "="*50)
    print("STARTING FINE-TUNING...")
    print("="*50)
    trainer.train()

    # --- STEP 7: RUN FINAL EVALUATION ---
    print("\n" + "="*50)
    print("RUNNING FINAL EVALUATION (AFTER FINE-TUNING)...")
    print("="*50)
    final_metrics = trainer.evaluate()
    print("Final Evaluation Metrics:")
    print(final_metrics)

    # --- STEP 8: SAVE FINAL MODEL & GENERATE REPORT ---
    print("\n" + "="*50)
    print("SAVING MODEL AND GENERATING REPORT...")
    print("="*50)

    # Save LoRA adapter
    final_adapter_path = os.path.join(OUTPUT_DIR, "final-adapter")
    model_to_tune.save_pretrained(final_adapter_path)
    processor.save_pretrained(final_adapter_path)
    print(f"Training complete. LoRA adapter saved to: {final_adapter_path}")

    # Save heatmap
    heatmap_path = os.path.join(OUTPUT_DIR, "gradient_impact_heatmap.png")
    trainer.plot_final_heatmap(save_path=heatmap_path)

    # Generate and save the final text report
    generate_final_report(
        baseline_metrics,
        final_metrics,
        base_model,
        model_to_tune.to(device), # Ensure tuned model is on device
        processor,
        OUTPUT_DIR,
        device
    )

# ===================================================================
#  Part 4: Final Report Generation
# ===================================================================

def generate_final_report(baseline_metrics, final_metrics, base_model, tuned_model, processor, output_dir, device):
    """
    Generates a final markdown report with quantitative and qualitative results.
    """
    print("Generating final report...")
    report_path = os.path.join(output_dir, "final_report.md")

    # --- Part 1: Quantitative Metrics Table ---
    report_content = "# Fine-Tuning Experiment Report\n\n"
    report_content += "## 1. Quantitative Metrics\n\n"
    report_content += "Comparison of model performance on the validation set before and after fine-tuning.\n\n"

    report_content += "| Metric | Baseline (Before) | Fine-Tuned (After) | Change |\n"
    report_content += "| :--- | :--- | :--- | :--- |\n"

    # Helper to format and get metrics
    def get_metric(metrics, key, precision=4):
        val = metrics.get(key)
        if val is None:
            return "N/A"
        return f"{val:.{precision}f}"

    def get_change(baseline, final, key, precision=4):
        b = baseline.get(key)
        f = final.get(key)
        if b is None or f is None:
            return "N/A"
        change = f - b
        sign = "+" if change >= 0 else ""
        return f"{sign}{change:.{precision}f}"

    metric_keys = [
        ("eval_loss", "Eval Loss"),
        ("eval_accuracy", "Accuracy"),
        ("eval_precision", "Precision (Macro)"),
        ("eval_recall", "Recall (Macro)"),
        ("eval_f1", "F1-Score (Macro)"),
        ("eval_runtime", "Eval Runtime (s)"),
    ]

    for key, name in metric_keys:
        b_val = get_metric(baseline_metrics, key)
        f_val = get_metric(final_metrics, key)
        c_val = get_change(baseline_metrics, final_metrics, key)
        report_content += f"| **{name}** | {b_val} | {f_val} | {c_val} |\n"

    # --- Part 2: Qualitative Similarity Test ---
    report_content += "\n## 2. Qualitative Analysis (Similarity Test)\n\n"
    report_content += "This test shows how the model's understanding of specific concepts changed. We feed a test image and several text probes to both models and compare their similarity scores.\n\n"

    # Define a sample image and probes (relevant to SCIN dataset)
    # [FIX] Changed to a domain-relevant image (benign keratosis)
    SAMPLE_IMG_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/6/67/Seborrhoeic_keratosis_-_close-up.jpg/800px-Seborrhoeic_keratosis_-_close-up.jpg"
    # [FIX] Use the exact label from the dataset
    SAMPLE_IMG_CATEGORY = "benign keratosis" # The correct label

    # [FIX] Changed to domain-relevant, dataset-matching text probes
    TEXT_PROBES = [
        SAMPLE_IMG_CATEGORY,
        "melanoma",
        "nevus",
        "eczema",
        "basal cell carcinoma"
    ]

    print(f"\nRunning qualitative similarity test on image: {SAMPLE_IMG_URL}")
    sim_report_path = "N/A"
    try:
        # Get similarity scores
        response = requests.get(SAMPLE_IMG_URL)
        image = Image.open(BytesIO(response.content))
        # [FIX] Must convert this image to RGB as well
        image = image.convert("RGB")

        base_scores = get_similarity_scores(base_model, processor, image, TEXT_PROBES, device)
        tuned_scores = get_similarity_scores(tuned_model, processor, image, TEXT_PROBES, device)

        # Plot and save the bar chart
        sim_report_path = os.path.join(output_dir, "similarity_report.png")
        plot_similarity_scores(base_scores, tuned_scores, TEXT_PROBES, SAMPLE_IMG_CATEGORY, sim_report_path)

    except Exception as e:
        print(f"[ERROR] Failed to generate similarity plot: {e}")
        report_content += f"**Failed to generate similarity plot:** `{e}`\n"

    if sim_report_path != "N/A":
        report_content += f"**Test Image URL:** {SAMPLE_IMG_URL}\n"
        report_content += f"**Test Image Category:** `{SAMPLE_IMG_CATEGORY}`\n\n"
        report_content += "![Similarity Plot](similarity_report.png)\n"
        report_content += "\n**Interpretation:** Ideally, the 'Baseline' model is confused (low, flat scores), while the 'Fine-Tuned' model shows a clear, high score for the correct category.\n"

    # --- Part 3: Gradient Impact Heatmap ---
    report_content += "\n## 3. Gradient Impact Heatmap\n\n"
    report_content += "This heatmap shows which parts of the model were changed the most during fine-tuning (average L2 norm of gradients).\n\n"
    report_content += "![Gradient Impact Heatmap](gradient_impact_heatmap.png)\n"
    report_content += "\n**Interpretation:** Brighter colors (e.g., yellow) indicate layers and components that were heavily modified by the fine-tuning process. Darker colors (e.g., purple) indicate parts of the model that were left mostly unchanged.\n"

    # --- Save the final report ---
    try:
        with open(report_path, "w") as f:
            f.write(report_content)
        print(f"Final report saved to: {report_path}")
    except Exception as e:
        print(f"[ERROR] Failed to save final report: {e}")

def get_similarity_scores(model, processor, image, text_probes, device):
    """Helper function to get model similarity scores for a single image and text list."""
    inputs = processor(
        text=text_probes,
        images=[image], # Pass image as a list
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=64
    ).to(device)

    model.eval() # Set to evaluation mode
    with torch.no_grad():
        outputs = model(**inputs)
        # We use logits_per_image. We have 1 image, N texts. Shape (1, N)
        # Sigmoid gives a 0-1 "probability"
        scores = torch.sigmoid(outputs.logits_per_image).cpu().numpy().flatten()
    return scores

def plot_similarity_scores(base_scores, tuned_scores, probes, true_category, save_path):
    """Generates a bar chart comparing baseline and fine-tuned scores."""

    df_data = {
        "Text Probe": probes * 2,
        "Similarity Score": np.concatenate([base_scores, tuned_scores]),
        "Model": ["Baseline"] * len(probes) + ["Fine-Tuned"] * len(probes)
    }
    df = pd.DataFrame(df_data)

    # Set colors
    colors = []
    for probe in probes:
        colors.append("red" if probe == true_category else "grey")
    colors = colors * 2 # Apply to both baseline and tuned

    plt.figure(figsize=(15, 7))
    sns.barplot(
        data=df,
        x="Text Probe",
        y="Similarity Score",
        hue="Model",
        palette={"Baseline": "lightblue", "Fine-Tuned": "darkblue"}
    )

    # Add highlighting to the correct category
    ax = plt.gca()
    for i, probe in enumerate(probes):
        if probe == true_category:
            ax.get_xticklabels()[i].set_color("red")
            ax.get_xticklabels()[i].set_fontweight("bold")

    plt.title(f"Qualitative Similarity Test (True Category: {true_category})", fontsize=16)
    plt.ylabel("Similarity Score (Sigmoid)", fontsize=12)
    plt.xlabel("Text Probes", fontsize=12)
    plt.xticks(rotation=15)
    plt.legend(title="Model", fontsize=12)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    print(f"Similarity plot saved to: {save_path}")


# ===================================================================
#  Part 5: Run the Code
# ===================================================================

if __name__ == "__main__":
    try:
        torch.multiprocessing.set_start_method('spawn')
    except RuntimeError:
        pass

    main_training()

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Loading base model and processor from: google/siglip-base-patch16-224


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
`torch_dtype` is deprecated! Use `dtype` instead!


Applying LoRA configuration...
Model configured with LoRA. Trainable parameters:
trainable params: 1,179,648 || all params: 204,335,618 || trainable%: 0.5773
Loading SCIN dataset (streaming)...


Some datasets params were ignored: ['splits', 'download_size', 'dataset_size']. Make sure to use only valid params for the dataset builder and to have a up-to-date version of the `datasets` library.


Splitting 'train' stream: first 1000 for eval, rest for train.
Pre-loading 1000 samples for validation set...


Loading eval samples: 100%|██████████| 1000/1000 [01:06<00:00, 15.12it/s]


Initializing SCIN_Iterable_Dataset with provided iterable...
Initializing SCIN_List_Dataset with 747 pre-loaded samples.
Dataset iterators created. Train (streaming), Eval (747 samples).
Setting up training arguments...
CustomTrainer initialized with loss_type: sigmoid

RUNNING BASELINE EVALUATION (BEFORE FINE-TUNING)...


Baseline Evaluation Metrics:
{'eval_loss': 0.6139736771583557, 'eval_model_preparation_time': 0.0025, 'eval_runtime': 5.6339, 'eval_samples_per_second': 132.59, 'eval_steps_per_second': 8.342}

STARTING FINE-TUNING...


Step,Training Loss,Validation Loss,Model Preparation Time
250,0.2363,0.239409,0.0025
500,0.2345,0.236773,0.0025



RUNNING FINAL EVALUATION (AFTER FINE-TUNING)...


Final Evaluation Metrics:
{'eval_loss': 0.23677285015583038, 'eval_model_preparation_time': 0.0025, 'eval_runtime': 5.5861, 'eval_samples_per_second': 133.726, 'eval_steps_per_second': 8.414, 'epoch': 2.24}

SAVING MODEL AND GENERATING REPORT...
Training complete. LoRA adapter saved to: ./siglip-scin-lora/final-adapter

Generating final gradient heatmaps...
Heatmap saved to: ./siglip-scin-lora/gradient_impact_heatmap.png
Generating final report...

Running qualitative similarity test on image: https://upload.wikimedia.org/wikipedia/commons/thumb/6/67/Seborrhoeic_keratosis_-_close-up.jpg/800px-Seborrhoeic_keratosis_-_close-up.jpg
[ERROR] Failed to generate similarity plot: cannot identify image file <_io.BytesIO object at 0x7b82941f62f0>
Final report saved to: ./siglip-scin-lora/final_report.md
