### **Save the Sample for evaluation**

In [8]:
import torch
import traceback
from datasets import load_dataset
from PIL import Image
from tqdm import tqdm
import os

# --- Configuration --- #
N_VAL_SAMPLES = 1000 # Number of samples to save
OUTPUT_FILE = "eval_data.pt" # The file we will save

def prepare_and_save_data():
    """
    Downloads, sanitizes, and saves the validation dataset.
    Run this script ONCE.
    """
    print("Starting data preparation...")

    # --- STEP 1: LOAD DATASET STREAM ---
    print("Loading SCIN dataset (streaming)...")
    try:
        base_iterable = load_dataset("google/scin", split="train", streaming=True)
    except Exception as e:
        print(f"Failed to load dataset: {e}. Check dataset name/internet.")
        return

    eval_data_list = []
    image_columns = ["image_1_path", "image_2_path", "image_3_path"]

    print(f"Pre-loading and sanitizing {N_VAL_SAMPLES} samples...")

    # --- STEP 2: SANITIZE AND COLLECT DATA ---
    for item in tqdm(base_iterable.take(N_VAL_SAMPLES), total=N_VAL_SAMPLES, desc="Loading eval samples"):
        text = item.get("related_category")

        # [FIX] Check for None AND empty strings
        if not text or text.strip() == "":
            continue

        for img_col in image_columns:
            image = item.get(img_col)
            if image and isinstance(image, Image.Image):
                try:
                    # [FIX] Force convert to RGB and store the sanitized item
                    eval_data_list.append({"image": image.convert("RGB"), "text": text.strip()})
                except Exception as e:
                    print(f"Error converting image, skipping: {e}")
                # Found a valid image, break from inner loop
                break

    print(f"Collected {len(eval_data_list)} valid samples out of {N_VAL_SAMPLES}.")

    # --- STEP 3: SAVE TO FILE ---
    if not eval_data_list:
        print("ERROR: No data was collected. Aborting save.")
        return

    try:
        print(f"Saving {len(eval_data_list)} samples to {OUTPUT_FILE}...")
        torch.save(eval_data_list, OUTPUT_FILE)
        print("Data preparation complete.")
        print(f"\nYou can now run 'test_baseline_eval.py' repeatedly.")
    except Exception as e:
        print(f"Error saving data: {e}")

if __name__ == "__main__":
    prepare_and_save_data()

Starting data preparation...
Loading SCIN dataset (streaming)...


Some datasets params were ignored: ['splits', 'download_size', 'dataset_size']. Make sure to use only valid params for the dataset builder and to have a up-to-date version of the `datasets` library.


Pre-loading and sanitizing 1000 samples...


Loading eval samples: 100%|██████████| 1000/1000 [01:03<00:00, 15.77it/s]


Collected 747 valid samples out of 1000.
Saving 747 samples to eval_data.pt...
Data preparation complete.

You can now run 'test_baseline_eval.py' repeatedly.


### **Test the eval_data.pl file for correctness**

In [7]:
import torch

EVAL_DATA_FILE = "eval_data.pt"

print(f"--- INSPECTING {EVAL_DATA_FILE} ---")

try:
    # Load the data
    eval_data_list = torch.load(EVAL_DATA_FILE, weights_only=False)
    
    # Check what was loaded
    print(f"Type of loaded data: {type(eval_data_list)}")
    
    if isinstance(eval_data_list, list):
        print(f"Total items in list: {len(eval_data_list)}")
        print("\n--- FIRST 5 ITEMS ---")
        print(eval_data_list[:5])
        print("---------------------")

        # Check for 'None'
        none_count = sum(1 for item in eval_data_list if item is None)
        print(f"Count of 'None' items: {none_count} / {len(eval_data_list)}")
    else:
        print("ERROR: Data file is not a list as expected.")

except Exception as e:
    print(f"Failed to load or inspect file. Error: {e}")

--- INSPECTING eval_data.pt ---
Type of loaded data: <class 'list'>
Total items in list: 747

--- FIRST 5 ITEMS ---
[{'image': <PIL.Image.Image image mode=RGB size=810x779 at 0x70A8451BE050>, 'text': 'RASH'}, {'image': <PIL.Image.Image image mode=RGB size=810x1080 at 0x70A8451BFED0>, 'text': 'OTHER_ISSUE_DESCRIPTION'}, {'image': <PIL.Image.Image image mode=RGB size=810x1080 at 0x70A8451BE010>, 'text': 'OTHER_ISSUE_DESCRIPTION'}, {'image': <PIL.Image.Image image mode=RGB size=810x1080 at 0x70A8451BD090>, 'text': 'RASH'}, {'image': <PIL.Image.Image image mode=RGB size=810x1080 at 0x70A8451BED90>, 'text': 'RASH'}]
---------------------
Count of 'None' items: 0 / 747


