# Exercise 03: CLIP zero-shot prediction
In this exercise, you will perform zero-shot prediction using CLIP.

### Basic Imports

In [2]:
import os
import time
import os.path as osp

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10
from torchvision import datasets
from torchvision import transforms
import torchvision

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from PIL import Image
from clip import clip

### Hyperparameters

In [3]:
# # random seed
# SEED = 1 
# NUM_CLASS = 10

# Training
BATCH_SIZE = 128
# NUM_EPOCHS = 30
# EVAL_INTERVAL=1
# SAVE_DIR = './log'

# # Optimizer
# LEARNING_RATE = 1e-1
# MOMENTUM = 0.9
# STEP=5
# GAMMA=0.5

# CLIP
VISUAL_BACKBONE = 'RN50' # RN50, ViT-B/32, ViT-B/16


### Device

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


### Dataset


In [5]:
transform_cifar10_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_set = torchvision.datasets.CIFAR10(root='/shareddata', train=False,
                                       download=True, transform=transform_cifar10_test)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
dataset_name = 'CIFAR10'

Files already downloaded and verified


### Model

In [6]:
# Load the model
model, preprocess = clip.load(name=VISUAL_BACKBONE, device=device, download_root='/shareddata/clip/')
model.to(device)

CLIP(
  (visual): ModifiedResNet(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu3): ReLU(inplace=True)
    (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
     

### Task 1: Prompt Gereration
---

Please denfine a function named ``prompt_encode`` to encode the text using CLIP text encoder.


In [12]:

def prompt_encode(prompt):
    """
    Args:
        prompt (str): the text prefix before the class

    Returns:
        text_inputs(torch.Tensor)

    """
    ##################### Write your answer here ##################
    text_inputs = clip.tokenize(prompt).to(device)
    ###############################################################
    
    return text_inputs
prompt = 'a photo of a' # you can try different prompt
text_inputs = prompt_encode(prompt)
text_inputs

tensor([[49406,   320,  1125,   539,   320, 49407,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]], device='cuda:0',
       dtype=torch.int32)

### Task 2: Zero-shot inference
---

Please denfine a function named ``model_inference``. The function is essential for training and evaluating machine learning models using batched data from dataloaders.

**To do**: 
1. Encode the image.
2. Encode the text.
3. Calculate the logits.

In [18]:
def model_inference(model, image, text_inputs):
    """
    Perform model inference to calculate logits for zero-shot learning.

    Args:
        model (CLIPModel): The CLIP model for inference.
        image (torch.Tensor): The batch of images to encode.
        text_inputs (torch.Tensor): The batch of text inputs to encode.

    Returns:
        torch.Tensor: Logits representing the similarity between images and texts.
    """
    # Ensure the model is in evaluation mode
    model.eval()

    # Encode the image using the CLIP model
    with torch.no_grad():
        image_features = model.encode_image(image)

    # Normalize the image features
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)

    # Encode the text using the CLIP model's text encoder
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)

    # Normalize the text features
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    # Calculate the similarity (logits) between each image and text
    logits = image_features @ text_features.t()

    return logits


### Task 3: Zero-shot accuracy calculation
---

In [19]:

testing_loss = []
testing_acc = []

with torch.no_grad():
    model.eval()

    ##################### Write your answer here ##################
# Assuming 'test_dataloader' is your DataLoader for the CIFAR10 test dataset
# and 'class_names' is a list of class names

# Encode the text prompts for all classes
text_prompts = [f"a photo of a {classname}" for classname in class_names]
text_inputs = clip.tokenize(text_prompts).to(device)

# Zero-shot accuracy calculation
correct_predictions = 0
total_predictions = 0

with torch.no_grad():
    model.eval()

    for images, labels in test_dataloader:
        # Move images and labels to the same device as the model
        images = images.to(device)
        labels = labels.to(device)

        # Perform model inference to get logits
        logits = model_inference(model, images, text_inputs)

        # Predictions are the indices of the max logit values
        predictions = logits.argmax(dim=-1)

        # Update correct predictions and total predictions
        correct_predictions += (predictions == labels).sum().item()
        total_predictions += labels.size(0)

# Calculate accuracy
val_acc = correct_predictions / total_predictions


     ###############################################################
print(f"the zero-shot performance on {dataset_name} is {val_acc*100:.2f}%, visual encoder is {VISUAL_BACKBONE}.")




the zero-shot performance on CIFAR10 is 55.90%, visual encoder is RN50.
