<a href="https://colab.research.google.com/github/Vinooj/health-kiosk/blob/main/notebooks/fine_tune_for_image_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Fintuning SigLIP with SCIN dataset

### Get access to SigLIP

Before you get started, make sure that you have access to MedSigLIP models on
Hugging Face:

1.  If you don't already have a Hugging Face account, you can create one for
    free by clicking [here](https://huggingface.co/join).
2.  Head over to the
    [SigLIP model page](https://huggingface.co/google/siglip-base-patch16-224) and
    accept the usage conditions.

### Install dependencies

In [None]:
! pip install --upgrade --quiet accelerate datasets evaluate tensorboard transformers

## Prepare fine-tuning dataset

This notebook uses the
[SCIN dataset](https://github.com/google-research-datasets/scin) to fine-tune MedSigLIP to classify the ten most common dermatology conditions:

```
["Eczema", "Allergic Contact Dermatitis", "Insect Bite", "Urticaria", "Psoriasis", "Folliculitis", "Irritant Contact Dermatitis", "Tinea", "Herpes Zoster", "Drug Rash"]
```

**Note:** The SCIN dataset was included in MedSigLIP's training data. It is used as a fine-tuning dataset in this notebook for demonstration purposes.

Download the dataset from
[Cloud Storage](https://console.cloud.google.com/storage/browser/dx-scin-public-data).

In [None]:
# Skip authentication since this dataset is public
! gcloud config set auth/disable_credentials True && gcloud config set user_output_enabled False

! mkdir dataset
! gcloud storage cp -R gs://dx-scin-public-data/dataset/* dataset/

Load the dataset as a `pandas.DataFrame`.

In [None]:
import pandas as pd

SCIN_GCS_CASES_CSV = "dataset/scin_cases.csv"
SCIN_GCS_LABELS_CSV = "dataset/scin_labels.csv"


def initialize_df_with_metadata(csv_path: str):
    df = pd.read_csv(csv_path, dtype={"case_id": str})
    df["case_id"] = df["case_id"].astype(str)
    return df


def augment_metadata_with_labels(df: pd.DataFrame, csv_path: str):
    labels_df = pd.read_csv(csv_path, dtype={"case_id": str})
    print(f"Loaded labels with {len(labels_df)} rows.")
    labels_df["case_id"] = labels_df["case_id"].astype(str)
    merged_df = pd.merge(df, labels_df, on="case_id")
    return merged_df


scin_no_label_df = initialize_df_with_metadata(SCIN_GCS_CASES_CSV)
scin_df = augment_metadata_with_labels(scin_no_label_df, SCIN_GCS_LABELS_CSV)
scin_df.set_index("case_id", inplace=True)

scin_df

Process the DataFrame so that each row corresponds to a training example with an `image` and `label`.

Filter examples with insufficient image quality and
low-confidence labels. Additionally, in the original dataset each data donator contributes up to three images. Make each individual image and its corresponding label into separate examples in the resulting training dataset.

In [None]:
from sklearn.preprocessing import MultiLabelBinarizer

CONDITIONS = ["Eczema", "Allergic Contact Dermatitis", "Insect Bite", "Urticaria", "Psoriasis", "Folliculitis", "Irritant Contact Dermatitis", "Tinea", "Herpes Zoster", "Drug Rash"]
MINIMUM_CONFIDENCE = 0


def remove_low_confidence_labels(row: pd.Series):
    labels = eval(row.dermatologist_skin_condition_on_label_name)
    confidences = eval(row.dermatologist_skin_condition_confidence)

    row_labels = []
    for label, confidence in zip(labels, confidences):
        if label in CONDITIONS and confidence >= MINIMUM_CONFIDENCE:
            row_labels.append(label)
    return row_labels


# Filter examples with insufficient image quality
scin_df = scin_df[scin_df.dermatologist_gradable_for_skin_condition_1 == "DEFAULT_YES_IMAGE_QUALITY_SUFFICIENT"]

# Remove labels that are below a minimum confidence
scin_df["label"] = scin_df.apply(remove_low_confidence_labels, axis=1)

# Make each image (if it exists) into a separate example.
# Also create a new dataframe with only images and labels.
image_cols = ["image_1_path", "image_2_path", "image_3_path"]
scin_df = pd.melt(
    scin_df, id_vars=["label"], value_vars=image_cols, value_name="image"
).drop(columns=["variable"]).dropna(subset=["image"]).set_index("image")

# Convert labels from e.g. [["Eczema"], ["Urticaria", "Insect Bite"]] to
# [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 1, 0, 0, 0, 0, 0, 0]]
mlb = MultiLabelBinarizer(classes=CONDITIONS)
scin_df["label"] = mlb.fit_transform(scin_df["label"]).tolist()

# Drop missing image (https://github.com/google-research-datasets/scin/issues/1)
scin_df.drop(index=["dataset/images/-2243186711511406658.png"], inplace=True)

scin_df

Load the DataFrame as a Hugging Face `Dataset` with `from_pandas()`. Then, create train, validation, and test splits.

Initial scin_df after merging labels had 5033 rows (cases), but the number of samples in the dataset for training became 6451. This is because:

1. **Multiple Images per Case:** The original scin_df had image_1_path, image_2_path, and image_3_path columns. The code used pd.melt to transform this, creating a separate row for each image associated with a case. So, if a single case had 3 images, it became 3 separate image examples for the model.
2. **Filtering:** Some examples were filtered out due to insufficient image quality or missing images, which brought the total number of valid image examples down to 6451.

From these** 6451 total image samples**, the split was as follows:

- Initial Split (test_size=0.2):
  - Training Set: 80% of 6451 ≈ 5160 samples.
  - Temporary Test/Validation Pool: 20% of 6451 ≈ 1291 samples.
- Second Split (test_size=0.5) of the Temporary Pool:
  - Validation Set: 50% of 1291 = 645 samples.
  - Test Set: 50% of 1291 = 646 samples.

In [None]:
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer

# Re-initializing for accurate counts
SCIN_GCS_CASES_CSV = "dataset/scin_cases.csv"
SCIN_GCS_LABELS_CSV = "dataset/scin_labels.csv"

def initialize_df_with_metadata(csv_path: str):
    df = pd.read_csv(csv_path, dtype={"case_id": str})
    df["case_id"] = df["case_id"].astype(str)
    return df

def augment_metadata_with_labels(df: pd.DataFrame, csv_path: str):
    labels_df = pd.read_csv(csv_path, dtype={"case_id": str})
    labels_df["case_id"] = labels_df["case_id"].astype(str)
    merged_df = pd.merge(df, labels_df, on="case_id")
    return merged_df

# Step 1: Initial load of cases and labels
scin_no_label_df_temp = initialize_df_with_metadata(SCIN_GCS_CASES_CSV)
scin_df_temp = augment_metadata_with_labels(scin_no_label_df_temp, SCIN_GCS_LABELS_CSV)
scin_df_temp.set_index("case_id", inplace=True)
print(f"Initial number of cases (rows) after merging labels: {len(scin_df_temp)}")

# Step 2: Filter examples with insufficient image quality
original_cases_count = len(scin_df_temp)
scin_df_temp = scin_df_temp[scin_df_temp.dermatologist_gradable_for_skin_condition_1 == "DEFAULT_YES_IMAGE_QUALITY_SUFFICIENT"]
filtered_cases_count = len(scin_df_temp)
print(f"Number of cases (rows) after filtering for image quality: {filtered_cases_count}")
print(f"Number of cases filtered out by image quality: {original_cases_count - filtered_cases_count}")

# Step 3: Remove low confidence labels and melt to individual image examples
CONDITIONS = ["Eczema", "Allergic Contact Dermatitis", "Insect Bite", "Urticaria", "Psoriasis", "Folliculitis", "Irritant Contact Dermatitis", "Tinea", "Herpes Zoster", "Drug Rash"]
MINIMUM_CONFIDENCE = 0

def remove_low_confidence_labels(row: pd.Series):
    labels = eval(row.dermatologist_skin_condition_on_label_name)
    confidences = eval(row.dermatologist_skin_condition_confidence)

    row_labels = []
    for label, confidence in zip(labels, confidences):
        if label in CONDITIONS and confidence >= MINIMUM_CONFIDENCE:
            row_labels.append(label)
    return row_labels

scin_df_temp["label"] = scin_df_temp.apply(remove_low_confidence_labels, axis=1)

image_cols = ["image_1_path", "image_2_path", "image_3_path"]
initial_image_rows_after_melt = len(pd.melt(scin_df_temp, id_vars=["label"], value_vars=image_cols, value_name="image").drop(columns=["variable"]))

scin_df_temp = pd.melt(
    scin_df_temp, id_vars=["label"], value_vars=image_cols, value_name="image"
).drop(columns=["variable"])

# Count rows after melt, before dropna
print(f"Number of image rows after melting (before dropping NaN images): {len(scin_df_temp)}")

scin_df_temp_before_dropna = len(scin_df_temp)
scin_df_temp = scin_df_temp.dropna(subset=["image"])
print(f"Number of image rows after dropping NaN images: {len(scin_df_temp)}")
print(f"Number of image rows filtered out by dropping NaN images: {scin_df_temp_before_dropna - len(scin_df_temp)}")

# Step 4: Drop specific missing image
scin_df_temp_before_specific_drop = len(scin_df_temp)
# Check if the specific image path exists in the index before attempting to drop
if "dataset/images/-2243186711511406658.png" in scin_df_temp.set_index("image").index:
    scin_df_temp = scin_df_temp.set_index("image").drop(index=["dataset/images/-2243186711511406658.png"], errors='ignore')
    # Reset index if it was set for the drop operation
    if isinstance(scin_df_temp, pd.DataFrame):
        scin_df_temp = scin_df_temp.reset_index()
else:
    print("Specific missing image 'dataset/images/-2243186711511406658.png' not found in DataFrame for explicit drop.")

print(f"Final number of image rows before `from datasets import Dataset`: {len(scin_df_temp)}")
print(f"Number of image rows filtered out by specific image drop: {scin_df_temp_before_specific_drop - len(scin_df_temp)}")

In [None]:
from datasets import Dataset, Image

data = Dataset.from_pandas(scin_df)
# Decode image paths as PIL images
data = data.cast_column("image", Image())
data = data.train_test_split(
    test_size=0.2,
    shuffle=True,
    seed=42,
)
validation_test_data = data.pop("test").train_test_split(
    test_size=0.5,
    shuffle=True,
    seed=42,
)
data["validation"] = validation_test_data["train"]
data["test"] = validation_test_data["test"]

# Display dataset details
data

Inspect a sample data point, which contains:

* `image`: dermatology image as a `PIL` image object
* `label`: corresponding multiple labels as a one-hot encoded vector

In [None]:
data["train"][1]["image"]

In [None]:
data["train"][1]["label"]

Preprocess the input images.

The model expects the input images to be resized to 448x448 with pixel values rescaled to the range [-1, 1].

Note that the input images are also zero-padded to square before resizing to preserve aspect ratio. This step is included for consistency with the original MedSigLIP training data preprocessing.

In [None]:
from torchvision.transforms import Compose, CenterCrop, Resize, ToTensor, Normalize, InterpolationMode
from transformers import AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")

size = image_processor.size["height"]  # 448
mean = image_processor.image_mean  # 0.5
std = image_processor.image_std  # 0.5

_transform = Compose([
    Resize((size, size), interpolation=InterpolationMode.BILINEAR),
    # Convert PIL image to PyTorch tensor and rescale pixel values from the
    # range [0, 255] to [0, 1]
    ToTensor(),
    # Scale pixel values to the range [-1, 1]
    Normalize(mean=mean, std=std),
])


def preprocess(examples):
    examples["pixel_values"] = [
        # CenterCrop effectively zero pads the image to a square with size equal
        # to the larger dimension
        _transform(CenterCrop(max(image.size))(image.convert("RGB")))
        for image in examples["image"]
    ]
    return examples


train_data = data["train"].map(preprocess, batched=True, remove_columns=["image"])
validation_data = data["validation"].map(preprocess, batched=True, remove_columns=["image"])

## Fine-tune the model

This notebook demonstrates fine-tuning the MedSigLIP vision encoder for a multi-label image classification task on image and structured label data using the `Trainer` from the Hugging Face `Transformers` library.

Load the MedSigLIP vision encoder with an image classification head on top.

In [None]:
from transformers import AutoModelForImageClassification

model_id = "google/siglip-base-patch16-224"

# Define the label mappings for the classification task
id2label = {i: label for i, label in enumerate(CONDITIONS)}
label2id = {label: i for i, label in enumerate(CONDITIONS)}

model = AutoModelForImageClassification.from_pretrained(
    model_id,
    problem_type="multi_label_classification",
    num_labels=len(CONDITIONS),
    id2label=id2label,
    label2id=label2id,
)

Define a data collator to prepare batches of training examples.

In [None]:
import torch


def collate_fn(examples):
    pixel_values = torch.tensor([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples], dtype=torch.float)
    return {"pixel_values": pixel_values, "labels": labels}

Define evaluation metrics to be computed during training. The function takes in an [`EvalPrediction`](https://huggingface.co/docs/transformers/en/internal/trainer_utils#transformers.EvalPrediction) which contains the model predictions (logits) and labels.

Similar to MedSigLIP's reported metrics, this example uses the macro-averaged One-vs-rest ROC AUC (Area Under the Receiver Operating Characteristic Curve) score to evaluate multi-class classification performance. See the [documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html) for more details.

In [None]:
import evaluate
import numpy as np

roc_auc_score = evaluate.load("roc_auc", "multilabel")


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def metrics_fn(eval_pred):
    logits, labels = eval_pred
    scores = sigmoid(logits)
    return roc_auc_score.compute(
        prediction_scores=scores,
        references=labels,
        average="macro",
        multi_class="ovr",
    )

Define a weighted loss function to address class imbalance within the dataset.

This is the [default loss](https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/siglip/modeling_siglip.py#L1205) used for multi-label classification but utilizes `pos_weight` to assign a per-class weight to positive examples, effectively treating minority positive classes with greater importance in the loss calculation. Refer to the [documentation](https://docs.pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html) for more details.

**Note:** There are other balancing methods such as oversampling that may be used depending on your dataset and classification task.

In [None]:
from torch.nn import BCEWithLogitsLoss

labels = torch.tensor(train_data["label"])
num_samples = labels.shape[0]
num_positive = labels.sum(axis=0)
num_negative = num_samples - num_positive
POS_WEIGHT = num_negative / num_positive


def loss_fn(outputs, labels, num_items_in_batch):
    logits = outputs.get("logits")
    pos_weight = POS_WEIGHT.to(logits.device)
    loss_fct = BCEWithLogitsLoss(pos_weight=pos_weight)
    return loss_fct(logits, labels)

Configure training parameters in
[`TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments).

In [None]:
from transformers import TrainingArguments

num_train_epochs = 3  # @param {type: "number"}
learning_rate = 5e-5  # @param {type: "number"}

training_args = TrainingArguments(
    output_dir="siglip-224-scin-classification",  # Directory and Hub repository id to save the model to
    num_train_epochs=num_train_epochs,               # Number of training epochs
    per_device_train_batch_size=8,                   # Batch size per device during training
    per_device_eval_batch_size=8,                    # Batch size per device during evaluation
    gradient_accumulation_steps=8,                   # Number of steps before performing a backward/update pass
    logging_steps=40,                                # Number of steps between logs
    save_strategy="epoch",                           # Save checkpoint every epoch
    eval_strategy="steps",                           # Evaluate every `eval_steps`
    eval_steps=40,                                   # Number of steps between evaluations
    learning_rate=learning_rate,                     # Learning rate
    weight_decay=0.01,                               # Weight decay to apply
    warmup_steps=5,                                  # Number of steps for linear warmup from 0 to learning rate
    lr_scheduler_type="cosine",                      # Use cosine learning rate scheduler
    push_to_hub=False,                                # Push model to Hub
    report_to="tensorboard",                         # Report metrics to tensorboard
)

Construct a [`Trainer`](https://huggingface.co/docs/transformers/trainer) using the previously defined training parameters, data collator, metrics function, and weighted loss function.

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=validation_data,
    data_collator=collate_fn,
    compute_metrics=metrics_fn,
    compute_loss_func=loss_fn,
)

Launch the fine-tuning process.

**Note:** This may take around 3 hours to run using the default configuration.

In [None]:
trainer.train()

Save the final model to Hugging Face Hub.

In [None]:
trainer.save_model()

You can use the link below to navigate to your model repository and click on the "Training metrics" tab to view training curves.

In [None]:
from huggingface_hub import HfApi

api = HfApi()
username = api.whoami()["name"]
print(f"https://huggingface.co/{username}/{training_args.output_dir}")

## Evaluate the fine-tuned model

### Set up for evaluation

Load the ROC AUC (Area Under the Receiver Operating Characteristic Curve) and additional accuracy metrics to evaluate the model's performance on the classification task.

You can use other accuracy metrics based on your use case and performance requirements.

In [None]:
from typing import Union

import evaluate
import numpy as np
from sklearn.metrics import multilabel_confusion_matrix

roc_auc_score = evaluate.load("roc_auc", "multilabel")

# Ground-truth labels
REFERENCES = data["test"]["label"]


def compute_metrics(
    prediction_scores: np.ndarray,
    threshold: float,
) -> dict[str, float]:
    metrics = {}
    metrics.update(roc_auc_score.compute(
        prediction_scores=prediction_scores,
        references=REFERENCES,
        average="macro",
        multi_class="ovr",
    ))
    predictions = (prediction_scores > threshold).astype(int)
    mcm = multilabel_confusion_matrix(
        y_true=REFERENCES,
        y_pred=predictions,
    )
    tn = mcm[:, 0, 0]
    tp = mcm[:, 1, 1]
    fn = mcm[:, 1, 0]
    fp = mcm[:, 0, 1]
    metrics.update({
        "sensitivity": tp / (tp + fn),
        "specificity": tn / (tn + fp),
    })
    return metrics


def print_metrics(metrics: dict[str, Union[float, np.ndarray]]) -> None:
    print(f"Macro-averaged one-vs-rest ROC AUC: {metrics['roc_auc']:.2f}")
    for metric in ["sensitivity", "specificity"]:
        print(f"\n{metric.capitalize()}:")
        for i, condition in enumerate(CONDITIONS):
            print(f"{condition}: {metrics[metric][i]:.4f}")

### Compute metrics on the fine-tuned model

Load the fine-tuned model.

In [None]:
ft_model = AutoModelForImageClassification.from_pretrained(
    training_args.output_dir,
    problem_type="multi_label_classification",
    num_labels=len(CONDITIONS),
    id2label=id2label,
    label2id=label2id,
    device_map="auto",
)

Run batch inference on the test dataset.

**Followig Code will pick a random image from the training set, run it through the fine-tuned model, and then display the model's predicted probabilities, the conditions it identified (based on a 0.5 threshold), and the original ground-truth labels for that image.**

--- Verification with Top 5 Predicted Conditions ---

In [None]:
import random
import torch
from PIL import Image

# Choose a random index from the training data
random_index = random.randint(0, len(train_data) - 1)

# Get the original image path and label from the raw data before preprocessing
original_data_point = data["train"][random_index]
original_image_path = original_data_point["image"]
original_label = original_data_point["label"]

# Preprocess the image for the model
image = original_image_path.convert("RGB") # Corrected line
inputs = image_processor(images=[image], return_tensors="pt").to("cuda")

# Get model predictions
with torch.no_grad():
    outputs = ft_model(**inputs)

logits = outputs.logits
probabilities = torch.sigmoid(logits).cpu().numpy().flatten()

# Define a threshold for identifying conditions (same as used for evaluation)
threshold = 0.5
identified_conditions_binary = (probabilities > threshold).astype(int)

print(f"--- Analysis for Random Image (Index: {random_index}) ---")
print(f"Image Path: {original_image_path}") # Changed print statement for clarity
print("\nOriginal Labels (Ground Truth):")
original_condition_names = []
for i, val in enumerate(original_label):
    if val == 1:
        original_condition_names.append(CONDITIONS[i])
        print(f"- {CONDITIONS[i]}")

print("\nModel Predicted Probabilities and Identified Conditions (Threshold=0.5) (Sorted by Probability):")

# Combine probabilities, condition names, and identified status
prediction_details = []
for i, prob in enumerate(probabilities):
    condition_name = CONDITIONS[i]
    is_identified = bool(identified_conditions_binary[i])
    prediction_details.append((prob, condition_name, is_identified))

# Sort by probability in descending order
prediction_details.sort(key=lambda x: x[0], reverse=True)

identified_conditions_list = []
for prob, condition_name, is_identified in prediction_details:
    print(f"- {condition_name}: {prob:.4f} (Identified: {is_identified})")
    if is_identified:
        identified_conditions_list.append(condition_name)

if not identified_conditions_list:
    print("  No conditions identified by the model for this image.")
else:
    print(f"\nModel identified the following conditions: {', '.join(identified_conditions_list)}")

# --- Additional Verification: Check for intersection with top 5 predicted ---
print("\n--- Verification with Top 5 Predicted Conditions ---")

top_5_predicted_conditions = [detail[1] for detail in prediction_details[:5]]

print(f"Top 5 Model Predicted Conditions: {', '.join(top_5_predicted_conditions)}")
print(f"Original Labels: {', '.join(original_condition_names)}")

# Find the intersection
intersection = set(original_condition_names).intersection(set(top_5_predicted_conditions))

if intersection:
    print(f"Match Successful! Common conditions found: {', '.join(intersection)}")
else:
    print("Match Failed. No common conditions between original labels and top 5 predicted.")

# Let's break down the line outputs = ft_model(**inputs) bold text when you have a batch of 64 images, using a simple analogy.

Imagine your ft_model is a highly specialized team of ten doctors, one for each of the dermatology conditions you are trying to classify (Eczema, Insect Bite, etc.).

inputs: This is like taking a stack of 64 patient folders. Each folder contains one image that has already been prepared and cleaned up (resized, normalized, etc.) so the doctors can understand it perfectly.

ft_model(**inputs): The ** before inputs is just a technical way of saying, "Hey, model! Here's this whole stack of 64 patient folders." Instead of handing them one by one, you give the entire stack to the doctor team at once.

How the doctors (model) process the batch: The team of doctors doesn't just look at one image, finish their diagnosis, and then move to the next. Instead, they efficiently look at all 64 images almost simultaneously. Each doctor (representing a specific condition) quickly scans all 64 images and forms an initial opinion about whether their condition is present in each image.

outputs: This is the collective report from all ten doctors for all 64 patients. It's a structured collection of their findings.

logits: Within that outputs report, logits are the raw initial suspicion scores from each doctor. For every one of the 64 images, and for every one of the ten conditions, each doctor provides a number. This number isn't yet a neat probability (like "80% chance of Eczema"). Instead, it's a raw, unscaled score that indicates how strongly that specific doctor suspects their condition is present in that specific image. A higher positive number means stronger suspicion, a negative number means less suspicion, and zero means neutral. These scores will later be converted into actual probabilities (0 to 1) using the sigmoid function, which is like turning the doctors' raw suspicions into a final percentage chance.

In [None]:
from PIL import Image
import torch

prediction_scores = []
for batch in data["test"].batch(batch_size=64):
    images = [Image.open(image["path"]) for image in batch["image"]]
    inputs = image_processor(images=images, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = ft_model(**inputs)

    logits = outputs.logits
    scores = torch.sigmoid(logits)
    prediction_scores.extend(scores)

prediction_scores = torch.stack(prediction_scores).cpu().numpy()

Compute metrics.

**Note:** This notebook demonstrates a sample training run and the metrics below have not been optimized. Further tuning will be needed to achieve desired performance.

In [None]:
# Default threshold used to convert probability scores into class predictions.
# Note that optimal threshold selection is not demonstrated in this notebook.
threshold = 0.5

metrics = compute_metrics(prediction_scores, threshold)
print_metrics(metrics)

## Next steps

Explore the other [notebooks](https://github.com/google-health/medsiglip/blob/main/notebooks) to learn what else you can do with the model.

# Task
Evaluate the fine-tuned model on the `data['test']` dataset by performing predictions for each image, extracting original labels and the top 5 predicted conditions. Calculate the 'Top 5 Intersection Accuracy' by determining if any of the original labels are present in the top 5 predicted conditions. Additionally, compute and print the macro-averaged one-vs-rest ROC AUC, sensitivity, and specificity for all conditions using the `compute_metrics` and `print_metrics` functions, and finally, summarize the model's overall performance.

## Evaluate Model on Test Set

### Subtask:
Iterate through the entire `data['test']` dataset. For each image, perform the prediction using the fine-tuned model. Extract the original labels and the top 5 predicted conditions. Compare these two sets to check for any intersection.


**Reasoning**:
The subtask requires iterating through the test dataset, making predictions, extracting top 5 predicted conditions, and comparing them with ground truth labels to count successful matches. This code block will implement these steps.



In [None]:
successful_matches = 0

for i in range(len(data["test"])):
    example = data["test"][i]
    image = example["image"]
    original_label_onehot = example["label"]

    # Preprocess the image
    preprocessed_image = image_processor(images=[image.convert("RGB")], return_tensors="pt").to("cuda")

    # Get model predictions
    with torch.no_grad():
        outputs = ft_model(**preprocessed_image)

    logits = outputs.logits
    probabilities = torch.sigmoid(logits).cpu().numpy().flatten()

    # Get top 5 predicted conditions
    top_5_indices = probabilities.argsort()[-5:][::-1] # Get indices of top 5 probabilities, in descending order
    top_5_predicted_conditions = [CONDITIONS[idx] for idx in top_5_indices]

    # Convert original label to condition names
    original_condition_names = [CONDITIONS[idx] for idx, val in enumerate(original_label_onehot) if val == 1]

    # Check for intersection
    if set(original_condition_names).intersection(set(top_5_predicted_conditions)):
        successful_matches += 1

print(f"Total test examples: {len(data['test'])}")
print(f"Successful matches (at least one original label in top 5 predictions): {successful_matches}")

top_5_intersection_accuracy = (successful_matches / len(data["test"])) * 100
print(f"Top 5 Intersection Accuracy: {top_5_intersection_accuracy:.2f}%")
