In [1]:
from PIL import Image
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel,CLIPModel,CLIPConfig
import os
import torch
from tqdm import tqdm
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
base_model = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
model = CLIPModel.from_pretrained(
    "output_model/epoch_1", torch_dtype=torch.float16
).to(device)
# 使用pretrained
# model = CLIPModel.from_pretrained(
#     base_model, torch_dtype=torch.float16
# ).to(device) 
processor = CLIPProcessor.from_pretrained(
    base_model, torch_dtype=torch.float16
)
model.eval()
model_vision = model.vision_model
visual_projection = model.visual_projection
model_text = model.text_model
text_projection = model.text_projection

cos = torch.nn.CosineSimilarity()


In [4]:
path ="topic2_release/test"
animal_type = os.listdir(path)
animal_type.sort()
animal_type_prefix = [f"a photo of {animal}" for animal in animal_type]
images_test = []
images_pil_list = []
for i in tqdm(range(len(animal_type))):
    animal = animal_type[i]
    image_list = os.listdir(f"{path}/{animal}")
    image_list.sort()
    images = [Image.open(f"{path}/{animal}/{file}") for file in image_list]
    images_pil_list.extend(images)
    inputs = processor(
        text=animal_type_prefix,
        images=images,
        return_tensors="pt",
        padding=True,
    )
    for i in range(0,inputs["pixel_values"].shape[0],10):
        pixel_values = inputs["pixel_values"][i:i+10].to(device)
        images = model_vision(pixel_values=pixel_values)
        images_test.append(visual_projection(images.pooler_output).cpu().detach())
images_test = torch.cat(images_test)

100%|██████████| 10/10 [00:31<00:00,  3.16s/it]


In [6]:
database = []
query = []
for i in range(0, 1000, 100):
    database.append(images_test[i : i + 30])
    query.append(images_test[i + 30 : i + 100])
database = torch.cat(database)
query = torch.cat(query)

In [56]:
precision_list = []
recall_list = []
for i in range(10):
    for j in range(70):
        top_num = 30
        result = cos(query[i * 70 + j].unsqueeze(0), database).topk(top_num)
        retrieved_num = top_num
        retrieved_num_true = ((result.indices // 30) == i).sum()
        precision = retrieved_num_true / retrieved_num
        recall = retrieved_num_true / 30
        precision_list.append(precision.item())
        recall_list.append(recall.item())

In [57]:
[np.array(precision_list[i:i+70]).mean() for i in range(0,700,70)]

[0.9628571186746869,
 0.9928571428571429,
 0.9895238097224917,
 0.9899999993188041,
 0.9995238091264452,
 0.9980952365057809,
 0.9933333277702332,
 0.9895238058907645,
 0.9919047577040536,
 0.997619047335216]

In [58]:
[np.array(recall_list[i:i+70]).mean() for i in range(0,700,70)]

[0.9628571186746869,
 0.9928571428571429,
 0.9895238097224917,
 0.9899999993188041,
 0.9995238091264452,
 0.9980952365057809,
 0.9933333277702332,
 0.9895238058907645,
 0.9919047577040536,
 0.997619047335216]

In [24]:
model

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 1024)
      (position_embedding): Embedding(77, 1024)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-23): 24 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          )
          (layer_norm2): LayerNorm((1024,), e