In [17]:
import torch
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from sklearn.metrics import accuracy_score
import os

In [16]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

In [18]:
def load_dataset(dataset_name, batch_size):
    if dataset_name == "CIFAR10":
        transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
        dataset = CIFAR10(root="./", download=True, train=False, transform=transform)
    elif dataset_name == "CIFAR100":
        transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
        dataset = CIFAR100(root="./", download=True, train=False, transform=transform)
    else:
        transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
        dataset = ImageFolder(root=f"data/{dataset_name}", transform=transform)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader, dataset.classes


In [19]:
def evaluate_clf(model, dataloader, class_labels):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            images, labels = batch
            inputs = processor(text=class_labels, images=images, return_tensors="pt", padding=True)
            outputs = model(**inputs)
            logits_per_image = outputs.logits_per_image
            probs = logits_per_image.softmax(dim=1)
            preds = torch.argmax(probs, dim=1).tolist()
            all_preds.extend(preds)
            all_labels.extend(labels.tolist())

    return all_labels, all_preds

In [None]:
dataset_name = "CIFAR10"
batch_size = 32

dataloader, class_labels = load_dataset(dataset_name, batch_size)
true_labels, predicted_labels = evaluate_clf(model, dataloader, class_labels)
accuracy = accuracy_score(true_labels, predicted_labels)

print(f"Accuracy on {dataset_name}: {accuracy}")