In [1]:
import os
os.chdir('../')

# Implement a zero-shot function for medclip

In [2]:
# implement a zero-shot function for medclip

import torch
import torchvision
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm

# Device configuration
from medclip import MedCLIPModel, MedCLIPVisionModelViT
from medclip.modeling_medclip import MedCLIPVisionModel

# debuggin
from PIL import Image

# prepare for the demo image and texts
from build.lib.medclip.constants import BERT_TYPE, IMG_MEAN, IMG_STD, IMG_SIZE
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from data_utils import load_dataset, LESION_TYPE, load_ham10000_dataset

BATCH_SIZE = 64

# def medclip_zero_shot_inline(test_dataset, classes, batch_size=BATCH_SIZE):
#     # Device configuration
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     # Data loader for the dataset
#     data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
#     print(f"Device: {device}")

#     # Initialize MedClip Models
#     model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT).to(device)

#     # Prepare text prompts
#     text_prompts = [f"a photo of a {c}, a type of skin lesion." for c in classes]
#     # Initialize the tokenizer
#     tokenizer = AutoTokenizer.from_pretrained(BERT_TYPE)

#     # Tokenize text prompts and convert to tensors
#     text_tokens = [tokenizer(text, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True) for text in text_prompts]

#     # Encode text prompts using MedClip's text model
#     # Inside the medclip_zero_shot function
#     text_features = [
#         model.encode_text(
#             input_ids=tokens['input_ids'].to(device), 
#             attention_mask=tokens['attention_mask'].to(device)
#         ) 
#         for tokens in text_tokens
#     ]

#     # Initialize variables for accuracy calculation
#     correct = 0
#     total = 0

#     for images, labels in tqdm(data_loader):
#         images, labels = images.to(device), labels.to(device)

#         # TODO: Encode images using MedClip's vision model
#         image_features = model.encode_image(images)

#         # Flatten text_features into a single 2D tensor
#         text_features_tensor = torch.cat(text_features, dim=0)

#         # Calculate similarity and make predictions
#         similarity = torch.matmul(image_features, text_features_tensor.t())
#         _, predictions = similarity.max(dim=-1)

#         # Update correct and total counts
#         correct += (predictions == labels).sum().item()
#         total += labels.size(0)

#     return correct / total

# # Load HAM10000 dataset
# transform = torchvision.transforms.Compose([
#     torchvision.transforms.Resize((IMG_SIZE, IMG_SIZE)),
#     torchvision.transforms.ToTensor(),
#     torchvision.transforms.Normalize(mean=[IMG_MEAN], std=[IMG_STD])
# ])

# # train_dataset, test_dataset = load_ham10000_dataset(transform=transform, data_dir='data/ham10000')
# train_dataset, test_dataset = load_dataset("HAM10000", transform=transform, data_dir='data/ham10000/')
# classes = list(LESION_TYPE.values())  # From the data_utils.py file

# # Run zero-shot classification
# acc = medclip_zero_shot_inline(test_dataset, classes)
# print(f"Accuracy: {acc:.2f}")

In [3]:
def medclip_zero_shot(model, test_dataset, classes, batch_size=BATCH_SIZE):
    # Data loader for the dataset
    data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    # Prepare text prompts
    text_prompts = [f"a photo of a {c}, a type of skin lesion." for c in classes]
    # Initialize the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(BERT_TYPE)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")

    # Tokenize text prompts and convert to tensors
    text_tokens = [tokenizer(text, return_tensors='pt', padding=True, truncation=False, add_special_tokens=True) for text in text_prompts]

    # Encode text prompts using MedClip's text model
    # Inside the medclip_zero_shot function
    text_features = [
        model.encode_text(
            input_ids=tokens['input_ids'].to(device), 
            attention_mask=tokens['attention_mask'].to(device)
        ) 
        for tokens in text_tokens
    ]

    # Initialize variables for accuracy calculation
    correct = 0
    total = 0

    for images, labels in tqdm(data_loader):
        images, labels = images.to(device), labels.to(device)
        # Encode images using MedClip's vision model
        # with torch.no_grad():
        image_features = model.encode_image(images)
        # Flatten text_features into a single 2D tensor
        text_features_tensor = torch.cat(text_features, dim=0)

        # Calculate similarity and make predictions
        similarity = torch.matmul(image_features, text_features_tensor.t())
        _, predictions = similarity.max(dim=-1)

        # Update correct and total counts
        correct += (predictions == labels).sum().item()
        total += len(labels)

    return correct / total

