In [1]:
import clip 
import torch 
import torchvision
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader



In [2]:
clip_preprocess_tensor = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.48145466, 0.4578275, 0.40821073],
        std=[0.26862954, 0.26130258, 0.27577711]
    )
])

In [3]:
test_dataset=OxfordIIITPet(root='../data/',
                            split='test',
                            transform=clip_preprocess_tensor,
                            download=False)
cls_names=test_dataset.classes

In [4]:
for (img,lbl) in test_dataset:
    break

In [5]:
batch_size=8
test_dataloader=DataLoader(dataset=test_dataset,
                           batch_size=8)

In [6]:
for imgs,lbls in test_dataloader:
    break

### Importing Pretrained model for zero shot classification 

In [7]:
device='cuda:1' if torch.cuda.is_available() else 'cpu'
model,preprocess=clip.load('ViT-L/14',device=device)
model=model.eval()

## Creating text features by using different forms of prompts; 

see the list `prompt_templates` for reference

And for the final text_features, we will take mean of all the prompt for each class 


In [30]:


prompt_templates = [
    '{}',
    'a bad photo of a {}.',
    'a photo of many {}.',
    'a sculpture of a {}.',
    'a photo of the hard to see {}.',
    'a low resolution photo of the {}.',
    'a rendering of a {}.',
    'graffiti of a {}.',
    'a bad photo of the {}.',
    'a cropped photo of the {}.',
    'a tattoo of a {}.',
    'the embroidered {}.',
    'a photo of a hard to see {}.',
    'a bright photo of a {}.',
    'a photo of a clean {}.',
    'a photo of a dirty {}.',
    'a dark photo of the {}.',
    'a drawing of a {}.',
    'a photo of my {}.',
    'the plastic {}.',
    'a photo of the cool {}.',
    'a close-up photo of a {}.',
    'a black and white photo of the {}.',
    'a painting of the {}.',
    'a painting of a {}.',
    'a pixelated photo of the {}.',
    'a sculpture of the {}.',
    'a bright photo of the {}.',
    'a cropped photo of a {}.',
    'a plastic {}.',
    'a photo of the dirty {}.',
    'a jpeg corrupted photo of a {}.',
    'a blurry photo of the {}.',
    'a photo of the {}.',
    'a good photo of the {}.',
    'a rendering of the {}.',
    'a {} in a video game.',
    'a photo of one {}.',
    'a doodle of a {}.',
    'a close-up photo of the {}.',
    'a photo of a {}.',
    'the origami {}.',
    'the {} in a video game.',
    'a sketch of a {}.',
    'a doodle of the {}.',
    'a origami {}.',
    'a low resolution photo of a {}.',
    'the toy {}.',
    'a rendition of the {}.',
    'a photo of the clean {}.',
    'a photo of a large {}.',
    'a rendition of a {}.',
    'a photo of a nice {}.',
    'a photo of a weird {}.',
    'a blurry photo of a {}.',
    'a cartoon {}.',
    'art of a {}.',
    'a sketch of the {}.',
    'a embroidered {}.',
    'a pixelated photo of a {}.',
    'itap of the {}.',
    'a jpeg corrupted photo of the {}.',
    'a good photo of a {}.',
    'a plushie {}.',
    'a photo of the nice {}.',
    'a photo of the small {}.',
    'a photo of the weird {}.',
    'the cartoon {}.',
    'art of the {}.',
    'a drawing of the {}.',
    'a photo of the large {}.',
    'a black and white photo of a {}.',
    'the plushie {}.',
    'a dark photo of a {}.',
    'itap of a {}.',
    'graffiti of the {}.',
    'a toy {}.',
    'itap of my {}.',
    'a photo of a cool {}.',
    'a photo of a small {}.',
    'a tattoo of the {}.',
]
# prompt_templates=['A photo of {}']

