**Using custom loss function**

https://www.youtube.com/watch?v=Hm8_PgVTFuc

In [None]:
%pip install torch torchvision transformers datasets peft accelerate Pillow

In [1]:
import torch
import os
import torch.nn.functional as F
from datasets import load_dataset
from PIL import Image
from transformers import (
    AutoProcessor,
    AutoModel,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model

# --- Configuration --- #
# This section defines key parameters that control the model, training, and LoRA configuration.
# Adjusting these values can significantly impact performance, training time, and resource usage.
MODEL_ID = "google/siglip-base-patch16-224" # The base pre-trained SigLIP model to fine-tune.
                                           # Importance: This is the foundation of our fine-tuned model.
                                           # Rationale: SigLIP models are good for vision-language tasks.
OUTPUT_DIR = "./siglip-scin-lora"         # Directory to save the fine-tuned LoRA adapter and processor.
                                           # Importance: Essential for model persistence and later inference.
BATCH_SIZE = 16                            # Number of samples processed in parallel during training and evaluation.
                                           # Impact (Higher): Faster training per epoch, but higher memory usage.
                                           #                  Can lead to poorer generalization if too large.
                                           # Impact (Lower): Slower training, less memory usage.
                                           #                 Can lead to more noisy gradients but potentially better generalization.
LEARNING_RATE = 1e-4                       # The initial learning rate for the optimizer.
                                           # Importance: Controls the step size during model weight updates.
                                           # Impact (Higher): Model may converge faster but risk overshooting the optimum (divergence).
                                           # Impact (Lower): Slower convergence, but potentially more stable training and a better optimum.
LORA_RANK = 16                             # The rank (r) of the low-rank matrices in LoRA.
                                           # Importance: Determines the capacity of the LoRA adapter. Higher rank allows more expressiveness.
                                           # Impact (Higher): More trainable parameters, higher memory usage, potentially better performance
                                           #                  but increased risk of overfitting and slower training.
                                           # Impact (Lower): Fewer trainable parameters, lower memory usage, faster training,
                                           #                 but might not capture complex relationships (underfitting).
LORA_ALPHA = 16                            # Scaling factor for the LoRA update. Often `lora_alpha = lora_rank`.
                                           # Importance: Scales the impact of the LoRA weights.
                                           # Impact (Higher): LoRA adapters have a stronger influence on the base model.
                                           # Impact (Lower): LoRA adapters have a weaker influence.
# --------------------- #


# ===================================================================
#  Part 1: The Dataset Loader Code
# ===================================================================

class SCIN_Dataset(torch.utils.data.IterableDataset):
    """
    Custom PyTorch IterableDataset for the google/scin dataset.
    This class STREAMS the dataset to prevent OOM errors.

    Importance: Streaming datasets are crucial for large datasets that cannot fit into memory,
                preventing Out-Of-Memory (OOM) errors and allowing training on resource-constrained systems.
    Rationale: The `IterableDataset` does not preload the entire dataset. Instead, it loads items
               one by one as requested, making it memory-efficient.
    """
    def __init__(self, split="train"):
        print(f"Loading 'google/scin' dataset in STREAMING mode: {split}...")
        try:
            # Loads the 'google/scin' dataset in streaming mode.
            # 'streaming=True' ensures that data is loaded on-the-fly, not all at once.
            self.dataset = load_dataset("google/scin", split=split, streaming=True)
        except Exception as e:
            print(f"Failed to load dataset 'google/scin'. Error: {e}")
            raise

        # Defines the columns in the dataset that contain image paths.
        # The iterator will loop through these to find valid image data.
        self.image_columns = ["image_1_path", "image_2_path", "image_3_path"]

    def __iter__(self):
        """
        Yields dictionaries containing processed image and text data.
        Importance: This method defines how individual data samples are provided to the DataLoader.
        """
        for item in self.dataset:
            # Extracts the 'related_category' as the text label.
            text = item.get("related_category")
            # Skips items without a valid text label.
            if not text or not isinstance(text, str):
                continue

            # Iterates through potential image columns to find an actual image.
            for img_col in self.image_columns:
                image = item.get(img_col)
                # If a valid PIL Image is found, yield it along with the text.
                if image and isinstance(image, Image.Image):
                    yield {
                        "image": image,
                        "text": text
                    }

def collate_fn(batch, processor):
    """
    Data collator for batching.
    Importance: This function takes a list of individual samples (output of __iter__)
                and combines them into a single batch suitable for model input.
    Rationale: Models typically process data in batches for efficiency. This function
               handles tasks like tokenization, image preprocessing, and padding to create uniform batches.
    """
    # Separates images and texts from the incoming batch.
    images = [item["image"] for item in batch]
    texts = [item["text"] for item in batch]

    try:
        # Uses the AutoProcessor to prepare both text and images for the model.
        # `text`: List of text labels.
        # `images`: List of PIL images.
        # `return_tensors="pt"`: Returns PyTorch tensors.
        # `padding="max_length"`: Pads sequences to the `max_length`.
        # `truncation=True`: Truncates sequences longer than `max_length`.
        # `max_length=64`: The maximum sequence length for tokenization.
        #                  Importance: Matches the model's expected input size.
        #                  Impact (Higher): Can capture more context but increases memory and computation.
        #                  Impact (Lower): Faster, less memory, but might lose important information if text is long.
        inputs = processor(
            text=texts,
            images=images,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=64 # Correct max length for SigLIP
        )
        return inputs
    except Exception as e:
        print(f"Error during processing batch: {e}")
        return {}


# ===================================================================
#  Part 2: The Custom Trainer
# ===================================================================

class CustomTrainer(Trainer):
    """
    Custom Trainer to compute the contrastive loss (SigLIP loss).

    Importance: Overriding `compute_loss` allows us to implement specific loss functions
                that are not directly provided by the standard Hugging Face `Trainer`.
    Rationale: SigLIP models are trained using a contrastive loss, specifically a symmetric
               cross-entropy loss between image and text embeddings. This custom trainer
               implements that specific loss formulation.
    """
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # Get the model's outputs (logits).
        # The model takes processed inputs (image and text features) and outputs
        # similarity scores (logits) between images and texts.
        outputs = model(**inputs)

        # `logits_per_image` are the similarity scores where each row corresponds to an image
        # and columns correspond to texts.
        # `logits_per_text` are the similarity scores where each row corresponds to a text
        # and columns correspond to images.
        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text

        # Create the contrastive labels.
        # For contrastive learning, the diagonal elements of the similarity matrix represent
        # positive pairs (e.g., image[i] with text[i]), and off-diagonal are negative pairs.
        # `labels` is an identity matrix implicitly created by `torch.arange`, where `labels[i]`
        # corresponds to the correct text for `image[i]` (and vice-versa).
        batch_size = logits_per_image.shape[0]
        labels = torch.arange(batch_size, device=model.device) # e.g., [0, 1, 2, ..., batch_size-1]

        # Calculate the symmetric cross-entropy loss.
        # `F.cross_entropy` computes the loss between the logits and the true labels.
        # `loss_images`: Measures how well images predict their corresponding texts.
        # `loss_text`: Measures how well texts predict their corresponding images.
        loss_images = F.cross_entropy(logits_per_image, labels)
        loss_text = F.cross_entropy(logits_per_text, labels)

        # The total loss is the average of the two cross-entropy losses, making it symmetric.
        loss = (loss_images + loss_text) / 2.0

        return (loss, outputs) if return_outputs else loss


# ===================================================================
#  Part 3: The Main Training Logic
# ===================================================================

def main_training():

    # Determine the available device (GPU or CPU) for computation.
    # Importance: Utilizing a GPU (CUDA or MPS) significantly speeds up training for deep learning models.
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    print(f"Using device: {device}")

    # --- STEP 1: LOAD MODEL AND PROCESSOR ---
    print(f"Loading base model and processor from: {MODEL_ID}")

    # The processor handles image transformations (e.g., resizing, normalization) and text tokenization.
    # Importance: Ensures input data is in the correct format for the model.
    processor = AutoProcessor.from_pretrained(MODEL_ID)

    # Set torch_dtype based on device to leverage mixed-precision training on CUDA devices.
    # Importance: `float16` (half-precision) reduces memory usage and speeds up computations on compatible hardware
    #             (like NVIDIA GPUs with Tensor Cores), while maintaining sufficient accuracy.
    # Rationale: Training in `float32` (full-precision) for non-CUDA devices to ensure compatibility and stability.
    dtype = torch.float16 if device == "cuda" else torch.float32

    # Loads the pre-trained SigLIP model from the Hugging Face Hub.
    # Importance: Provides a strong starting point for fine-tuning, leveraging knowledge learned from massive datasets.
    model = AutoModel.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype # Applies the selected precision.
    )

    # --- STEP 2: CONFIGURE AND APPLY LORA ---
    print("Applying LoRA configuration...")

    # Configures Low-Rank Adaptation (LoRA).
    # Importance: LoRA significantly reduces the number of trainable parameters during fine-tuning,
    #             making it more memory-efficient and faster than full fine-tuning, while achieving comparable performance.
    # Rationale: Instead of updating all model weights, LoRA injects small, trainable low-rank matrices into the model.
    lora_config = LoraConfig(
        r=LORA_RANK,                         # Rank of the update matrices. See `LORA_RANK` above.
        lora_alpha=LORA_ALPHA,               # LoRA scaling factor. See `LORA_ALPHA` above.
        target_modules=["q_proj", "v_proj"], # Modules within the base model to apply LoRA to.
                                           # Importance: These are typically attention projection layers where LoRA is effective.
        lora_dropout=0.1,                    # Dropout probability for the LoRA layers.
                                           # Importance: Helps prevent overfitting by randomly setting some LoRA weights to zero during training.
                                           # Impact (Higher): More regularization, less likely to overfit but can underfit.
                                           # Impact (Lower): Less regularization, more likely to overfit.
        bias="none",                         # How to handle bias terms. "none" means no bias is added to LoRA layers.
                                           # Importance: Simplifies the LoRA layers and often sufficient.
    )

    # Applies the LoRA configuration to the base model, creating a PeftModel.
    # This wraps the original model with trainable LoRA layers.
    model = get_peft_model(model, lora_config)
    model = model.to(device) # Moves the model to the selected device.

    print("Model configured with LoRA. Trainable parameters:")
    # Prints a summary of trainable parameters, highlighting the significant reduction due to LoRA.
    model.print_trainable_parameters()

    # --- STEP 3: LOAD DATASET ---
    print("Loading SCIN dataset (streaming)...")

    # Initializes the streaming datasets for training and evaluation.
    # Rationale: Using the custom `SCIN_Dataset` ensures memory-efficient loading.
    # Note: For this example, eval_dataset also uses 'train' split for simplicity,
    #       but in a real scenario, you'd use a separate 'validation' or 'test' split.
    train_dataset = SCIN_Dataset(split="train")
    eval_dataset = SCIN_Dataset(split="train") # Using train for eval as per original notebook

    print("Dataset iterators created.")

    # --- STEP 4: SET UP TRAINING ---
    print("Setting up training arguments...")

    # Enable fp16 only if on CUDA to utilize mixed-precision benefits.
    use_fp16 = True if device == "cuda" else False

    # Defines the training arguments using Hugging Face's TrainingArguments.
    # Importance: Configures various aspects of the training loop.
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,                     # Directory for model checkpoints and logs.
        per_device_train_batch_size=BATCH_SIZE,    # Batch size per GPU/CPU for training. See `BATCH_SIZE` above.
        per_device_eval_batch_size=BATCH_SIZE,     # Batch size per GPU/CPU for evaluation. See `BATCH_SIZE` above.
        max_steps=500,                             # Total number of training steps.
                                                   # Importance: Controls the duration of training.
                                                   # Impact (Higher): More training, potentially better performance, but longer time and risk of overfitting.
                                                   # Impact (Lower): Faster training, but potential underfitting if model doesn't learn enough.
        weight_decay=0.01,                         # Strength of L2 regularization.
                                                   # Importance: Helps prevent overfitting by penalizing large weights.
                                                   # Impact (Higher): Stronger regularization, can lead to underfitting.
                                                   # Impact (Lower): Weaker regularization, can lead to overfitting.
        learning_rate=LEARNING_RATE,               # Initial learning rate. See `LEARNING_RATE` above.
        warmup_steps=50,                           # Number of steps for linear learning rate warmup.
                                                   # Importance: Gradually increases learning rate from zero, helping stabilize training at the beginning.
        logging_steps=10,                          # Log training metrics every N steps.
        save_strategy="steps",                     # Save model checkpoint based on steps.
        save_steps=250,                            # Save a checkpoint every N steps.
        eval_strategy="steps",                     # Evaluate model performance based on steps.
        eval_steps=250,                            # Run evaluation every N steps.

        # --- [THE FIX] --- #
        # We set this to False because our IterableDataset doesn't produce the 'eval_loss' metric
        # in a way that `load_best_model_at_end` can reliably track. This prevents errors.
        # Importance: Essential when using iterable datasets or custom evaluation metrics.
        load_best_model_at_end=False,

        fp16=use_fp16,                             # Enables mixed-precision training if CUDA is available.
        report_to="none",                          # Disables integration with experiment tracking tools (e.g., Weights & Biases).
                                                   # Rationale: Simplifies the example by not requiring external logging setup.
        remove_unused_columns=False,               # Prevents the Trainer from removing columns not used by the model's forward pass.
                                                   # Importance: Necessary when your dataset has extra columns you might need later or for custom processing.
    )

    # Initializes the custom trainer with the LoRA-enabled model, arguments, datasets, and collator.
    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=lambda data: collate_fn(data, processor), # Uses our custom collate function.
    )

    # --- STEP 5: RUN FINE-TUNING ---
    print("Starting fine-tuning...")
    trainer.train() # Initiates the training loop.

    # Define the path to save the final LoRA adapter.
    final_adapter_path = os.path.join(OUTPUT_DIR, "final-adapter")
    # Saves only the trainable LoRA weights, not the entire base model.
    # Importance: This is a key advantage of LoRA, as it produces a small, portable adapter file.
    model.save_pretrained(final_adapter_path)
    # Saves the processor used during training, ensuring consistency during inference.
    processor.save_pretrained(final_adapter_path)

    print(f"Training complete. LoRA adapter saved to: {final_adapter_path}")

