In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
import random
import data_utils 

In [2]:
# set seed
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # 如果使用CUDA
np.random.seed(seed)
random.seed(seed)

In [3]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

model_name = "clip"
model, preprocess = data_utils.get_model(model_name, device)

Successfully load CLIP-ViT-L/14


In [9]:
import importlib
importlib.reload(data_utils)
dataset_name = "awa2"
data = data_utils.get_data(dataset_name, preprocess, get_attr=False)
dataloader = DataLoader(data, batch_size=512, shuffle=False)

In [5]:
print(len(data))

37322


In [6]:
class_names = data.classes['class_name'].tolist()
if dataset_name == "cub":
    # generalized zero-shot learning
    clean_cls = [name.split(".")[1] for name in class_names]
    #TODO: maybe not that cleaned due to "_"
    print(clean_cls)
elif dataset_name == "awa2":
    clean_cls = [name.replace("+", " ") for name in class_names]
    print(clean_cls)
else:
    clean_cls = class_names
    print(clean_cls)

['antelope', 'grizzly bear', 'killer whale', 'beaver', 'dalmatian', 'persian cat', 'horse', 'german shepherd', 'blue whale', 'siamese cat', 'skunk', 'mole', 'tiger', 'hippopotamus', 'leopard', 'moose', 'spider monkey', 'humpback whale', 'elephant', 'gorilla', 'ox', 'fox', 'sheep', 'seal', 'chimpanzee', 'hamster', 'squirrel', 'rhinoceros', 'rabbit', 'bat', 'giraffe', 'wolf', 'chihuahua', 'rat', 'weasel', 'otter', 'buffalo', 'zebra', 'giant panda', 'deer', 'bobcat', 'pig', 'lion', 'mouse', 'polar bear', 'collie', 'walrus', 'raccoon', 'cow', 'dolphin']


In [7]:
from tqdm import tqdm
import clip

def zs_predict(model_name, model, dataloader, class_names, device):
    """
    Set model to eval and evaluate zero-shot classification acc,
    return predic labels based on given class_names
    """
    model.eval()
    all_values = [] # similarity values
    all_preds = [] # predicted labels
    all_labels = [] # ground truth labels

    with torch.no_grad():
        if "cvcl" in model_name:
            txt_tokens = [model.tokenize(f"{c}") for c in class_names]
            txt_input  = torch.cat([txt[0] for txt in txt_tokens]).to(device)
            txt_len = torch.cat([txt[1] for txt in txt_tokens]).to(device)
            txt_feature = model.encode_text(txt_input, txt_len)
        
        elif "clip" in model_name:
            txt_input = torch.cat([clip.tokenize(f"a photo of {c}") for c in class_names]).to(device)
            txt_feature = model.encode_text(txt_input)
            
        txt_feature /= txt_feature.norm(dim=-1, keepdim=True)

        for img, label in tqdm(dataloader, desc="Evaluating"):
            imgs = img.to(device)
            labels = label.to(device)
            

            #! which is a more arbitrary way
            # without considering normalize feature length
            # similarity = torch.matmul(image_features, text_features.T)
            # preds = similarity.argmax(dim=1)

            img_feature = model.encode_image(imgs)
            img_feature /= img_feature.norm(dim=-1, keepdim=True)
            txt_feature /= txt_feature.norm(dim=-1, keepdim=True)
            similarity = (100.0 * img_feature @ txt_feature.T).softmax(dim=-1)
            indices = similarity.argmax(dim=-1)
            # values, indices = similarity[0].topk(1)

            # all_values.extend(values.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(indices.cpu().numpy())
            # for index in indices:
            #     all_preds.append(class_names[index.item()]) 

    return all_values, all_preds, all_labels

In [10]:
values, predictions, gt_labels = zs_predict(model_name, model, dataloader, clean_cls, device)

Evaluating: 100%|██████████| 73/73 [26:38<00:00, 21.89s/it]


In [11]:
print(predictions[:100])
print(gt_labels[:100])

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [13]:
# Calculate performance metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

accuracy = accuracy_score(gt_labels, predictions)
precision = precision_score(gt_labels, predictions, average='macro')
recall = recall_score(gt_labels, predictions, average='macro')
f1 = f1_score(gt_labels, predictions, average='macro')

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

Accuracy: 0.9400
Precision: 0.9328
Recall: 0.9226
F1 Score: 0.9122