## Load HAM10000 dataset and test MedClip's zero-shot capabilities

In [4]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((IMG_SIZE, IMG_SIZE)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[IMG_MEAN], std=[IMG_STD])
])

ham_train, ham_test = load_dataset("HAM10000", transform=transform, data_dir='data/ham10000/')
classes = list(LESION_TYPE.values())  # From the data_utils.py file


Loading HAM10000 dataset...


MedCLIP_ResNet50_model

In [6]:
# load MedCLIP-ResNet50
MedCLIP_ResNet50_model = MedCLIPModel(vision_cls=MedCLIPVisionModel).to(device)
accuracy = medclip_zero_shot(MedCLIP_ResNet50_model, ham_train, classes)
print(f"\nAccuracy = {100*accuracy:.3f}%")

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Device: cuda


100%|██████████| 141/141 [00:22<00:00,  6.19it/s]


Accuracy = 46.500%





MedCLIP_ViT_model

In [15]:
# load MedCLIP-ViT
MedCLIP_ViT_model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT).to(device)
accuracy = medclip_zero_shot(MedCLIP_ViT_model, ham_train, classes)
print(f"\nAccuracy = {100*accuracy:.3f}%")

Some weights of the model checkpoint at microsoft/swin-tiny-patch4-window7-224 were not used when initializing SwinModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.tra

Device: cuda


100%|██████████| 141/141 [00:24<00:00,  5.65it/s]


Accuracy = 32.653%





## Load NIH Chest X-ray dataset

In [None]:
import os
os.chdir('../')

In [None]:
import torch
import torchvision
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

# Device configuration
from data_utils import load_nih_dataset_split, NIH_CLASS_TYPES, load_dataset
from medclip import MedCLIPModel, MedCLIPVisionModelViT, MedCLIPVisionModel
from build.lib.medclip.constants import BERT_TYPE, IMG_MEAN, IMG_STD, IMG_SIZE

# debuggin
from PIL import Image

BATCH_SIZE = 128

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((IMG_SIZE, IMG_SIZE)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[IMG_MEAN], std=[IMG_STD])
])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# NIH_CLASS_TYPES
classes = list(NIH_CLASS_TYPES)  # From the data_utils.py file
classes

# nih_train, nih_test = load_nih_dataset_split(transform=transform)
nih_train, nih_test = load_dataset("NIH", transform=transform, data_dir='data/nih/')
nih_train, nih_test

In [None]:
def medclip_zero_shot(model, test_dataset, classes, batch_size=BATCH_SIZE):
    # Data loader for the dataset
    data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    # Prepare text prompts
    text_prompts = [f"a photo of a {c}, a type of skin lesion." for c in classes]
    # Initialize the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(BERT_TYPE)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")

    # Tokenize text prompts and convert to tensors
    text_tokens = [tokenizer(text, return_tensors='pt', padding=True, truncation=False, add_special_tokens=True) for text in text_prompts]

    # Encode text prompts using MedClip's text model
    # Inside the medclip_zero_shot function
    text_features = [
        model.encode_text(
            input_ids=tokens['input_ids'].to(device), 
            attention_mask=tokens['attention_mask'].to(device)
        ) 
        for tokens in text_tokens
    ]

    # Initialize variables for accuracy calculation
    correct = 0
    total = 0

    for images, labels in tqdm(data_loader):
        images, labels = images.to(device), labels.to(device)
        image_features = model.encode_image(images)
        text_features_tensor = torch.cat(text_features, dim=0)

        similarity = torch.matmul(image_features, text_features_tensor.t())
        probabilities = F.sigmoid(similarity)  # Convert to probabilities
        predictions = (probabilities > 0.5).float()  # Apply threshold to get binary predictions

        # Update correct and total counts
        # Calculate correct predictions in a multi-label scenario
        correct_preds = (predictions == labels).all(dim=1).sum().item()
        correct += correct_preds
        total += len(labels)

    return correct / total

In [None]:
# load MedCLIP-ResNet50
MedCLIP_ResNet50_model = MedCLIPModel(vision_cls=MedCLIPVisionModel).to(device)
MedCLIP_ResNet50_model
accuracy = medclip_zero_shot(MedCLIP_ResNet50_model, nih_train, classes)
print(f"\nAccuracy = {100*accuracy:.3f}%")

In [None]:
# load MedCLIP-ViT
MedCLIP_ViT_model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT).to(device)
accuracy = medclip_zero_shot(MedCLIP_ViT_model, nih_train, classes)
print(f"\nAccuracy = {100*accuracy:.3f}%")