# ===================================================================
#  Part 4: Run the Code
# ===================================================================

if __name__ == "__main__":
    # Sets the start method for multiprocessing to 'spawn'.
    # Importance: This is often necessary in PyTorch when using CUDA with multiprocessing
    #             to prevent issues like deadlocks or unexpected behavior.
    # Rationale: 'spawn' creates fresh, independent child processes, avoiding resource conflicts.
    try:
        torch.multiprocessing.set_start_method('spawn')
    except RuntimeError:
        pass # Ignore if it's already set or not needed.

    main_training() # Calls the main training function to start the process.

Using device: cuda
Loading base model and processor from: google/siglip-base-patch16-224


preprocessor_config.json:   0%|          | 0.00/368 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/711 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/798k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/409 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/432 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/813M [00:00<?, ?B/s]

Applying LoRA configuration...
Model configured with LoRA. Trainable parameters:
trainable params: 1,179,648 || all params: 204,335,618 || trainable%: 0.5773
Loading SCIN dataset (streaming)...
Loading 'google/scin' dataset in STREAMING mode: train...


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Some datasets params were ignored: ['splits', 'download_size', 'dataset_size']. Make sure to use only valid params for the dataset builder and to have a up-to-date version of the `datasets` library.


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Loading 'google/scin' dataset in STREAMING mode: train...


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Some datasets params were ignored: ['splits', 'download_size', 'dataset_size']. Make sure to use only valid params for the dataset builder and to have a up-to-date version of the `datasets` library.


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Dataset iterators created.
Setting up training arguments...
Starting fine-tuning...


