# Contrastive Models – Analyzing CLIP (ViT-B/32) and Multimodal Biases

In [None]:
# !pip install git+https://github.com/openai/CLIP.git

In [None]:
# IMPORTS

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import clip

import torchvision
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

## Prepping Model and Dataset

In [None]:
model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
transform = preprocess
testset = CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

In [None]:
# CIFAR-10 class names
class_names = ["airplane", "automobile", "bird" , "cat", "deer",
               "dog", "frog", "horse", "ship", "truck"]

In [None]:
plain_labels = class_names
prompted_text = [f"A Photo of a {c}" for c in class_names]
pt_2 = [f"Here lies a Picture of {c}" for c in class_names]
sketch_prompt = [f"A drawing of a {c}" for c in class_names]

## Zero-Shot Classification Test

In [None]:
zeroshot_ws = []
with torch.no_grad():
    for i in range(len(class_names)):
        texts = [plain_labels[i], prompted_text[i], pt_2[i], sketch_prompt[i]]
        tokenized_texts = clip.tokenize(texts).to(device)
        class_embeds = model.encode_text(tokenized_texts)
        class_embeds /= class_embeds.norm(dim=-1, keepdim=True) # [-1, 1] norm for cosine sim term in clip
        class_embeds = class_embeds.mean(dim=0)
        class_embeds /= class_embeds.norm(dim=-1, keepdim=True)
        zeroshot_ws.append(class_embeds)
    zeroshot_ws = torch.stack(zeroshot_ws, dim=1).to(device) # [dim, num_classes]

In [None]:
# Eval

correct = 0
total = 0

with torch.no_grad():
    for images, labels in testloader:
        images = images.to(device)
        labels = labels.to(device)

        image_feats = model.encode_image(images) # [batch, dim]
        image_feats /= image_feats.norm(dim=-1, keepdim=True)

        logits = 100.0 * image_feats @ zeroshot_ws # [batch, num_classes]
        preds = torch.argmax(logits, dim=-1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)

acc = correct/total
print(f"Zero-Shot Classification Accuracy on various styles of prompts: {(acc*100):.2f}%")

## Image Text Retrieval