# Exercise 03: CLIP zero-shot prediction
In this exercise, you will perform zero-shot prediction using CLIP.

### Basic Imports

In [39]:
import os
import time
import os.path as osp

import numpy as np
import pandas as pd

import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, auc, roc_auc_score, average_precision_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10
from torchvision import datasets
from torchvision import transforms
from torchvision import models
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from PIL import Image
from clip import clip

### Hyperparameters

In [40]:
# # random seed
# SEED = 1 
# NUM_CLASS = 10

# Training
BATCH_SIZE = 32
# NUM_EPOCHS = 30
# EVAL_INTERVAL=1
# SAVE_DIR = './log'

# # Optimizer
# LEARNING_RATE = 1e-1
# MOMENTUM = 0.9
# STEP=5
# GAMMA=0.5

# CLIP
VISUAL_BACKBONE = 'ViT-B/16' # RN50, ViT-B/32, ViT-B/16


### Device

In [41]:
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu" if torch.cuda.is_available() else "cpu")


### Dataset


In [79]:
def pothole_dataset():# define transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224,224),antialias=True),
    ])

    # loading dataset
    dataset = ImageFolder(root='/data/lab/STA303-Assignment02/data/data/', transform=transform)
    test_dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=2)
    class_names = ['normal', 'pothole']
    dataset_name = 'Pothole identification'
    return dataset,test_dataloader,class_names,dataset_name

In [80]:
def CIFAR10_dataset():
    transform_cifar10_test = transforms.Compose([
        transforms.Resize(size=224),
        transforms.CenterCrop(size=(224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    test_set = torchvision.datasets.CIFAR10(root='/shareddata', train=False,
                                       download=True, transform=transform_cifar10_test)
    test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=2)

    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    dataset_name = 'CIFAR10'
    return test_set,test_dataloader,class_names,dataset_name

In [81]:
def DTD_dataset():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224,224),antialias=True),
    ])
    test_set = torchvision.datasets.DTD(root='/shareddata',
                                       download=True, transform=transform)
    test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=47)
    class_names = ["banded", "blotchy", "braided", "bubbly", "bumpy", "chequered", "cobwebbed", "cracked", "crosshatched", "crystalline", "dotted", "fibrous", "flecked", "freckled", "frilly", "gauzy", "grid", "grooved", "honeycombed", "interlaced", "knitted", "lacelike", "lined", "marbled", "matted", "meshed", "paisley", "perforated", "pitted", "pleated", "polka-dotted", "porous", "potholed", "scaly", "smeared", "spiralled", "sprinkled", "stained", "stratified", "striped", "studded", "swirly", "veined", "waffled", "woven", "wrinkled", "zigzagged"]
    dataset_name = 'DTD'
    return test_set, test_dataloader,class_names,dataset_name

In [83]:
test_dataset,test_dataloader,class_names,dataset_name = CIFAR10_dataset() # pothole_dataset(), CIFAR10_dataset(), DTD_dataset()

Files already downloaded and verified


### Model

In [84]:
# Load the model
model, preprocess = clip.load(name=VISUAL_BACKBONE, device=device, download_root='/shareddata/clip/')
model.to(device)

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

### Task 1: Prompt Gereration
---

Please denfine a function named ``prompt_encode`` to encode the text using CLIP text encoder.


In [85]:

prompt = 'a photo of a' # you can try different prompt

def prompt_encode(prompt):
    """
    Args:
        prompt (str): the text prefix before the class

    Returns:
        text_inputs(torch.Tensor)

    """
    ##################### Write your answer here ##################
    # 使用CLIP模型编码提示
    text_inputs = torch.cat([clip.tokenize(f"{prompt}{c}") for c in class_names]).to(device)
    ###############################################################
    
    return text_inputs


### Task 2: Zero-shot inference
---

Please denfine a function named ``model_inference``. The function is essential for training and evaluating machine learning models using batched data from dataloaders.

**To do**: 
1. Encode the image.
2. Encode the text.
3. Calculate the logits.

In [86]:

