In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip

In [15]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('RN101', device)

100%|███████████████████████████████████████| 278M/278M [01:19<00:00, 3.68MiB/s]


In [17]:
from torchvision.datasets import MNIST

dataset = MNIST(root="./dataset", download=True)

In [18]:
MNIST.classes

['0 - zero',
 '1 - one',
 '2 - two',
 '3 - three',
 '4 - four',
 '5 - five',
 '6 - six',
 '7 - seven',
 '8 - eight',
 '9 - nine']

In [19]:
all_indices = torch.arange(len(dataset))
all_indices

tensor([    0,     1,     2,  ..., 59997, 59998, 59999])

In [20]:
def batch_acc(predicted, actual):
    total = (predicted == actual).count_nonzero().sum()
    return total / predicted.size(0)

batch_acc(torch.arange(10), torch.arange(10))

tensor(1.)

In [21]:
from tqdm.auto import tqdm

def infer(text, batch_size=512, all_indices=all_indices) -> None:
    batched_indices = torch.split(all_indices, batch_size)
    n_batches = len(batched_indices)
    batch_wise_acc = torch.zeros(n_batches, )
    
    for batch_idx, indices in tqdm(enumerate(batched_indices), total=n_batches):
        images = [dataset[int(i)][0] for i in indices]
        labels = [dataset[int(i)][1] for i in indices]
        labels = torch.tensor(labels).float()
    
        images = torch.stack([
            preprocess(image) for image in images
        ], dim=0)
    
    
        images = images.to(device)
        text = text.to(device)
    
        with torch.no_grad():
            logits_per_image, _ = model(images, text)
        
        probs = logits_per_image.softmax(dim=-1).cpu()
        pred = torch.argmax(probs, dim=-1)
    
    
        batch_wise_acc[batch_idx] = batch_acc(pred, labels)
        
    # over all the batches
    mean_acc =  batch_wise_acc.mean(dim=-1)
    print(f"Mean Accuracy :: {mean_acc.item()} over :: {n_batches} batches.")   

In [22]:
# https://github.com/openai/CLIP/issues/164
label_prompts = [
    f"a photo of the number: '{i}'."
    for i in range(10)
]

text = clip.tokenize(label_prompts)
infer(text=text)

  0%|          | 0/118 [00:00<?, ?it/s]

Mean Accuracy :: 0.5203092098236084 over :: 118 batches.


In [23]:
label_prompts = [
    f"an image of the handwritten form of the number: '{i}'."
    for i in range(10)
]

text = clip.tokenize(label_prompts)
infer(text=text)

  0%|          | 0/118 [00:00<?, ?it/s]

Mean Accuracy :: 0.47272247076034546 over :: 118 batches.
