# Contrastive Models – Analyzing CLIP (ViT-B/32) and Multimodal Biases

In [2]:
#!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-rme8eufq
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-rme8eufq
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting 

In [3]:
# IMPORTS

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import clip

import torchvision
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

## Prepping Model and Dataset

In [4]:
model, preprocess = clip.load("ViT-B/32", device=device)

100%|███████████████████████████████████████| 338M/338M [00:06<00:00, 57.2MiB/s]


In [5]:
transform = preprocess
testset = CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

100%|██████████| 170M/170M [00:05<00:00, 29.7MB/s] 


In [6]:
# CIFAR-10 class names
class_names = ["airplane", "automobile", "bird" , "cat", "deer",
               "dog", "frog", "horse", "ship", "truck"]

In [7]:
plain_labels = class_names
prompted_text = [f"A Photo of a {c}" for c in class_names]
pt_2 = [f"Here lies a Picture of {c}" for c in class_names]
sketch_prompt = [f"A drawing of a {c}" for c in class_names]

## Zero-Shot Classification Test

In [8]:
zeroshot_ws = []
with torch.no_grad():
    for i in range(len(class_names)):
        texts = [plain_labels[i], prompted_text[i], pt_2[i], sketch_prompt[i]]
        tokenized_texts = clip.tokenize(texts).to(device)
        class_embeds = model.encode_text(tokenized_texts)
        print(class_embeds.shape)
        print(class_embeds.norm(dim=-1, keepdim=True))
        class_embeds /= class_embeds.norm(dim=-1, keepdim=True) # [-1, 1] norm for cosine sim term in clip
        
        class_embeds = class_embeds.mean(dim=0)
        class_embeds /= class_embeds.norm(dim=-1, keepdim=True)
        zeroshot_ws.append(class_embeds)
    zeroshot_ws = torch.stack(zeroshot_ws, dim=1).to(device) # [dim, num_classes]

torch.Size([4, 512])
tensor([[10.1058],
        [ 9.7643],
        [10.2154],
        [ 9.3906]])
torch.Size([4, 512])
tensor([[10.6856],
        [10.4549],
        [10.8002],
        [ 9.4328]])
torch.Size([4, 512])
tensor([[10.8247],
        [10.3093],
        [10.6148],
        [ 9.6954]])
torch.Size([4, 512])
tensor([[10.4445],
        [10.4764],
        [10.2702],
        [ 9.5396]])
torch.Size([4, 512])
tensor([[9.6999],
        [9.5237],
        [9.9410],
        [8.8227]])
torch.Size([4, 512])
tensor([[10.9251],
        [10.7019],
        [10.6141],
        [ 9.7131]])
torch.Size([4, 512])
tensor([[10.1615],
        [ 9.6822],
        [10.2106],
        [ 8.9358]])
torch.Size([4, 512])
tensor([[10.5699],
        [10.3246],
        [10.4016],
        [ 9.3490]])
torch.Size([4, 512])
tensor([[11.2740],
        [10.4792],
        [10.8434],
        [ 9.5741]])
torch.Size([4, 512])
tensor([[10.7632],
        [10.0096],
        [10.3840],
        [ 8.9663]])


In [None]:
# Eval

correct = 0
total = 0

with torch.no_grad():
    for images, labels in testloader:
        images = images.to(device)
        labels = labels.to(device)

        image_feats = model.encode_image(images) # [batch, dim]
        image_feats /= image_feats.norm(dim=-1, keepdim=True)

        logits = 100.0 * image_feats @ zeroshot_ws # [batch, num_classes]
        preds = torch.argmax(logits, dim=-1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)

acc = correct/total
print(f"Zero-Shot Classification Accuracy on various styles of prompts: {(acc*100):.2f}%")

## Image Text Retrieval