### **Run baseline Evaluation on eval_data.pt file**

In [10]:
import torch
import os
import torch.nn.functional as F
from PIL import Image

from torchvision.transforms.functional import to_pil_image
from transformers import (
    AutoProcessor,
    AutoModel,
    TrainingArguments,
    Trainer
)
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import numpy as np

# --- Configuration --- #
MODEL_ID = "google/siglip-base-patch16-224"
BATCH_SIZE = 16
EVAL_DATA_FILE = "eval_data.pt" # The file created by prepare_data.py

# We will use it to clear the logs just once
global_batch_counter = 0

# ===================================================================
#  Part 1: The Dataset & Collate Code
# ===================================================================

class SCIN_List_Dataset(torch.utils.data.Dataset):
    """
    Custom PyTorch Dataset that wraps a simple list.
    This is used for the VALIDATION set to ensure it's reusable
    and works correctly with trainer.evaluate().
    """
    def __init__(self, data_list):
        print(f"Initializing SCIN_List_Dataset with {len(data_list)} pre-loaded samples.")
        self.data = data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch, processor):
    """
    [Robust, Item-by-Item] Data collator.
    This version includes explicit checks for None or empty content
    to prevent the processor from crashing.
    """
    processed_images = []
    processed_texts_input_ids = []
    skipped_count = 0

    for i, item in enumerate(batch):
        try:
            if item is None:
                print(f"WARNING (collate_fn): Skipping item {i} because it is 'None'.")
                skipped_count += 1
                continue
            
            img = item.get("image")
            txt = item.get("text")
            
            # --- THIS IS THE FIX ---
            # Explicitly check for bad content *before* calling the processor.
            # The processor will crash on None or empty strings.
            
            if img is None:
                print(f"WARNING (collate_fn): Skipping item {i} due to 'None' image.")
                skipped_count += 1
                continue
                
            if txt is None or txt.strip() == "":
                print(f"WARNING (collate_fn): Skipping item {i} due to 'None' or empty text.")
                skipped_count += 1
                continue
            # -------------------------

            # If we get here, img and txt are valid
            inputs = processor(
                text=[txt],
                images=[img],
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=64
            )
            
            processed_images.append(inputs["pixel_values"])
            processed_texts_input_ids.append(inputs["input_ids"])
            
        except Exception as e:
            # This now catches *other* errors, like a truly corrupt image file
            print(f"WARNING (collate_fn): Skipping item {i} due to UNEXPECTED error: {e}")
            skipped_count += 1
            
    if not processed_images:
        # This will still happen if an entire batch is bad
        print(f"ERROR: Entire batch was skipped! ({skipped_count} items failed)")
        return {}

    if skipped_count > 0:
        print(f"Collate: Skipped {skipped_count}/{len(batch)} items in this batch")

    try:
        batch_pixel_values = torch.cat(processed_images, dim=0)
        batch_input_ids = torch.cat(processed_texts_input_ids, dim=0)
        
        return {
            "pixel_values": batch_pixel_values,
            "input_ids": batch_input_ids
        }
    
    except Exception as e:
        print(f"Error during final batch collation: {e}")
        return {}


# ===================================================================
#  Part 2: The Corrected Logic
# ===================================================================

def compute_metrics(eval_pred):
    """
    Standard Hugging Face method to calculate metrics.
    This runs ONCE at the end of evaluation on the accumulated predictions.
    """
    # eval_pred.predictions is a tuple of (logits_per_image, logits_per_text)
    # We care about image-to-text matching (logits_per_image)
    logits = eval_pred.predictions
    
    # Handle tuple case
    if isinstance(logits, tuple):
        logits = logits[0]
    
    # In CLIP/SigLIP, the standard task is: given an image, find the correct text in the batch.
    # So labels are the diagonal (0, 1, 2, 3...).
    predictions = np.argmax(logits, axis=1)
    
    # Generate ground truth labels (diagonal alignment)
    true_labels = np.arange(len(predictions))
    
    precision, recall, f1, _ = precision_recall_fscore_support(
        true_labels, predictions, average='macro', zero_division=0
    )
    acc = accuracy_score(true_labels, predictions)
    
    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

