# Exercise 03: CLIP zero-shot prediction
In this exercise, you will perform zero-shot prediction using CLIP.

### Basic Imports

In [7]:
pip install ftfy

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting ftfy
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m438.3 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: ftfy
Successfully installed ftfy-6.1.1
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [34]:
import os
import time
import os.path as osp

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10
from torchvision import datasets
from torchvision import transforms
import torchvision

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from PIL import Image
from clip import clip
import tqdm

### Hyperparameters

In [37]:
# # random seed
# SEED = 1 
# NUM_CLASS = 10

# Training
BATCH_SIZE = 128
# NUM_EPOCHS = 30
# EVAL_INTERVAL=1
# SAVE_DIR = './log'

# # Optimizer
# LEARNING_RATE = 1e-1
# MOMENTUM = 0.9
# STEP=5
# GAMMA=0.5

# CLIP
VISUAL_BACKBONE = 'ViT-B/32' # RN50, ViT-B/32, ViT-B/16


### Device

In [38]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


### Dataset


In [39]:
transform_cifar10_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_set = torchvision.datasets.CIFAR10(root='/shareddata', train=False,
                                       download=True, transform=transform_cifar10_test)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
dataset_name = 'CIFAR10'

Files already downloaded and verified


### Model

In [40]:
# Load the model
model, preprocess = clip.load(name=VISUAL_BACKBONE, device=device, download_root='/shareddata/clip/')
model.to(device)

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

### Task 1: Prompt Gereration
---

Please denfine a function named ``prompt_encode`` to encode the text using CLIP text encoder.


In [46]:
prompt = 'a photo of a' # you can try different prompt

def prompt_encode(prompt):
    
    
    
    """
    Args:
        prompt (str): the text prefix before the class

    Returns:
        text_inputs(torch.Tensor)

    """
    ##################### Write your answer here ##################
    text_inputs_list = []
    for i in class_names:
        text_inputs_list.append(clip.tokenize(prompt + i).to(device))
    
    text_inputs = torch.cat(text_inputs_list)

    ###############################################################
    
    return text_inputs


### Task 2: Zero-shot inference
---

Please denfine a function named ``model_inference``. The function is essential for training and evaluating machine learning models using batched data from dataloaders.

**To do**: 
1. Encode the image.
2. Encode the text.
3. Calculate the logits.

In [47]:

def model_inference(model, image, text_inputs):
    

    ##################### Write your answer here ##################
    model.eval()
    image = image.to(device)
    text_inputs = text_inputs.to(device)

    image_encoding = model.encode_image(image)
    text_encoding = model.encode_text(text_inputs)

    logits = model(image, text_inputs)

    ###############################################################

    return logits

### Task 3: Zero-shot accuracy calculation
---

In [48]:

testing_acc = []

with torch.no_grad():
    model.eval()

    ##################### Write your answer here ##################
    for image,label in test_dataloader:
        
        image = image.to(device)
        label = label.to(device)
        
        text_inputs = prompt_encode(prompt)
    
        logits = model_inference(model, image, text_inputs)

        predictions = logits[0].argmax(dim=1).to(device)
        accuracy = (predictions == label).sum().item() / label.size(0)
        testing_acc.append(accuracy)
        val_acc = sum(testing_acc) / len(testing_acc)
    



     ###############################################################



    print(f"the zero-shot performance on {dataset_name} is {val_acc*100:.2f}%, visual encoder is {VISUAL_BACKBONE}.")




the zero-shot performance on CIFAR10 is 75.17%, visual encoder is ViT-B/32.
