In [1]:
import sys
sys.path.append("..")

import torch
import pandas as pd
from PIL import Image
from tqdm import tqdm

from src.inference import ClipInference

In [2]:
clip = ClipInference("microsoft/swin-tiny-patch4-window7-224", "cointegrated/rubert-tiny2", 128)

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([128]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([128, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cointegrated/rubert-tiny2 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extra

In [3]:
clip.load("/data/sporkhun/TinyClip/logs/train-swin-rubert-distil/runs/2023-10-21_04-25-16/checkpoints/last.ckpt", lightning=True)

<All keys matched successfully>

In [4]:
clip.clip.to("cuda:0")

Clip(
  (image_encoder): SwinForImageClassification(
    (swin): SwinModel(
      (embeddings): SwinEmbeddings(
        (patch_embeddings): SwinPatchEmbeddings(
          (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
        )
        (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): SwinEncoder(
        (layers): ModuleList(
          (0): SwinStage(
            (blocks): ModuleList(
              (0-1): 2 x SwinLayer(
                (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
                (attention): SwinAttention(
                  (self): SwinSelfAttention(
                    (query): Linear(in_features=96, out_features=96, bias=True)
                    (key): Linear(in_features=96, out_features=96, bias=True)
                    (value): Linear(in_features=96, out_features=96, bias=True)
                    (dropout): Dropout(p=0.0, inplace=Fals

In [5]:
eval_df = pd.read_csv("/data/sporkhun/TinyClip/data/eval/caltech_101.tsv", sep="\t", index_col=0)
eval_df.label = eval_df.label.apply(lambda _: _.split("/")[-2])
index = eval_df.image.apply(lambda _: len(Image.open(_).mode) == 3)
eval_df = eval_df[index]
eval_df.head()

Unnamed: 0,image,label
0,/data/sporkhun/clip/caltech_101/caltech101/101...,Faces
1,/data/sporkhun/clip/caltech_101/caltech101/101...,Faces
2,/data/sporkhun/clip/caltech_101/caltech101/101...,Faces
3,/data/sporkhun/clip/caltech_101/caltech101/101...,Faces
4,/data/sporkhun/clip/caltech_101/caltech101/101...,Faces


In [6]:
batch_size = 120

batch_images = []
embedding_images = []

for i, row in tqdm(eval_df.iterrows()):
    if len(batch_images) > batch_size:
        embedding_images.append(clip.forward_image(images=batch_images).detach().cpu())
        batch_images = []
    batch_images.append(Image.open(row.image))
if batch_images:
    embedding_images.append(clip.forward_image(images=batch_images).detach().cpu())

batch_texts = []
embedding_texts = []
for text in tqdm(eval_df.label.unique().tolist()):
    if len(batch_texts) > batch_size:
        embedding_texts.append(clip.forward_text(texts=batch_texts).detach().cpu())
    batch_texts.append(text)
if batch_texts:
    embedding_texts.append(clip.forward_text(texts=batch_texts).detach().cpu())

8296it [01:14, 111.86it/s]
100%|██████████| 100/100 [00:00<00:00, 1001027.21it/s]


In [7]:
assert len(embedding_texts) == 1
label2index = {l: i for i, l in enumerate(batch_texts)}
all_text_emb = embedding_texts[0]
all_text_emb.shape

torch.Size([100, 128])

In [8]:
all_img_emb = torch.cat(embedding_images, dim=0)
all_img_emb.shape

torch.Size([8296, 128])

In [9]:
predicts = clip.predict(all_img_emb, all_text_emb).detach().cpu()

In [10]:
predicts

tensor([93, 58, 93,  ..., 19, 18, 13])

In [11]:
label = eval_df.label.apply(lambda _: label2index[_]).tolist()

In [12]:
sum(p == l for p, l in zip(predicts.tolist(), label)) / len(label)

0.47866441658630665