In [None]:
import clip
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score

In [None]:
# 加载CLIP模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# 加载CIFAR-100数据集
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # CLIP需要224x224输入
    transforms.ToTensor(),
    transforms.Normalize(
        (0.48145466, 0.4578275, 0.40821073), 
        (0.26862954, 0.26130258, 0.27577711)
    )  # CLIP的归一化参数
])

# train_set = torchvision.datasets.CIFAR100(root='./data', train=True, download=False, transform=transform)
test_set = torchvision.datasets.CIFAR100(root='./data', train=False, download=False, transform=transform)

# train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

In [None]:
# 获取CIFAR-100的类别名称
cifar100_classes = test_set.classes

# 为CLIP创建文本提示
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100_classes]).to(device)

In [None]:
def evaluate(model, test_loader, text_inputs):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            images = images.to(device)
            labels = labels.to(device)
            
            # 计算图像特征
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            
            # 计算文本特征
            text_features = model.encode_text(text_inputs)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            
            # 计算相似度
            similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            predictions = similarity.argmax(dim=-1)
               
            all_preds.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    accuracy = accuracy_score(all_labels, all_preds)
    
    return accuracy

In [None]:
acc = evaluate(model, test_loader, text_inputs)
print('accuracy:', acc)