In [9]:
import torch
from PIL import Image
from transformers import AutoModel, AutoImageProcessor
from transformers import CLIPModel, AutoProcessor
import faiss, pickle
import os
import numpy as np


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor_dino = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model_dino = AutoModel.from_pretrained('facebook/dinov2-base').to(device)

processor_clip = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
model_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)

images = []
for root, dirs, files in os.walk('/home/gunubansal129/CS/yolov8/data/images/train'):
    for file in files:
        if file.endswith('.jpg') or file.endswith('.png'):
            images.append(os.path.join(root, file))

In [10]:
def add_vector_to_index(embedding, index):
    #convert embedding to numpy
    vector = embedding.detach().cpu().numpy()
    #Convert to float32 numpy
    vector = np.float32(vector)
    #Normalize vector: important to avoid wrong results when searching
    faiss.normalize_L2(vector)
    #Add to index
    index.add(vector)

def extract_features_dino(image):
    with torch.no_grad():
        inputs = processor_dino(images=image, return_tensors="pt").to(device)
        outputs = model_dino(**inputs)
        image_features = outputs.last_hidden_state
        return image_features.mean(dim=1)
    
def extract_features_clip(image):
    with torch.no_grad():
        inputs = processor_clip(images=image, return_tensors="pt").to(device)
        image_features = model_clip.get_image_features(**inputs)
        return image_features

index_dino = faiss.IndexFlatL2(768)
index_clip = faiss.IndexFlatL2(512)

for image in images:
    img = Image.open(image).convert('RGB')
    dino_features = extract_features_dino(img)
    add_vector_to_index(dino_features, index_dino)
    clip_features = extract_features_clip(img)
    add_vector_to_index(clip_features, index_clip)


faiss.write_index(index_dino, 'index_dino.index')
faiss.write_index(index_clip, 'index_clip.index')

In [20]:
source = "/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_41_P_439.jpg"
img = Image.open(source).convert('RGB')

with torch.no_grad():
    inputs = processor_dino(images=img, return_tensors="pt").to(device)
    outputs = model_dino(**inputs)
    image_features = outputs.last_hidden_state
    image_features_dino = image_features.mean(dim=1)

with torch.no_grad():
    inputs = processor_clip(images=img, return_tensors="pt").to(device)
    image_features_clip = model_clip.get_image_features(**inputs)

def normalizeL2(embeddings):
    vector = embeddings.detach().cpu().numpy()
    vector = np.float32(vector)
    faiss.normalize_L2(vector)
    return vector

input_features_dino = normalizeL2(image_features_dino)
index_dino = faiss.read_index('index_dino.index')

input_features_clip = normalizeL2(image_features_clip)
index_clip = faiss.read_index('index_clip.index')

D_dino, I_dino = index_dino.search(input_features_dino, 10)
D_clip, I_clip = index_clip.search(input_features_clip, 5)

In [21]:
# show results
for i in range(5):
    print(f"Image {I_dino[0][i]} with distance {D_dino[0][i]}")
    print(images[I_dino[0][i]])
    # img = Image.open(images[I_dino[0][i]]).convert('RGB')
    # img.show()

Image 666 with distance 0.0
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_41_P_439.jpg
Image 230 with distance 0.07882018387317657
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_40_P_438.jpg
Image 189 with distance 0.08717025816440582
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_39_P_437.jpg
Image 164 with distance 0.1222972497344017
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_27_P_433.jpg
Image 361 with distance 0.12920145690441132
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_34_P_434.jpg


In [22]:
for i in range(5):
    print(f"Image {I_clip[0][i]} with distance {D_clip[0][i]}")
    print(images[I_clip[0][i]])
    # img = Image.open(images[I_clip[0][i]]).convert('RGB')
    # img.show()

Image 666 with distance 0.0
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_41_P_439.jpg
Image 230 with distance 0.03782675787806511
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_40_P_438.jpg
Image 361 with distance 0.06143813207745552
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_34_P_434.jpg
Image 164 with distance 0.07154545933008194
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_27_P_433.jpg
Image 1084 with distance 0.07210376113653183
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_36_P_436.jpg
