# CLIP zero-shot prediction
## Basic import

In [1]:
import os
import time
import os.path as osp

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision import datasets
from torchvision import transforms
import torchvision

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from PIL import Image
from clip import clip

## Hyperparameters

In [2]:
# # random seed
# SEED = 1 
# NUM_CLASS = 10

# Training
BATCH_SIZE = 128
# NUM_EPOCHS = 30
# EVAL_INTERVAL=1
# SAVE_DIR = './log'

# # Optimizer
# LEARNING_RATE = 1e-1
# MOMENTUM = 0.9
# STEP=5
# GAMMA=0.5

# CLIP
VISUAL_BACKBONE = 'RN50' # RN50, ViT-B/32, ViT-B/16


## Device

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Model

### Clip

In [4]:
# Load the model
model, preprocess= clip.load(name=VISUAL_BACKBONE, device=device, download_root='/clip/')
model.to(device)
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [5]:
# 数据转换管道
transform_mnist_test = transforms.Compose([
    transforms.Resize(size=224),  # 可以根据需要调整大小
    transforms.Grayscale(num_output_channels=3),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))  # 使用MNIST的均值和标准差
])

# 测试数据集
# test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform_mnist_test)

test_set = torchvision.datasets.MNIST(root='/data/dataset/', train=False,
                                       download=True, transform=transform_mnist_test)  # 修改此行

# 数据加载器
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# 类别标签（MNIST有10个类别，数字0到9）
class_names = [str(i) for i in range(10)]

# 数据集名称
dataset_name = 'MNIST'
print(class_names)

['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']


In [6]:
prompt = 'a number: ' # you can try different prompt

text_inputs = torch.cat([clip.tokenize(f"{prompt} {c}") for c in class_names]).to(device)


In [7]:
def model_inference(model, image, text_inputs):
    """
    Args:
        model (torch.nn.Module): 训练好的机器学习模型
        image (torch.Tensor): 输入图像张量
        text_inputs (torch.Tensor): 输入文本张量

    Returns:
        logits (torch.Tensor): 模型的预测logits
    """
    image_features = model.encode_image(image)
    text_features = model.encode_text(text_inputs)

    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    logit_scale = model.logit_scale.exp()

    logits = logit_scale * image_features @ text_features.t()
    
    return logits



In [8]:
testing_loss = []
testing_acc = []

with torch.no_grad():
    model.eval()  
    ##################### Write your answer here ##################
    correct_predictions = 0
    
    for batch,(image, class_id) in enumerate(test_dataloader):
       
        image=image.to(device)
        class_id=class_id.to(device)
        
        logits = model_inference(model,image,text_inputs)
        
        _ , predictions = torch.max(logits,1)
#         print(predictions,class_id)
        correct_predictions += torch.sum(predictions == class_id.data)
    
    val_acc = correct_predictions.double() / len(test_set)
    
    ###############################################################

    print(f"the zero-shot performance on {dataset_name} is {val_acc*100:.2f}%, visual encoder is {VISUAL_BACKBONE}.")

the zero-shot performance on MNIST is 40.63%, visual encoder is RN50.
