In [1]:
import clip
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from tqdm import tqdm


In [2]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# Load the model
model, preprocess = clip.load("RN50", device=device, download_root=None)
model.eval()

# This also provides a useful preprocessing pipeline for the images
preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=warn)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x00000237721670A0>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [4]:
# Class labels of CIFAR10
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Load the test set of CIFAR10 and add the preprocessing pipeline
cifar10 = CIFAR10(root='./data', train=False, download=True, transform=preprocess)

# Create a dataloader
dl = DataLoader(cifar10, batch_size=64, num_workers=2, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:27<00:00, 6222836.82it/s] 


Extracting ./data\cifar-10-python.tar.gz to ./data


In [5]:
# let's start by using no prompt but simply the classname as sanity check
template = 'An image of a {}'

# The result is a tensor of shape (1024, 10),
# since we have 10 classes and the feature dimension of the text encoder is 1024
text_embedding = None

# 1. Insert each classname into the template to create a prompt
# 2. Tokenize the prompts with `clip.tokenize`
tokens = clip.tokenize([template.format(cls) for cls in classes]).to(device)
# we don't want to calculate gradient during evaluation
with torch.no_grad():
    # 3. Forward the result through the text encoder
    text_embedding = model.encode_text(tokens)
    # 4. Normalize the embedding
    text_embedding = F.normalize(text_embedding, dim=-1).T

print(text_embedding.shape, text_embedding.norm(dim=0))

torch.Size([1024, 10]) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0',
       dtype=torch.float16)


In [6]:
correct = 0
total = len(cifar10)

with torch.no_grad():
    # 1. Loop over the dataset
    for inputs, targets in tqdm(dl):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # 2. Create visual embeddings with the image encoder
        features = model.encode_image(inputs)
        features = F.normalize(features, dim=-1)
        
        # 3. Calculate the cosine similarity between the image and text embeddings
        logits = features @ text_embedding
        
        # 4. Count the number of correct predictions
        preds = logits.argmax(1)
        correct += (preds == targets).sum()

print(f"Accuracy: {correct / total: .5f}")

100%|██████████| 157/157 [00:15<00:00, 10.02it/s]

Accuracy:  0.71230





| Prompt | Accuracy | Comment            |
|--------| --- |--------------------|
| 'This is a {}.'  | 0.696 | not so good |
| 'A {}'  | 0.672 | Poor result |
| '{}'  | 0.708 | Could be better... |
| 'This image contains {}' | 71.5 | Improvment, but not significantly better |
| 'There is {}' | 71.5 | Improvment, but not significantly better |
| 'a photo of a {}'  | 0.716 | better |
| 'you can see a {}' | 0.716 | better |
| 'A {} is in the picture' | 0.726 | better |
| 'In this image you can see a {}' | 0.727 | Best result so far |
| 'You can see a {} in this image' | 0.727 | Best result so far |
| 'An {} is waiting' | 0.733 | best |