Step,Training Loss,Validation Loss
250,2.7439,No log


Step,Training Loss,Validation Loss
250,2.7439,No log
500,2.7915,No log


Training complete. LoRA adapter saved to: ./siglip-scin-lora/final-adapter


🚀 How to Load the Model for Inference

In [2]:
import torch
from transformers import AutoModel, AutoProcessor
from peft import PeftModel
from PIL import Image

# 1. Define your models and device
base_model_id = "google/siglip-base-patch16-224"
adapter_path = "./siglip-scin-lora/final-adapter"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

# 2. Load the base model (the original, large model)
base_model = AutoModel.from_pretrained(
    base_model_id,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
)

# 3. Load the processor (this was saved with your adapter)
processor = AutoProcessor.from_pretrained(adapter_path)

# 4. Load and apply the LoRA adapter
print(f"Loading LoRA adapter from: {adapter_path}")
model = PeftModel.from_pretrained(base_model, adapter_path)

# 5. [Recommended] Merge for faster inference
# This combines the LoRA weights back into the base model.
# After this, it's just like a regular model.
model = model.merge_and_unload()
model = model.to(device)
model.eval() # Set model to evaluation mode

print("Model loaded and ready for inference!")

Loading LoRA adapter from: ./siglip-scin-lora/final-adapter
Model loaded and ready for inference!


