# Exercise 03: CLIP zero-shot prediction
In this exercise, you will perform zero-shot prediction using CLIP.

### Basic Imports

In [1]:
import os
import time
import os.path as osp

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10
from torchvision import datasets
from torchvision import transforms
import torchvision

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from PIL import Image
from clip import clip

### Hyperparameters

In [2]:
# # random seed
SEED = 1 
NUM_CLASS = 10

# Training
BATCH_SIZE = 128
NUM_EPOCHS = 30
EVAL_INTERVAL=1
SAVE_DIR = './log'

# # Optimizer
LEARNING_RATE = 1e-1
MOMENTUM = 0.9
STEP=5
GAMMA=0.5

# CLIP
VISUAL_BACKBONE = 'RN50' # RN50, ViT-B/32, ViT-B/16

### Device

In [3]:
device = device = "cuda:0" if torch.cuda.is_available() else "cpu"


### Dataset


In [4]:
transform_cifar10_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_set = torchvision.datasets.CIFAR10(root='/shareddata', train=False,
                                       download=True, transform=transform_cifar10_test)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
dataset_name = 'CIFAR10'

Files already downloaded and verified


### Model

In [5]:
# Load the model
model, preprocess = clip.load(name=VISUAL_BACKBONE, device=device, download_root='/shareddata/clip/')
model.to(device)

CLIP(
  (visual): ModifiedResNet(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu3): ReLU(inplace=True)
    (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
     

### Task 1: Prompt Gereration
---

Please denfine a function named ``prompt_encode`` to encode the text using CLIP text encoder.


In [6]:
prompt = 'a photo of a' # you can try different prompt
def prompt_encode(prompt):
    text = [f"{prompt} {class_name}" for class_name in class_names]
    text_inputs = clip.tokenize(text).to(device)
    return text_inputs

### Task 2: Zero-shot inference
---

Please denfine a function named ``model_inference``. The function is essential for training and evaluating machine learning models using batched data from dataloaders.

**To do**: 
1. Encode the image.
2. Encode the text.
3. Calculate the logits.

In [7]:
def model_inference(model, image, text_inputs):
    ##################### Write your answer here ##################
    image_features = model.encode_image(image)
    image_features = model.encode_text(text_inputs)
    logits_per_image = model(image,text_inputs)
    ###############################################################
    return logits_per_image

In [8]:
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

### Task 3: Zero-shot accuracy calculation
---

In [9]:
testing_loss = []
testing_acc = []

with torch.no_grad():
    model.eval()
    for batch_idx,(images, labels) in enumerate(test_dataloader):
        images = images.to(device)
        labels = labels.to(device)
        # Encode the text using the prompt
        text_inputs =prompt_encode(prompt)
        # Model inference
        logits,_ = model_inference(model,images,text_inputs)
        # Calculate loss
        loss = F.cross_entropy(logits, labels)
        testing_loss.append(loss.item())
        # Calculate accuracy
        _, predicted = logits.max(1)
        correct = predicted.eq(labels).sum().item()
        acc = correct / BATCH_SIZE
        print(acc)
        testing_acc.append(acc)
# Calculate average accuracy
val_acc = np.mean(testing_acc)


0.53125
0.53125
0.4921875
0.609375
0.578125
0.5234375
0.5703125
0.59375
0.546875
0.53125
0.609375
0.5859375
0.578125
0.609375
0.5703125
0.46875
0.5703125
0.5078125
0.5859375
0.5546875
0.5703125
0.59375
0.5546875
0.609375
0.6015625
0.6328125
0.5703125
0.4921875
0.5234375
0.5703125
0.5546875
0.5703125
0.6171875
0.546875
0.5546875
0.546875
0.4765625
0.5546875
0.59375
0.5546875
0.6171875
0.5234375
0.4921875
0.5546875
0.59375
0.5859375
0.484375
0.546875
0.4921875
0.609375
0.4609375
0.5703125
0.5859375
0.5546875
0.578125
0.5859375
0.671875
0.5390625
0.5390625
0.5859375
0.5390625
0.53125
0.5703125
0.5703125
0.5390625
0.5078125
0.5
0.5625
0.5625
0.5078125
0.4921875
0.578125
0.5625
0.53125
0.578125
0.6484375
0.5625
0.59375
0.0625


In [11]:
print(f'Testing Loss: {np.mean(testing_loss):.4f}, Testing Accuracy: {val_acc:.4f}')

Testing Loss: 1.2282, Testing Accuracy: 0.5520
