In [None]:
import torch
import clip
from PIL import Image

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
!wget -O "twocats.jpg" "http://images.cocodataset.org/val2017/000000039769.jpg"

In [None]:
image = preprocess(Image.open("inputs/twocats.jpg")).unsqueeze(0).to(device)

In [None]:
text = clip.tokenize(["two cats", "dog", "one cat", "two running cats", "two sleeping cats"]).to(device)

In [None]:
with torch.no_grad():
    logits_per_image, logits_per_text = model(image, text) # 1 x 3 x 224 x 224, 5 x text -> 1 x 5, 5 x 1
    probs = logits_per_image.softmax(dim=-1).cpu().numpy() # 1 x 5

In [None]:
print("Label probs:", [f'{prob:.2f}' for prob in probs[0]])

In [None]:
with torch.no_grad():
    text_features = model.encode_text(text)    # 5 x text -> 5 x 512
    text_features = text_features / text_features.norm(dim=1, keepdim=True)

In [None]:
with torch.no_grad():
    image_features = model.encode_image(image) # 1 x 3 x 224 x 224 -> 1 x 512
    image_features = image_features / image_features.norm(dim=1, keepdim=True)

In [None]:
# cosine similarity as logits
logit_scale = torch.tensor(4.6052).exp() # == 100
logits_per_image2 = logit_scale * image_features @ text_features.t()
logits_per_text2 = logits_per_image.t()
# logits_per_image == logits_per_image2