def get_text_prompts(classes):
    text_prompts=[]
    for templete in prompt_templates:
        templete_prompts=[]
        for c in classes:
            prompt=templete.replace('{}',c)
            templete_prompts.append(prompt)
        text_prompts.append(templete_prompts)

    return text_prompts

In [31]:
text_prompts=get_text_prompts(cls_names)

In [32]:
text_prompts[1]

['a bad photo of a Abyssinian.',
 'a bad photo of a American Bulldog.',
 'a bad photo of a American Pit Bull Terrier.',
 'a bad photo of a Basset Hound.',
 'a bad photo of a Beagle.',
 'a bad photo of a Bengal.',
 'a bad photo of a Birman.',
 'a bad photo of a Bombay.',
 'a bad photo of a Boxer.',
 'a bad photo of a British Shorthair.',
 'a bad photo of a Chihuahua.',
 'a bad photo of a Egyptian Mau.',
 'a bad photo of a English Cocker Spaniel.',
 'a bad photo of a English Setter.',
 'a bad photo of a German Shorthaired.',
 'a bad photo of a Great Pyrenees.',
 'a bad photo of a Havanese.',
 'a bad photo of a Japanese Chin.',
 'a bad photo of a Keeshond.',
 'a bad photo of a Leonberger.',
 'a bad photo of a Maine Coon.',
 'a bad photo of a Miniature Pinscher.',
 'a bad photo of a Newfoundland.',
 'a bad photo of a Persian.',
 'a bad photo of a Pomeranian.',
 'a bad photo of a Pug.',
 'a bad photo of a Ragdoll.',
 'a bad photo of a Russian Blue.',
 'a bad photo of a Saint Bernard.',
 'a 

In [33]:
text_tokens=[]

with torch.no_grad():
    for text_prompt in text_prompts:
        text_token=clip.tokenize(text_prompt).to(device)
        text_tokens.append(text_token)

In [34]:
print('text_tokens[0].shape ', text_tokens[0].shape)
print(text_tokens[0])

text_tokens[0].shape  torch.Size([37, 77])
tensor([[49406, 16342,   690,  ...,     0,     0,     0],
        [49406,  2151, 15611,  ...,     0,     0,     0],
        [49406,  2151,  7476,  ...,     0,     0,     0],
        ...,
        [49406, 29198,  6029,  ...,     0,     0,     0],
        [49406, 20505,   576,  ...,     0,     0,     0],
        [49406,  8633, 14455,  ...,     0,     0,     0]], device='cuda:1',
       dtype=torch.int32)


In [35]:
text_features=[]
with torch.no_grad():
    for token in text_tokens:
        text_feature=model.encode_text(token)
        text_features.append(text_feature)

text_features=torch.mean(torch.stack(text_features),dim=0)

## How to do Zero-Shot Classification
- First, we will extract image features for each image
- Second, we will measure, how each image allign with the given labels; we will do this by calculating the ``similarity`` between image features and text features, and will take the argmax for predicting the final class label

In [57]:
from tqdm import tqdm
import numpy as np 
batch_accuracy=[]

for img,lbls in tqdm(test_dataloader,desc='Zero-Shot Classification in Progress'):
    img_features=model.encode_image(img.to(device))
    similarity=img_features @ text_features.T
    pred_cls=torch.argmax(similarity,dim=1).cpu()
    acc_bool=(pred_cls==lbls).to(dtype=torch.float32)
    acc=(torch.sum(acc_bool)/len(pred_cls)).cpu().item()
    batch_accuracy.append(acc)



    

Zero-Shot Classification in Progress: 100%|██████████| 459/459 [00:23<00:00, 19.72it/s]


In [61]:
accuracy=round(np.mean(batch_accuracy,axis=0)*100,3)
print('Zero Shot Image Classification ', accuracy)

Zero Shot Image Classification  91.285


In this jupyter notebook we witnessed the capacity of zero shot classification of `91.26%` in ``OxfordIIITPet`` containing `37` classes. 