see https://github.com/openai/CLIP

In [1]:
! pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /private/var/folders/9g/fz6w6gks1dq2j6bw3pg865kr0000gn/T/pip-req-build-r108z586
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /private/var/folders/9g/fz6w6gks1dq2j6bw3pg865kr0000gn/T/pip-req-build-r108z586
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting ftfy (from clip==1.0)
  Using cached ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Using cached ftfy-6.3.1-py3-none-any.whl (44 kB)
Building wheels for collected packages: clip
  Building wheel for clip (pyproject.toml) ... [?25ldone
[?25h  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369549 sha256=f84593c69defbffc1c8528345fb3e100692c4b2b63cf03dc694

In [2]:
import torch
import clip
from torchvision import datasets, transforms
from tqdm import tqdm  # for progress bar

In [3]:
if torch.cuda.is_available():
    dev = "cuda:0"
elif torch.backends.mps.is_available():
    dev = "mps"
else:
    dev = "cpu"
device = torch.device(dev)
device

device(type='mps')

load the CLIP model

In [4]:
model, preprocess = clip.load("ViT-B/32", device=device)

The class labels are converted into text prompts like "a photo of a cat", which are then tokenized.

In [5]:
# CIFAR-10 class labels
cifar10_classes = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

text_inputs = clip.tokenize([f"a photo of a {c}" for c in cifar10_classes]).to(device)
# text_inputs = clip.tokenize([f"a blurry image of a {c}" for c in cifar10_classes]).to(device)

load CIFAR-10 test dataset (no training, as we go for zero-shot)

In [6]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to CLIP input size
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

testset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True)

pick just use one sample to show zero-shot classification

In [7]:
image, label = next(iter(testloader))
image = image.to(device)

CLIP encodes both the image and text into a shared embedding space.

In [8]:
with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text_inputs)

normalize features

In [9]:
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)

compute cosine similarity between image and text embeddings

In [10]:
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

In [11]:
values, indices = similarity[0].topk(3)

for value, index in zip(values, indices):
    print(f"{cifar10_classes[index]:>16s}: {100 * value.item():.2f}%")

             dog: 69.04%
             cat: 14.25%
           horse: 3.89%


In [12]:
cifar10_classes[label]

'dog'

full test set evaluation

In [13]:
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)

correct = 0
total = 0

with torch.no_grad():
    for images, labels in tqdm(testloader, desc="Evaluating", unit="batch"):
        images = images.to(device)
        labels = labels.to(device)

        image_features = model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        similarities = image_features @ text_features.T
        predictions = similarities.argmax(dim=1)

        correct += (predictions == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total * 100
print(f"zero-shot classification accuracy on CIFAR-10: {accuracy:.2f}%")

Evaluating: 100%|██████████| 313/313 [00:47<00:00,  6.63batch/s]

zero-shot classification accuracy on CIFAR-10: 84.72%



