<a href="https://colab.research.google.com/github/Tasnima158/clip/blob/main/CLIP_cifar_10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from torchvision.datasets import CIFAR10
from torchvision import transforms


In [None]:
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git


Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy
Successfully installed ftfy-6.3.1
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-_98pqmhs
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-_98pqmhs
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369490

In [None]:
import torch
import clip
from PIL import Image
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"

model, preprocess = clip.load("ViT-B/32", device=device)
model.eval()

print("CLIP loaded on:", device)


100%|████████████████████████████████████████| 338M/338M [00:03<00:00, 112MiB/s]


CLIP loaded on: cuda


In [None]:
cifar_test = CIFAR10(
    root="./data",
    train=False,
    download=True
)


100%|██████████| 170M/170M [00:14<00:00, 11.8MB/s]


In [None]:
# class_names = [
#     "airplane",
#     "automobile",
#     "bird",
#     "cat",
#     "deer",
#     "dog",
#     "frog",
#     "horse",
#     "ship",
#     "truck"
# ]


In [None]:
#  text_prompts = [
#     "a photo of an airplane flying in the sky",
#     "a photo of a car on the road",
#     "a photo of a bird",
#     "a photo of a cat",
#     "a photo of a deer",
#     "a photo of a dog",
#     "a photo of a frog",
#     "a photo of a horse",
#     "a photo of a ship in the water",
#     "a photo of a truck on the road"
# ]


In [None]:
text_prompts = [
    "a low resolution photo of an airplane",
    "a low resolution photo of an automobile",
    "a low resolution photo of a bird",
    "a low resolution photo of a cat",
    "a low resolution photo of a deer",
    "a low resolution photo of a dog",
    "a low resolution photo of a frog",
    "a low resolution photo of a horse",
    "a low resolution photo of a ship",
    "a low resolution photo of a truck"
    ]

In [None]:
text_tokens = clip.tokenize(text_prompts).to(device)


In [None]:
with torch.no_grad():
    text_features = model.encode_text(text_tokens)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)


In [None]:
correct = 0
total = 0

for image, label in cifar_test:
    image_input = preprocess(image).unsqueeze(0).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image_input)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        similarity = (image_features @ text_features.T).softmax(dim=-1)
        predicted_class = similarity.argmax(dim=-1).item()

    if predicted_class == label:
        correct += 1

    total += 1


In [None]:
zero_shot_accuracy = correct / total
zero_shot_accuracy


0.8877

In [None]:
train_dataset = CIFAR10(
    root="./data",
    train=True,
    download=True
)

test_dataset = CIFAR10(
    root="./data",
    train=False,
    download=True
)


In [None]:
class CIFAR10_CLIP(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        image = preprocess(image)
        return image, label


In [None]:
train_loader = DataLoader(
    CIFAR10_CLIP(train_dataset),
    batch_size=64,
    shuffle=True
)

test_loader = DataLoader(
    CIFAR10_CLIP(test_dataset),
    batch_size=64,
    shuffle=False
)


In [None]:
for param in model.parameters():
    param.requires_grad = False


In [None]:
from torch import nn
classifier = nn.Linear(512, 10).to(device)


In [None]:
optimizer = torch.optim.Adam(
    classifier.parameters(),
    lr=1e-3   # higher LR is OK here
)

loss_fn = nn.CrossEntropyLoss()


In [None]:
model.eval()
classifier.train()

epochs = 10

for epoch in range(epochs):
    total_loss = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            image_features = image_features.float()

        logits = classifier(image_features)
        loss = loss_fn(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(train_loader):.4f}")


Epoch 1/10 | Loss: 1.3822
Epoch 2/10 | Loss: 0.5887
Epoch 3/10 | Loss: 0.3839
Epoch 4/10 | Loss: 0.3021
Epoch 5/10 | Loss: 0.2591
Epoch 6/10 | Loss: 0.2331
Epoch 7/10 | Loss: 0.2162
Epoch 8/10 | Loss: 0.2044
Epoch 9/10 | Loss: 0.1956
Epoch 10/10 | Loss: 0.1887


In [None]:
model.eval()
classifier.eval()

correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        image_features = model.encode_image(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        image_features = image_features.float()

        logits = classifier(image_features)
        preds = logits.argmax(dim=1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
accuracy


0.9403