# Assignment02: CLIP

### Basic Imports

In [1]:
import os
import time
import os.path as osp

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10
from torchvision import datasets
from torchvision import transforms
import torchvision

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from PIL import Image
from clip import clip

### Hyperparameters

In [2]:
# # random seed
# SEED = 1 
# NUM_CLASS = 10

# Training
BATCH_SIZE = 128
# NUM_EPOCHS = 30
# EVAL_INTERVAL=1
# SAVE_DIR = './log'

# # Optimizer
# LEARNING_RATE = 1e-1
# MOMENTUM = 0.9
# STEP=5
# GAMMA=0.5

### Device

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


### Dataset


In [4]:
transform_cifar10_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_set = torchvision.datasets.CIFAR10(root='/shareddata', train=False,
                                       download=True, transform=transform_cifar10_test)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
dataset_name = 'CIFAR10'

Files already downloaded and verified


In [5]:
test_set[0]

(tensor([[[ 0.6338,  0.6338,  0.6338,  ..., -0.1804, -0.1804, -0.1804],
          [ 0.6338,  0.6338,  0.6338,  ..., -0.1804, -0.1804, -0.1804],
          [ 0.6338,  0.6338,  0.6338,  ..., -0.1804, -0.1804, -0.1804],
          ...,
          [-1.3823, -1.3823, -1.3823,  ..., -2.0220, -2.0220, -2.0220],
          [-1.3823, -1.3823, -1.3823,  ..., -2.0220, -2.0220, -2.0220],
          [-1.3823, -1.3823, -1.3823,  ..., -2.0220, -2.0220, -2.0220]],
 
         [[-0.2156, -0.2156, -0.2156,  ..., -0.7466, -0.7466, -0.7466],
          [-0.2156, -0.2156, -0.2156,  ..., -0.7466, -0.7466, -0.7466],
          [-0.2156, -0.2156, -0.2156,  ..., -0.7466, -0.7466, -0.7466],
          ...,
          [-0.3139, -0.3139, -0.3139,  ..., -1.1006, -1.1006, -1.1006],
          [-0.3139, -0.3139, -0.3139,  ..., -1.1006, -1.1006, -1.1006],
          [-0.3139, -0.3139, -0.3139,  ..., -1.1006, -1.1006, -1.1006]],
 
         [[-1.2654, -1.2654, -1.2654,  ..., -1.5776, -1.5776, -1.5776],
          [-1.2654, -1.2654,

### Model

In [6]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [7]:
# CLIP
VISUAL_BACKBONE = 'RN50' # RN50, ViT-B/32, ViT-B/16

In [8]:
# Load the model
model, preprocess = clip.load(name=VISUAL_BACKBONE, device=device, download_root='/shareddata/clip/')
model.to(device);

RuntimeError: NVML_SUCCESS == r INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/c10/cuda/CUDACachingAllocator.cpp":1150, please report a bug to PyTorch. 

### Task 1: Prompt Gereration
---

Please denfine a function named ``prompt_encode`` to encode the text using CLIP text encoder.


In [None]:
def prompt_encode(prompt):
    """
    Args:
        prompt (str): the text prefix before the class

    Returns:
        text_inputs(torch.Tensor)

    """
    ##################### Write your answer here ##################
    text_inputs =  torch.cat([clip.tokenize(f"{prompt} {c}") for c in class_names]).to(device)
#     print(text_inputs)
    ###############################################################

    return text_inputs


### Task 2: Zero-shot inference
---

Please denfine a function named ``model_inference``. The function is essential for training and evaluating machine learning models using batched data from dataloaders.

**To do**: 
1. Encode the image.
2. Encode the text.
3. Calculate the logits.

In [None]:

def model_inference(model, image, text_inputs):

    ##################### Write your answer here ##################
#     image_input = preprocess(image).unsqueeze(0).to(device)
    
#     image_features = model.encode_image(image_input)
#     text_features = model.encode_text(text_inputs)
    
    logits , _ = model(image, text_inputs)
    ###############################################################

    return logits

### Task 3: Zero-shot accuracy calculation
---

In [None]:
#try different prompt
prompt1 = 'a photo of a' 
prompt2 = 'The prominent features in this image indicate that it falls under the category of '
prompt3 = 'This image depicts '
prompt4 = 'it is considered to be '
prompt5 = 'An object in this image may be small or fuzzy, please identify it as accurately as possible. This is '
prompt6 = ''
prompt7 = "asdqwedqwdwda"

In [None]:
def run_clip(prompt):
    with torch.no_grad():
        model.eval()


        ##################### Write your answer here ##################
        text_inputs = prompt_encode(prompt)
        val_corrects = 0

        for (image, target) in test_dataloader:
            image = image.to(device)
            target = target.to(device)

            logits = model_inference(model,image,text_inputs)
            _, preds = torch.max(logits, 1)

            val_corrects += torch.sum(preds == target.data)
            
        val_acc = val_corrects/ len(test_set)
        ###############################################################


        print(f'Used prompt: \'{prompt}\'')
        print(f"The zero-shot performance on {dataset_name} is {val_acc*100:.2f}%, visual encoder is {VISUAL_BACKBONE}.")
        print("---------------------------------------------")

In [None]:
run_clip(prompt1)
run_clip(prompt2)
run_clip(prompt3)
run_clip(prompt4)
run_clip(prompt5)
run_clip(prompt6)
run_clip(prompt7)