def model_inference(model, image, text_inputs):

    ##################### Write your answer here ##################
    image_features = model.encode_image(image)
    text_features = model.encode_text(text_inputs)
    
    image_features /= image_features.norm(dim = -1, keepdim = True)
    text_features /= text_features.norm(dim = -1, keepdim = True)
    
    logit_scale = model.logit_scale.exp()
    
    logits = logit_scale * image_features @ text_features.t()
    ###############################################################

    return logits

### Task 3: Zero-shot accuracy calculation
---

In [None]:
target_clip_list = []
testing_loss = []
testing_acc = []
all_clip_predictions = []
with torch.no_grad():
    model.eval()

    ##################### Write your answer here ##################
    val_loss = 0.0
    val_corrects = 0
    
    for batch_idx, (image, target) in enumerate(test_dataloader):
        image = image.to(device)
        target = target.to(device)
        for targ in target:
            target_clip_list.append(targ.item())
        text_inputs = prompt_encode(prompt)
        #test model
        logits = model_inference(model, image, text_inputs)
        _, preds = torch.max(logits, 1)
        for pred in preds:
            all_clip_predictions.append(pred.item())
        val_corrects += torch.sum(preds == target.data)
    val_acc = val_corrects.double() / len(test_dataset)
     ###############################################################
    print(f"the zero-shot performance on {dataset_name} is {val_acc*100:.2f}%, visual encoder is {VISUAL_BACKBONE}.")




### self-defined baseline model

***WARNING: *** Change output dimensions to correct categories, if you use Pothole dataset, set 2; if you use CIFAR10 dataset, set 10; if you use DTD dataset, set 47.

In [None]:
# Load the pre-trained ResNet50,VGG model
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) # IMAGENET1K_V1, IMAGENET1K_V2

# Replace the last classification layer with a custom linear layer (e.g., changing output dimensions to two categories)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 10)  # Change output dimensions to correct categories
model.eval()  # Set to evaluation mode

all_baseline_predictions = []
correct = 0
total = 0

# Assuming you have a DataLoader for image data, iterate through it for prediction
for images, targets in test_dataloader:
    # Preprocess each image
    with torch.no_grad():
        # Assuming images represent a batch of image data
        outputs = model(images)
        # For binary classification, typically use the Sigmoid or Softmax function on model outputs
        probabilities = torch.sigmoid(outputs)  # Use Sigmoid to convert outputs to probabilities
        predicted_classes = torch.argmax(probabilities, dim=1).int()
        for pred in predicted_classes:
            all_baseline_predictions.append(pred.item())
        # Calculate accuracy
    
        correct += (predicted_classes == targets).sum().item()
        total += targets.size(0)

# Calculate accuracy
accuracy = correct / total
print(f"Baseline model Accuracy: {accuracy * 100:.2f}%")

### visualize confusion matrix

In [None]:
clip_conf_matrix = confusion_matrix(target_clip_list, all_clip_predictions)
baseline_conf_matrix = confusion_matrix(target_clip_list, all_baseline_predictions)



plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
sns.heatmap(clip_conf_matrix, annot=True, fmt="d", cmap="Blues")
plt.title('CLIP model confusion matrix')

plt.subplot(1, 2, 2)
sns.heatmap(baseline_conf_matrix, annot=True, fmt="d", cmap="Blues")
plt.title('baseline model confusion matrix')

plt.tight_layout()
plt.show()

### Draw ROC curve

***Warning: *** Just work for binary classification

In [16]:
fpr_baseline, tpr_baseline, _ = roc_curve(target_clip_list, all_baseline_predictions)
roc_auc_baseline = auc(fpr_baseline, tpr_baseline)
fpr_clip, tpr_clip, _ = roc_curve(target_clip_list, all_clip_predictions)
roc_auc_clip = auc(fpr_clip, tpr_clip)

plt.figure()
plt.plot(fpr_baseline, tpr_baseline, color='blue', lw=2, label=f'Baseline model ROC curve (AUC = {roc_auc_baseline})')
plt.plot(fpr_clip, tpr_clip, color='red', lw=2, label=f'CLIP model ROC curve (AUC = {roc_auc_clip})')

plt.plot([0, 1], [0, 1], color='gray', linestyle='--', lw=2, label='random prediction')

plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title('ROC curve')
plt.legend()
plt.show()


NameError: name 'roc_auc_curve' is not defined