In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms as t
from open_clip import create_model_from_pretrained, get_tokenizer

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

root_dir = "../"
sys.path.append(root_dir)
from datasets import HAM
from classifiers import BiomedCLIP

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_dir = os.path.join(root_dir, "data")

model, preprocess = create_model_from_pretrained(
    "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
)
tokenizer = get_tokenizer(
    "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
)

transform = t.Compose(
    [
        t.Resize((224, 224), antialias=True),
        t.ToTensor(),
        t.Normalize(
            mean=[0.48145466, 0.4578275, 0.40821073],
            std=[0.26862954, 0.26130258, 0.27577711],
        ),
    ]
)
dataset = HAM(root=data_dir, transform=transform)
prompts = [f"This is a photo of {class_name}" for class_name in dataset.classes]
print("\n".join(prompts))

In [None]:
model = model.to(device)
model.eval()

text = tokenizer(prompts).to(device)
print(text.size())

idx = np.random.choice(len(dataset), size=10)
for _idx in idx:
    image, label = dataset[_idx]

    image = image.unsqueeze(0).to(device)
    with torch.no_grad():
        image_features, text_features, logit_scale = model(image, text)
    logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
    print(label, torch.argsort(logits, dim=-1, descending=True))