~~~
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
~~~

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/google-health/medsiglip/blob/main/notebooks/train_data_efficient_classifier.ipynb">
      <img alt="Google Colab logo" src="https://www.tensorflow.org/images/colab_logo_32px.png" width="32px"><br> Run in Google Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2Fgoogle-health%2Fmedsiglip%2Fmain%2Fnotebooks%2Ftrain_data_efficient_classifier.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/google-health/medsiglip/blob/main/notebooks/train_data_efficient_classifier.ipynb">
      <img alt="GitHub logo" src="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" width="32px"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://huggingface.co/google/medsiglip-448">
      <img alt="Hugging Face logo" src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" width="32px"><br> View on Hugging Face
    </a>
  </td>
</tr></tbody></table>

# Classifying skin conditions with MedSigLIP and SCIN dataset

In this notebook we will train a model to classify skin conditions from images in the [SCIN dataset](https://github.com/google-research-datasets/scin) based on embeddings generated by the [MedSigLIP model](https://developers.google.com/health-ai-developer-foundations/medsiglip).


The SCIN (Skin Condition Image Network) open access dataset contains 5,000+ volunteer contributions (10,000+ images) of common dermatology conditions. The SCIN dataset was collected from Google Search users in the United States through a voluntary, consented image donation application. Three dermatologists labeled each image with up to 3 conditions, and a confidence rating from 1-5 for each condition. For example:

Dermatologist labels:
- Dermatologist 1's label: Eczema with confidence 3, Acute and chronic dermatitis with confidence 2, Psoriasis vulgaris with confidence 1

- Dermatologist 2's label: Eczematous dermatitis with confidence 3

- Dermatologist 3's label: Eczematous dermatitis with confidence 4, Post-inflammatory hyperpigmentation with confidence 3

The labels would be as follows:

- dermatologist_skin_condition_label_name: Eczema, Acute and chronic dermatitis, Psoriasis vulgaris, Eczematous dermatitis, Eczematous dermatitis, Post-inflammatory hyperpigmentation

- dermatologist_skin_condition_confidence: [3, 2, 1, 3, 4, 3]  

The MedSigLIP model is used to generate rich embeddings for medical images allowing us to train a machine learning model with less data and compute compared to training from scratch. Visit the [MedSigLIP page](https://developers.google.com/health-ai-developer-foundations/medsiglip) on the HAI-DEF site to learn more about the model and see this notebook to learn more about the [SCIN dataset](https://github.com/google-research-datasets/scin/blob/main/scin_demo.ipynb).

We will frame this as a multi-label classification problem, where the model takes in a 6144 dimensional embedding and the label is which of the following skin conditions the patient had `['Eczema', 'Allergic Contact Dermatitis', 'Insect Bite', 'Urticaria', 'Psoriasis', 'Folliculitis', 'Irritant Contact Dermatitis', 'Tinea', 'Herpes Zoster', 'Drug Rash']`. Using our example above our label will be `[ 0 1 0 1 0 0 0 0 0 0 ]`, since at least one doctor labeled the conditions Eczema and Psoriasis.

## Setup

This notebook uses precomputed embeddings by default and can be run using a CPU runtime. If you are generating the embeddings, you can use a runtime with a GPU to speed up generation:

1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **T4 GPU**.

In [None]:
import collections
import io
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from google.cloud import storage

# @markdown Set `USE_PRECOMPUTED_EMBEDDINGS` to load precomputed embeddings (True by default). Uncheck this option if you want to generate the embeddings.
USE_PRECOMPUTED_EMBEDDINGS = True # @param {type:"boolean"}

## Load SCIN dataset

In [None]:
SCIN_GCS_BUCKET_NAME = 'dx-scin-public-data'
SCIN_GCS_CASES_CSV = 'dataset/scin_cases.csv'
SCIN_GCS_LABELS_CSV = 'dataset/scin_labels.csv'
SCIN_GCS_IMAGES_DIR = 'dataset/images/'


def initialize_df_with_metadata(bucket, csv_path):
  """Loads the given CSV into a pd.DataFrame."""
  df = pd.read_csv(io.BytesIO(bucket.blob(csv_path).download_as_string()), dtype={'case_id': str})
  df['case_id'] = df['case_id'].astype(str)
  return df


def augment_metadata_with_labels(df, bucket, csv_path):
  """Loads the given CSV into a pd.DataFrame."""
  labels_df = pd.read_csv(io.BytesIO(bucket.blob(csv_path).download_as_string()), dtype={'case_id': str})
  print(f'Loaded labels with {len(labels_df)} rows.')
  labels_df['case_id'] = labels_df['case_id'].astype(str)
  merged_df = pd.merge(df, labels_df, on='case_id')
  return merged_df


scin_bucket = storage.Client.create_anonymous_client().bucket(SCIN_GCS_BUCKET_NAME)

scin_no_label_df = initialize_df_with_metadata(scin_bucket, SCIN_GCS_CASES_CSV)
scin_df = augment_metadata_with_labels(scin_no_label_df, scin_bucket, SCIN_GCS_LABELS_CSV)
scin_df.set_index('case_id', inplace=True)
print(f'Loaded {len(scin_df)} rows.')

# scin_df is the main data frame we will be working with.
scin_df.head(5)

## Explore SCIN dataset

In [None]:
import ipywidgets as widgets
from IPython.display import display
from google.colab import output as colab_output


def display_image(bucket, image_path):
  image = Image.open(io.BytesIO(bucket.blob(image_path).download_as_string()))
  f, axarr = plt.subplots(1, 1, figsize = (4, 4))
  axarr.imshow(image, cmap='gray')
  axarr.axis('off')
  plt.show()


def display_images_for_case(df, case_id):
  # Each volunteer contributor submitted up to 3 images
  image_paths = [df.loc[case_id, 'image_1_path'], df.loc[case_id, 'image_2_path'], df.loc[case_id, 'image_3_path']]
  for path in image_paths:
    if isinstance(path, str):
      scin_bucket = storage.Client.create_anonymous_client().bucket(SCIN_GCS_BUCKET_NAME)
      display_image(scin_bucket, path)


  conditions = df.loc[case_id, 'dermatologist_skin_condition_on_label_name']
  confidence = df.loc[case_id, 'dermatologist_skin_condition_confidence']
  print(f'Skin Conditions {conditions}')
  print(f'Confidence {confidence}')


def on_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        colab_output.clear()
        display(case_id_dropdown)
        display_images_for_case(scin_df,case_id=change['new'])


case_id_dropdown = widgets.Dropdown(options=scin_df.index, description="Case ID")
display(case_id_dropdown)
case_id_dropdown.observe(on_change)
display_images_for_case(scin_df, case_id_dropdown.value)

In [None]:
def print_condition_distribution(df, top_n_conditions=50):
  # Any condition that shows up in a label
  condition_ctr = collections.Counter()
  print(f'Distribution of conditions in "dermatologist_skin_condition_on_label_name" column:')
  for entry in df['dermatologist_skin_condition_on_label_name'].dropna():
    condition_ctr.update(eval(entry))
  for condition, cnt in condition_ctr.most_common()[:top_n_conditions]:
    print(f'  {condition}: {cnt}')


print_condition_distribution(scin_df)

## Clean and prepare the data

We will try and predict the 10 most common conditions:

`['Eczema', 'Allergic Contact Dermatitis', 'Insect Bite', 'Urticaria', 'Psoriasis', 'Folliculitis', 'Irritant Contact Dermatitis', 'Tinea', 'Herpes Zoster', 'Drug Rash']`


Our training data X will be a list of embeddings of size`(6144,)` and our labels y will be a list of binary labels of size `(10,)`. For example, `[0, 1, 0, 0, 1, 0, 1, 0, 0, 1]`.

We filter examples if the dermatologists labeled it with insufficient image quality. We also filter labels that are below our minimum confidence. For example, if we set minimum confidence to 3 and the dermatologist labeled Eczema with confidence 2, we included the example, but Eczma will be set to 0 not 1 in our label.

Finally, the skin condition labels are imbalanced. For example we have 156 examples with drug rash and 6295 examples without. Typically, you want to split your data such that there is an even distribution of positive labels for each condition in the train and test sets. We will print out the distributions, but don't explicitly create even distributions in this notebook.

In [None]:
from sklearn.preprocessing import MultiLabelBinarizer

CONDITIONS_TO_PREDICT = ['Eczema', 'Allergic Contact Dermatitis', 'Insect Bite', 'Urticaria', 'Psoriasis', 'Folliculitis', 'Irritant Contact Dermatitis', 'Tinea', 'Herpes Zoster', 'Drug Rash']


def prepare_data():
  MINIMUM_CONFIDENCE = 0

  X = []
  y = []
  poor_image_quality_counter = 0
  missing_embedding_counter = 0
  not_in_condition_to_predict_counter = 0
  condition_confidence_low_counter = 0

  for row in scin_df.itertuples():
    if row.dermatologist_gradable_for_skin_condition_1 != 'DEFAULT_YES_IMAGE_QUALITY_SUFFICIENT':
      poor_image_quality_counter += 1
      continue

    # eval converts from string to dict
    labels = eval(row.dermatologist_skin_condition_on_label_name)
    confidence = eval(row.dermatologist_skin_condition_confidence)

    row_labels = []
    for label, confidence in zip(labels, confidence):
      if label not in CONDITIONS_TO_PREDICT:
        not_in_condition_to_predict_counter += 1
        continue
      if confidence < MINIMUM_CONFIDENCE:
        condition_confidence_low_counter += 1
        continue
      row_labels.append(label)

    for image_path in [row.image_1_path, row.image_2_path, row.image_3_path]:
      if pd.isna(image_path):
        continue

      X.append(image_path)
      y.append(row_labels)


  print(f'Poor image quality: {poor_image_quality_counter}')
  print(f'Missing embedding: {missing_embedding_counter}')
  print(f'Condition not in "CONDITIONS_TO_PREDICT": {not_in_condition_to_predict_counter}')
  print(f'Exluded label confidence too low: {condition_confidence_low_counter}')
  return X, y

X_image_paths, y = prepare_data()
# Convert y from [['Eczma'], ['Urticaria', 'Insect Bite']] to
# [[0 0 1 0 0 0 0 0 0 0], [0 0 0 1 1 0 0 0 0 0]]

mlb = MultiLabelBinarizer(classes = CONDITIONS_TO_PREDICT)
y = mlb.fit_transform(y)

## Load precomputed embeddings

Since it takes ~hours to generate the embeddings, we will download precomputed embeddings from Google Cloud. If you would like to generate your own, see [Compute embeddings in batch](#scrollTo=iMu1zWYMPJWK).

In [None]:
if USE_PRECOMPUTED_EMBEDDINGS:
  from google.colab import auth
  import numpy as np
  import pandas as pd
  import tensorflow as tf  # Or any other library that uses GCS

  file_path = 'gs://healthai-us/medsiglip/scin_medsiglip_embeddings_and_binarized_labels.npz'

  try:
      data = np.load(tf.io.gfile.GFile(file_path, 'rb'), allow_pickle=True)
      print(type(data))
  except Exception as e:
      print(f"Error loading the file: {e}")

  d = data['data']
  X = np.vstack(d[:, 1])
  Y = np.vstack(d[:, 2])

  print(f'Length of X: {len(X)}')
  print(f'Length of Y: {len(Y)}')
  print(f'Sample from X: {X[0].shape}')
  print(f'Sample from y: {Y[0]}')

## (Optional) Compute embeddings in batch


In [None]:
if not USE_PRECOMPUTED_EMBEDDINGS:
  # Authenticate user for HuggingFace if needed. Enter token below if requested.
  from huggingface_hub.utils import HfFolder
  from huggingface_hub import notebook_login

  if HfFolder.get_token() is None:
      notebook_login()

In [None]:
if not USE_PRECOMPUTED_EMBEDDINGS:
  import requests
  from transformers import AutoProcessor, AutoModel  # Run this import only once
  import torch

  batch_size = 128  # Adjust this based on memory. Start small and increase if possible.
  num_images = len(X_image_paths)

  model_id = "google/medsiglip-448"  # @param {type:"string"}

  model = AutoModel.from_pretrained(model_id, device_map="auto")
  processor = AutoProcessor.from_pretrained(model_id)

  embeddings = [] # List to hold the embeddings
  labels = [] # list to hold the labels

  for i in range(0, num_images, batch_size):
      print(f'Processing {i} to {i + batch_size}')
      batch_paths = X_image_paths[i:i + batch_size]
      batch_labels = y[i:i + batch_size]

      X_images =  [] # List to hold the images for the current batch
      y_labels = []

      for image_path, row_labels in zip(batch_paths, batch_labels):
          try:
            response = requests.get("https://storage.googleapis.com/dx-scin-public-data/" + image_path, stream=True)
            response.raise_for_status()  # Raise an exception for bad status codes (4xx or 5xx)
            image = Image.open(io.BytesIO(response.content))
            X_images.append(image)
            y_labels.append(row_labels)
          except requests.exceptions.RequestException as e:
            print(f"Error fetching image from {image_path}: {e}")
          except Exception as e: # Catch other potential exceptions during image processing
            print(f"Error processing image from {image_path}: {e}")
      if not X_images: # Skip empty batches.
          continue

      print(len(X_images), len(y_labels))
      inputs = processor(images=X_images, return_tensors="pt").to(model.device)

      with torch.no_grad():
          image_features = model.get_image_features(**inputs)

      embeddings.append(image_features.cpu().numpy()) # Convert to NumPy
      labels.append(y_labels) # Add labels to the Y list

  final_embeddings = np.concatenate(embeddings, axis=0)
  labels = np.concatenate(labels, axis=0)
  print(final_embeddings.shape)  # Print the shape of the final embeddings

  # Now you have 'final_embeddings', a NumPy array containing the embeddings for all images.
  # You can save it to a file:
  # np.save("image_embeddings.npy", final_embeddings)

  X = final_embeddings
  Y = labels
  print(f'Length of X: {len(X)}')
  print(f'Length of Y: {len(Y)}')
  print(f'Sample from X: {X[0].shape}')
  print(f'Sample from y: {Y[0]}')

## Create train and test splits

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)


def print_label_distribution(y, mlb):
    """Prints the distribution of labels."""

    y = np.array(y)

    # Count label occurrences
    label_counts = np.sum(y, axis=0)

    # Calculate and print the distribution
    label_percentages = label_counts / len(y)

    for i, condition in enumerate(mlb.classes_):
        print(f"{condition}: {label_percentages[i]:.4f}")


print("\nPercentage of positive labels by condition in train:")
print_label_distribution(y_train, mlb)
print("\nPercentage of positive labels by condition in test:")
print_label_distribution(y_test, mlb)

## Train a logistic regression classifier

Using the MultiOutputClassifier wrapper sklearn will train 10 different logistic regression models. One for each of our labels. Let's see how they do!

In [None]:
from sklearn.multioutput import MultiOutputClassifier
from sklearn.linear_model import LogisticRegression

lr_classifier = MultiOutputClassifier(LogisticRegression(max_iter=250)).fit(X_train, y_train)
y_pred = lr_classifier.predict_proba(X_test)

# The predict_proba are returned in a funky format so reconfigure to (1291,10)
cols = []
for i in range(len(mlb.classes_)):
  cols.append(y_pred[i][:,1])
y_pred = np.column_stack(cols)
y_pred.shape

## Evaluate results

We will use the [Hamming loss](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.hamming_loss.html) instead of accuracy to get a general sense of how our model is performing. For accuracy the predictions must exactly match the labels, so if even one of the 10 predictions are wrong the whole example is marked as incorrect. The hamming loss is the fraction of labels that are incorrectly predicted, which is more forgiving.

In [None]:
from sklearn.metrics import multilabel_confusion_matrix, hamming_loss, ConfusionMatrixDisplay, RocCurveDisplay


def plot_confusion_matrix(y_test, y_pred, classes):
  y_bool = (y_pred >= 0.5).astype(int)
  cnf_matrix = multilabel_confusion_matrix(y_test, y_bool)
  _, axes = plt.subplots(2, 5, figsize=(14, 6), tight_layout=True)
  for cf, cl, ax in zip(cnf_matrix, classes, axes.flatten()):
    ax.set_title(cl)
    disp = ConfusionMatrixDisplay(confusion_matrix=cf)
    disp.plot(ax=ax)
  plt.show()


def print_hamming_loss(y_test, y_pred):
  y_bool = (y_pred >= 0.5).astype(int)
  print(f'\n### Hamming Loss: {hamming_loss(y_test, y_bool)} ###')


def plot_roc_curve(y_test, y_pred, classes):
  _, axes = plt.subplots(2, 5, figsize=(14, 6), tight_layout=True)
  for i, (cl, ax) in enumerate(zip(classes, axes.flatten())):
    ax.set_title(cl)
    RocCurveDisplay.from_predictions(y_test[:, i], y_pred[:, i], ax=ax)
  plt.show()

In [None]:
plot_confusion_matrix(y_test, y_pred, mlb.classes_)
plot_roc_curve(y_test, y_pred, mlb.classes_)
print_hamming_loss(y_test, y_pred)

## Train a neural net

Let's see if a simple neural network can do better! Since this is a multi-label classification problem we will use a sigmoid activation function instead of softmax.

In [None]:
# Convert to tensorflow datasets
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train))
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test))