✅ How to Test If It Trained Correctly

In [3]:
import torch
from transformers import AutoModel, AutoProcessor
from peft import PeftModel
from PIL import Image
from datasets import load_dataset
import warnings

# Suppress harmless warnings
warnings.filterwarnings("ignore")

@torch.no_grad() # We don't need to calculate gradients for testing
def test_model():
    # --- 1. Load the fine-tuned model (same as above) ---
    base_model_id = "google/siglip-base-patch16-224"
    adapter_path = "./siglip-scin-lora/final-adapter"
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    print(f"Using device: {device}")

    base_model = AutoModel.from_pretrained(
        base_model_id,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32
    )
    processor = AutoProcessor.from_pretrained(adapter_path)
    model = PeftModel.from_pretrained(base_model, adapter_path)
    model = model.merge_and_unload()
    model = model.to(device)
    model.eval()
    print("Fine-tuned model loaded.")

    # --- 2. Get a test image ---
    # We'll stream one item from the dataset to use as a test
    # This is an image the model *might* have seen, but it's a good first check.
    # For a *real* test, you should use a completely new, unseen image.
    try:
        test_data = next(iter(load_dataset("google/scin", split="train", streaming=True)))

        # Find the first valid image in the test item
        test_image = None
        for col in ["image_1_path", "image_2_path", "image_3_path"]:
            if test_data[col] and isinstance(test_data[col], Image.Image):
                test_image = test_data[col]
                break

        if test_image is None:
            print("Error: Could not load a test image from the dataset.")
            return

        print(f"Test image loaded. The correct label is: {test_data['related_category']}")
    except Exception as e:
        print(f"Failed to load test image: {e}")
        print("Please provide your own image by using: test_image = Image.open('path/to/your/image.jpg')")
        return

    # --- 3. Define your text labels ---
    # These are the "classes" we want to choose from
    text_labels = [
        "an image of ACNE",
        "an image of ECZEMA",
        "an image of a RASH",
        "a photo of a MOLE",
        "a picture of PSORIASIS",
        "a photo of healthy, normal skin"
    ]
    print(f"Testing against labels: {text_labels}")

    # --- 4. Process the image and text ---
    # Note: We process the image once and the text labels all at once
    inputs = processor(
        text=text_labels,
        images=[test_image], # Pass the image as a list
        return_tensors="pt",
        padding="max_length", # Pad text to the max length
        truncation=True,
        max_length=64 # Use the 64-token limit
    ).to(device)

    # --- 5. Get model predictions ---
    outputs = model(**inputs)

    # This gives us the similarity scores
    logits_per_image = outputs.logits_per_image

    # We apply softmax to turn scores into probabilities (0% to 100%)
    probs = logits_per_image.softmax(dim=1)

    # --- 6. Show the results ---
    print("\n--- Test Results ---")

    # Get the top 3 predictions
    top_k_values, top_k_indices = torch.topk(probs, 3)

    for i in range(top_k_values.shape[1]):
        value = top_k_values[0, i].item() * 100 # as percentage
        label_index = top_k_indices[0, i].item()
        label_name = text_labels[label_index]
        print(f"{i+1}. Predicted Label: {label_name:<25} | Confidence: {value:2.2f}%")

    print("\nTest complete. If the top prediction matches the 'correct label', the model is working!")

