In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

In [3]:
from torchvision.datasets import MNIST

dataset = MNIST(root="./dataset", download=True)

In [7]:
label_prompts = [
    f"This is a handwritten image of the digit {i}."
    for i in range(10)
]

In [9]:
from tqdm.auto import tqdm

predicted = []
actual = []

for idx, data in tqdm(enumerate(dataset), total=len(dataset)):
    image, label = data
    actual.append(label)
    
    image = preprocess(image).unsqueeze(0).to(device)
    text = clip.tokenize(label_prompts).to(device)
    
    
    with torch.no_grad():    
        logits_per_image, _ = model(image, text)
        
        
    probs = logits_per_image.softmax(dim=-1).cpu()
    predicted.append(torch.argmax(probs, dim=-1))



  0%|          | 0/60000 [00:00<?, ?it/s]

In [13]:
predicted = [int(predicted[i]) for i in range(len(predicted))]

In [15]:
from sklearn.metrics import classification_report

report = classification_report(y_pred=predicted, y_true=actual)
print(report)

              precision    recall  f1-score   support

           0       0.56      0.97      0.71      5923
           1       0.44      0.93      0.60      6742
           2       0.33      0.00      0.00      5958
           3       0.32      0.64      0.43      6131
           4       0.64      0.06      0.11      5842
           5       0.90      0.39      0.55      5421
           6       0.30      0.14      0.20      5918
           7       0.38      0.91      0.54      6265
           8       0.46      0.19      0.27      5851
           9       1.00      0.00      0.01      5949

    accuracy                           0.44     60000
   macro avg       0.53      0.42      0.34     60000
weighted avg       0.53      0.44      0.34     60000