for x, y in train_ds.take(1):
    print("Input:", x)
    print("Target:", y)

train_ds = train_ds.batch(32)
test_ds = test_ds.batch(32)

In [None]:
from tensorflow.keras import layers
from keras import regularizers

weight_decay = 1e-5

inputs = tf.keras.Input(shape=(1152,)) # embedding shape of 1152

hidden = layers.Dense(512,
                      kernel_regularizer=regularizers.l2(l2=weight_decay),
                      bias_regularizer=regularizers.l2(l2=weight_decay),
                      activation="relu"
                      )(inputs)
hidden = layers.Dropout(0.05)(hidden)
hidden = layers.Dense(256,
                      kernel_regularizer=regularizers.l2(l2=weight_decay),
                      bias_regularizer=regularizers.l2(l2=weight_decay),
                      activation="relu")(hidden)
hidden = layers.Dropout(0.1)(hidden)
output = layers.Dense(len(mlb.classes_), activation="sigmoid")(hidden)


model = tf.keras.Model(inputs, output)
model.compile(
    loss="binary_crossentropy",
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4)
)
print(model.summary())
history = model.fit(
    train_ds, validation_data=test_ds, epochs=25
)

In [None]:
def plot_result(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_result("loss")

In [None]:
y_pred = model.predict(test_ds)
plot_confusion_matrix(y_test, y_pred, mlb.classes_)
plot_roc_curve(y_test, y_pred, mlb.classes_)
print_hamming_loss(y_test, y_pred)

The neural net did perform better than the logistic regression classifier. We can see this reflected in the reduced Hamming loss which gives us sense of the models performance across all conditions. We can also see improvements in individual AUC's. Eczema, Drug Rash and Psoriasis show improved AUC in this particular run. You might get slightly different numbers because of the non determinism in training.

## Next steps

Explore the other [notebooks](https://github.com/google-health/medsiglip/blob/main/notebooks) to learn what else you can do with the model.