# Run the test
if __name__ == "__main__":
    test_model()

Using device: cuda
Fine-tuned model loaded.


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Some datasets params were ignored: ['splits', 'download_size', 'dataset_size']. Make sure to use only valid params for the dataset builder and to have a up-to-date version of the `datasets` library.


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Test image loaded. The correct label is: RASH
Testing against labels: ['an image of ACNE', 'an image of ECZEMA', 'an image of a RASH', 'a photo of a MOLE', 'a picture of PSORIASIS', 'a photo of healthy, normal skin']

--- Test Results ---
1. Predicted Label: an image of ECZEMA        | Confidence: 43.38%
2. Predicted Label: a picture of PSORIASIS    | Confidence: 40.45%
3. Predicted Label: an image of a RASH        | Confidence: 12.62%

Test complete. If the top prediction matches the 'correct label', the model is working!


**🧪 "Before" SigLip Training (Pre-Training)**

In [4]:
import torch
from transformers import AutoModel, AutoProcessor
# Note: We don't import PeftModel, as we are not loading an adapter
from PIL import Image
from datasets import load_dataset
import warnings

# Suppress harmless warnings
warnings.filterwarnings("ignore")

@torch.no_grad() # We don't need to calculate gradients for testing
def test_base_model():
    # --- 1. Load the ORIGINAL base model ---
    base_model_id = "google/siglip-base-patch16-224"
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    print(f"Using device: {device}")

    # We load the original model and processor directly from the Hugging Face Hub
    model = AutoModel.from_pretrained(
        base_model_id,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32
    ).to(device)

    processor = AutoProcessor.from_pretrained(base_model_id)

    model.eval() # Set model to evaluation mode
    print("Original BASE model loaded.")

    # --- 2. Get a test image ---
    # This section is identical to your previous test
    try:
        test_data = next(iter(load_dataset("google/scin", split="train", streaming=True)))

        # Find the first valid image in the test item
        test_image = None
        for col in ["image_1_path", "image_2_path", "image_3_path"]:
            if test_data[col] and isinstance(test_data[col], Image.Image):
                test_image = test_data[col]
                break

        if test_image is None:
            print("Error: Could not load a test image from the dataset.")
            return

        print(f"Test image loaded. The correct label is: {test_data['related_category']}")
    except Exception as e:
        print(f"Failed to load test image: {e}")
        print("Please provide your own image by using: test_image = Image.open('path/to/your/image.jpg')")
        return

    # --- 3. Define your text labels ---
    # Identical to your previous test
    text_labels = [
        "an image of ACNE",
        "an image of ECZEMA",
        "an image of a RASH",
        "a photo of a MOLE",
        "a picture of PSORIASIS",
        "a photo of healthy, normal skin"
    ]
    print(f"Testing against labels: {text_labels}")

    # --- 4. Process the image and text ---
    # Identical to your previous test
    inputs = processor(
        text=text_labels,
        images=[test_image],
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=64
    ).to(device)

    # --- 5. Get model predictions ---
    # Identical to your previous test
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)

    # --- 6. Show the results ---
    # Identical to your previous test
    print("\n--- 'BEFORE' Test Results (Original Model) ---")

    top_k_values, top_k_indices = torch.topk(probs, 3)

    for i in range(top_k_values.shape[1]):
        value = top_k_values[0, i].item() * 100 # as percentage
        label_index = top_k_indices[0, i].item()
        label_name = text_labels[label_index]
        print(f"{i+1}. Predicted Label: {label_name:<25} | Confidence: {value:2.2f}%")