class EvalTrainer(Trainer):
    """
    This is the final debug trainer.
    It includes a check for NaN/Inf values from the model's output.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.batch_count = 0
        self.successful_batches = 0
        self.skipped_batches = 0
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        self.batch_count += 1
        
        print(f"\n=== Batch {self.batch_count} ===")
        print(f"Inputs keys: {inputs.keys() if inputs else 'NONE'}")
        
        if not inputs or "pixel_values" not in inputs:
             print(f"WARNING: Empty batch received in compute_loss (batch {self.batch_count})")
             self.skipped_batches += 1
             # CRITICAL: Don't return dummy loss, raise exception to truly skip
             raise ValueError("Empty batch - skipping")

        print(f"Pixel values shape: {inputs['pixel_values'].shape}")
        print(f"Input ids shape: {inputs['input_ids'].shape}")
        
        outputs = model(**inputs)
        
        print(f"Model output keys: {outputs.keys() if hasattr(outputs, 'keys') else type(outputs)}")
        
        # --- FINAL NaN/Inf CHECK ---
        if torch.isnan(outputs.logits_per_image).any() or torch.isinf(outputs.logits_per_image).any():
            print("DEBUG: NaN or Inf DETECTED in model logits.")
            self.skipped_batches += 1
            raise ValueError("NaN/Inf in logits - skipping")
        # ---------------------------

        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text
        batch_size = logits_per_image.shape[0]
        
        print(f"Batch size: {batch_size}")
        print(f"Logits_per_image shape: {logits_per_image.shape}")
        
        if batch_size <= 1:
             print(f"WARNING: Batch size is {batch_size}, cannot compute contrastive loss")
             self.skipped_batches += 1
             raise ValueError(f"Batch size too small: {batch_size}")

        labels = torch.eye(batch_size, device=model.device)
        loss_images = F.binary_cross_entropy_with_logits(logits_per_image, labels)
        loss_text = F.binary_cross_entropy_with_logits(logits_per_text, labels)
        loss = (loss_images + loss_text) / 2.0
        
        # print(f"✓ Computed loss: {loss.item():.4f}")
        self.successful_batches += 1
        
        return (loss, outputs) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        """
        Override to ensure we compute loss AND return predictions for compute_metrics.
        This is called during evaluation, NOT compute_loss!
        """
        self.batch_count += 1
        
        # print(f"\n=== Batch {self.batch_count} (prediction_step) ===")
        # print(f"Inputs keys: {inputs.keys() if inputs else 'NONE'}")
        
        # Handle empty batches
        if not inputs or "pixel_values" not in inputs:
            print(f"WARNING: Empty batch in prediction_step (batch {self.batch_count})")
            self.skipped_batches += 1
            return (None, None, None)
        
        # print(f"Pixel values shape: {inputs['pixel_values'].shape}")
        # print(f"Input ids shape: {inputs['input_ids'].shape}")
        
        # We need to manually compute loss and get outputs
        with torch.no_grad():
            outputs = model(**inputs)
            
            # print(f"Model output type: {type(outputs)}")
            
            logits_per_image = outputs.logits_per_image
            logits_per_text = outputs.logits_per_text
            batch_size = logits_per_image.shape[0]
            
            # print(f"Batch size: {batch_size}")
            # print(f"Logits_per_image shape: {logits_per_image.shape}")
            
            # Check for NaN/Inf
            if torch.isnan(logits_per_image).any() or torch.isinf(logits_per_image).any():
                print("WARNING: NaN or Inf DETECTED in model logits.")
                self.skipped_batches += 1
                return (None, None, None)
            
            # Compute loss
            if batch_size <= 1:
                print(f"WARNING: Batch size is {batch_size}, cannot compute contrastive loss")
                self.skipped_batches += 1
                # Still return outputs for metrics, but no loss
                loss = None
            else:
                labels = torch.eye(batch_size, device=model.device)
                loss_images = F.binary_cross_entropy_with_logits(logits_per_image, labels)
                loss_text = F.binary_cross_entropy_with_logits(logits_per_text, labels)
                loss = (loss_images + loss_text) / 2.0
                # print(f"✓ Computed loss: {loss.item():.4f}")
                self.successful_batches += 1
        
        # Return (loss, logits, labels)
        # For contrastive learning, we don't have explicit labels, so return None
        # The logits are what compute_metrics will use
        logits_tuple = (logits_per_image.cpu(), logits_per_text.cpu())
        
        return (loss, logits_tuple, None)

# ===================================================================
#  Part 3: The Main Test Logic
# ===================================================================

def run_baseline_test():

    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    print(f"Using device: {device}")

    # --- STEP 1: LOAD MODEL AND PROCESSOR ---
    print(f"Loading base model and processor from: {MODEL_ID}")
    
    try:
        processor = AutoProcessor.from_pretrained(MODEL_ID)
        
        # Load the model in its default precision
        base_model = AutoModel.from_pretrained(MODEL_ID)
        
        # NOW, explicitly cast the ENTIRE model to float32 and move to device
        base_model = base_model.to(device).float()
        
        # This debug line will PROVE the model is in float32
        print(f"DEBUG: Model parameters are on device: {next(base_model.parameters()).device}")
        print(f"DEBUG: Model parameters are dtype: {next(base_model.parameters()).dtype}")

    except Exception as e:
        print(f"Error loading model: {e}. Check internet connection/model ID.")
        return

    # --- STEP 2: LOAD PRE-PREPARED DATA ---
    print(f"Loading pre-prepared data from {EVAL_DATA_FILE}...")
    
    if not os.path.exists(EVAL_DATA_FILE):
        print(f"ERROR: {EVAL_DATA_FILE} not found.")
        print("Please run 'prepare_data.py' first to create this file.")
        return
        
    try:
        eval_data_list = torch.load(EVAL_DATA_FILE, weights_only=False)
        eval_dataset = SCIN_List_Dataset(data_list=eval_data_list)
        print(f"Loaded {len(eval_dataset)} valid samples.")
    except Exception as e:
        print(f"Error loading {EVAL_DATA_FILE}: {e}")
        return
    
    if len(eval_dataset) == 0:
        print("ERROR: No valid data was loaded for evaluation. Stopping.")
        return

    # --- STEP 3: SET UP TRAINER ---
    print("Setting up training arguments and trainer...")
    
    training_args = TrainingArguments(
        output_dir="./temp_eval_test",
        per_device_eval_batch_size=BATCH_SIZE,
        report_to="none",
        remove_unused_columns=False,
        disable_tqdm=False,  # Enable tqdm to see progress
        # IMPORTANT: These ensure predictions are computed
        prediction_loss_only=False,
    )

    trainer = EvalTrainer(
        model=base_model,
        args=training_args,
        eval_dataset=eval_dataset,
        data_collator=lambda data: collate_fn(data, processor),
        compute_metrics=compute_metrics,
    )

    # --- STEP 4: RUN BASELINE EVALUATION ---
    print("\n" + "="*50)
    print("RUNNING BASELINE EVALUATION...")
    print("="*50)
    
    # This calls trainer.evaluate()
    baseline_metrics = trainer.evaluate()
    
    print("\n" + "="*50)
    print(f"BATCH STATISTICS:")
    print(f"  Total batches attempted: {trainer.batch_count}")
    print(f"  Successful batches: {trainer.successful_batches}")
    print(f"  Skipped batches: {trainer.skipped_batches}")
    print("="*50)
    
    print("\n" + "="*50)
    print("TEST COMPLETE. BASELINE METRICS:")
    print("="*50)
    for key, value in baseline_metrics.items():
        print(f"  {key}: {value}")
    print("\n")
    
    if "eval_loss" in baseline_metrics and baseline_metrics["eval_loss"] > 0:
        print("✓ SUCCESS: 'eval_loss' and other metrics were successfully calculated.")
    else:
        print("✗ FAILURE: 'eval_loss' is missing or zero. Batches were likely skipped.")
        print("\nDEBUG TIPS:")
        print("1. Check if any 'WARNING' messages appeared above")
        print("2. Verify your eval_data.pt file has valid image/text pairs")
        print("3. Try reducing BATCH_SIZE to see if smaller batches work")




if __name__ == "__main__":
    # Set multiprocessing context if needed
    try:
        torch.multiprocessing.set_start_method('spawn')
    except RuntimeError:
        pass 

    run_baseline_test()

Using device: cuda
Loading base model and processor from: google/siglip-base-patch16-224
DEBUG: Model parameters are on device: cuda:0
DEBUG: Model parameters are dtype: torch.float32
Loading pre-prepared data from eval_data.pt...
Initializing SCIN_List_Dataset with 747 pre-loaded samples.
Loaded 747 valid samples.
Setting up training arguments and trainer...

RUNNING BASELINE EVALUATION...



BATCH STATISTICS:
  Total batches attempted: 47
  Successful batches: 47
  Skipped batches: 0

TEST COMPLETE. BASELINE METRICS:
  eval_loss: 0.6112768650054932
  eval_model_preparation_time: 0.0024
  eval_runtime: 6.6985
  eval_samples_per_second: 111.518
  eval_steps_per_second: 7.017


✓ SUCCESS: 'eval_loss' and other metrics were successfully calculated.
