In [85]:
from PIL import Image
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel,CLIPModel,CLIPConfig
import os
import torch
from tqdm import tqdm
import numpy as np
import random

In [86]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [87]:
base_model = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
model = CLIPModel.from_pretrained(
    "output_model/epoch_1", torch_dtype=torch.float16
).to(device)
# 使用pretrained
# model = CLIPModel.from_pretrained(
#     base_model, torch_dtype=torch.float16
# ).to(device) 
processor = CLIPProcessor.from_pretrained(
    base_model, torch_dtype=torch.float16
)
model.eval()
model_vision = model.vision_model
visual_projection = model.visual_projection
model_text = model.text_model
text_projection = model.text_projection

cos = torch.nn.CosineSimilarity()


In [88]:
path ="topic2_release/test"
animal_type = os.listdir(path)
animal_type.sort()
animal_type_prefix = [f"a photo of {animal}" for animal in animal_type]
images_test = []
images_pil_list = []
for i in tqdm(range(len(animal_type))):
    animal = animal_type[i]
    image_list = os.listdir(f"{path}/{animal}")
    image_list.sort()
    images = [Image.open(f"{path}/{animal}/{file}") for file in image_list]
    images_pil_list.extend(images)
    inputs = processor(
        text=animal_type_prefix,
        images=images,
        return_tensors="pt",
        padding=True,
    )
    # 節省VRAM，每十張inference一次
    for i in range(0,inputs["pixel_values"].shape[0],10):
        pixel_values = inputs["pixel_values"][i:i+10].to(device)
        images = model_vision(pixel_values=pixel_values)
        images_test.append(visual_projection(images.pooler_output).cpu().detach())
database = torch.cat(images_test)

100%|██████████| 10/10 [00:24<00:00,  2.49s/it]


In [89]:
top_k = [10, 20, 50, 100]
precision_list = []
for top_num in top_k:
    for i in range(0, 1000, 100):
        precision = []
        for j in range(0, 100):
            target = database[i + j]
            retrieved = cos(target.unsqueeze(0), database).topk(top_num)
            true_num = torch.logical_and(
                retrieved.indices >= i, retrieved.indices < i + 100
            ).sum()
            precision.append(((true_num - 1) / (top_num - 1)).item())
        precision_list.append(round(sum(precision) / 100, 5))
for i in range(len(top_k)):
    print(
        f"topk={top_k[i]}: {precision_list[i*10:i*10+10]} avg: {round(sum(precision_list[i*10:i*10+10])/10,5)}"
    )

topk=10: [1.0, 0.99778, 0.99, 0.99111, 1.0, 0.99889, 1.0, 0.99778, 0.99, 1.0] avg: 0.99656
topk=20: [0.99947, 0.99789, 0.99053, 0.99263, 1.0, 0.99947, 1.0, 0.99842, 0.99, 0.99947] avg: 0.99679
topk=50: [0.99122, 0.99755, 0.99143, 0.99327, 0.9998, 0.99878, 1.0, 0.99898, 0.99, 0.99959] avg: 0.99606
topk=100: [0.97939, 0.98545, 0.98283, 0.98242, 0.9896, 0.95626, 0.99303, 0.97061, 0.97596, 0.99798] avg: 0.98135
