# Import Parts

In [2]:
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
import torch # numpy on steroid
import os
from PIL import Image # working with image
import faiss
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


# Code

In [3]:
# Prepare image dataset
img_dir = "img_dataset"
imgs_from_paths = [os.path.join(img_dir, fname) for fname in os.listdir(img_dir)]

In [4]:
# CLIP model
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") # for prepare data
model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14") # for compute

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [5]:
# Image Encoding
def encode_image(imgs_from_paths):
    img = Image.open(imgs_from_paths).convert("RGB")
    inputs = processor(images=img, return_tensors="pt")
    with torch.no_grad(): # get fetures from image
        img_features = model.get_image_features(**inputs) # array of vector
    return img_features / img_features.norm(p=2) # normalization -> have to study more

In [6]:
# List of vector images
img_features_vectors = [encode_image(p).squeeze().cpu().numpy() for p in imgs_from_paths]

In [7]:
# Vector database
dimention = len(img_features_vectors[0])
index = faiss.IndexFlatL2(dimention)
print(index.is_trained)
index.add(np.array(img_features_vectors))
print(index.ntotal)

True
652


In [None]:
# Embeding input
query = ""
input = processor(text=query, return_tensors="pt")
with torch.no_grad():
    text_features = model.get_text_features(**input)
text_features /= text_features.norm(dim=-1, keepdim=True)

In [9]:
# Searching
D, I = index.search(text_features.cpu().numpy(), k=3)
for i, idx in enumerate(I[0]):
    print(f"Rank {i+1}: {img_features_vectors[idx]} (score: {D[0][i]:.4f})")

Rank 1: [ 6.97435960e-02 -1.55559191e-02 -1.11473454e-02 -3.15190218e-02
 -2.05254517e-02  2.55786162e-02  2.54616160e-02 -6.70427782e-03
  5.80891734e-03  2.53205895e-02  1.94734838e-02 -3.09916753e-02
  3.36356438e-03  3.48338559e-02 -2.51733717e-02 -8.61425418e-03
 -3.87007967e-02  4.57625911e-02  3.83288338e-04 -2.75883060e-02
  2.50448240e-03  7.84661807e-03 -9.04924818e-04 -1.68502871e-02
  5.39425109e-03  1.28472680e-02  7.00249942e-03  3.32904384e-02
  1.59037244e-02  2.86970884e-02 -4.01621982e-02  2.63126083e-02
  4.85955067e-02 -9.74398386e-03 -4.07596957e-03 -4.55740141e-03
  3.01709101e-02  5.69471112e-03 -3.06277582e-03  4.32693884e-02
  1.22619616e-02 -4.21024449e-02  1.68086949e-03  1.37382848e-02
  2.45383866e-02  1.70262121e-02 -2.99997311e-02 -8.54252651e-03
 -3.59209739e-02  1.82354916e-02  3.72818147e-04 -1.64147373e-02
 -1.12815702e-03  2.16382146e-02 -4.31480929e-02  3.21022351e-03
 -1.38069969e-02  1.00049824e-02 -2.94923894e-02  1.91061981e-02
  2.85476446e-02 