# Run the test
if __name__ == "__main__":
    test_base_model()

Using device: cuda
Original BASE model loaded.


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Some datasets params were ignored: ['splits', 'download_size', 'dataset_size']. Make sure to use only valid params for the dataset builder and to have a up-to-date version of the `datasets` library.


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Test image loaded. The correct label is: RASH
Testing against labels: ['an image of ACNE', 'an image of ECZEMA', 'an image of a RASH', 'a photo of a MOLE', 'a picture of PSORIASIS', 'a photo of healthy, normal skin']

--- 'BEFORE' Test Results (Original Model) ---
1. Predicted Label: an image of a RASH        | Confidence: 88.77%
2. Predicted Label: an image of ECZEMA        | Confidence: 8.00%
3. Predicted Label: a picture of PSORIASIS    | Confidence: 2.24%


🩺 **"MedSigLIP" Test Script (For Comparison)**

In [5]:
import torch
from transformers import AutoModel, AutoProcessor
# Note: We don't import PeftModel
from PIL import Image
from datasets import load_dataset
import warnings

# Suppress harmless warnings
warnings.filterwarnings("ignore")

@torch.no_grad() # We don't need to calculate gradients for testing
def test_medsiglip_model():

    # --- 1. Load the MedSigLIP model ---

    # [THE ONLY CHANGE IS HERE]
    # We're now loading the pre-trained medical specialist model
    base_model_id = "google/medsiglip-448"

    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    print(f"Using device: {device}")
    print(f"Loading base model: {base_model_id}")

    # We load the original model and processor directly from the Hugging Face Hub
    model = AutoModel.from_pretrained(
        base_model_id,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32
    ).to(device)

    processor = AutoProcessor.from_pretrained(base_model_id)

    model.eval() # Set model to evaluation mode
    print("Original MedSigLIP model loaded.")

    # --- 2. Get a test image ---
    # This section is identical
    try:
        test_data = next(iter(load_dataset("google/scin", split="train", streaming=True)))

        test_image = None
        for col in ["image_1_path", "image_2_path", "image_3_path"]:
            if test_data[col] and isinstance(test_data[col], Image.Image):
                test_image = test_data[col]
                break

        if test_image is None:
            print("Error: Could not load a test image from the dataset.")
            return

        print(f"Test image loaded. The correct label is: {test_data['related_category']}")
    except Exception as e:
        print(f"Failed to load test image: {e}")
        print("Please provide your own image by using: test_image = Image.open('path/to/your/image.jpg')")
        return

    # --- 3. Define your text labels ---
    # Identical
    text_labels = [
        "an image of ACNE",
        "an image of ECZEMA",
        "an image of a RASH",
        "a photo of a MOLE",
        "a picture of PSORISASIS", # Corrected spelling for accuracy
        "a photo of healthy, normal skin"
    ]
    print(f"Testing against labels: {text_labels}")

    # --- 4. Process the image and text ---
    # Identical, max_length=64 is still correct
    inputs = processor(
        text=text_labels,
        images=[test_image],
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=64
    ).to(device)

    # --- 5. Get model predictions ---
    # Identical
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)

    # --- 6. Show the results ---
    print("\n--- 'MedSigLIP' Test Results (Original Model) ---")

    top_k_values, top_k_indices = torch.topk(probs, 3)

    for i in range(top_k_values.shape[1]):
        value = top_k_values[0, i].item() * 100 # as percentage
        label_index = top_k_indices[0, i].item()
        label_name = text_labels[label_index]
        print(f"{i+1}. Predicted Label: {label_name:<25} | Confidence: {value:2.2f}%")

# Run the test
if __name__ == "__main__":
    test_medsiglip_model()

Using device: cuda
Loading base model: google/medsiglip-448


config.json:   0%|          | 0.00/879 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.51G [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/360 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/809 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/798k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/455 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.40M [00:00<?, ?B/s]

Original MedSigLIP model loaded.


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Some datasets params were ignored: ['splits', 'download_size', 'dataset_size']. Make sure to use only valid params for the dataset builder and to have a up-to-date version of the `datasets` library.


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Test image loaded. The correct label is: RASH
Testing against labels: ['an image of ACNE', 'an image of ECZEMA', 'an image of a RASH', 'a photo of a MOLE', 'a picture of PSORISASIS', 'a photo of healthy, normal skin']

--- 'MedSigLIP' Test Results (Original Model) ---
1. Predicted Label: an image of a RASH        | Confidence: 32.91%
2. Predicted Label: an image of ECZEMA        | Confidence: 28.37%
3. Predicted Label: an image of ACNE          | Confidence: 15.06%
