In [1]:
import torch
from PIL import Image
from transformers import AutoModel, AutoImageProcessor
from transformers import CLIPModel, AutoProcessor
import faiss
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor_dino = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model_dino = AutoModel.from_pretrained('facebook/dinov2-base').to(device)

processor_clip = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
model_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)

images = []
for root, dirs, files in os.walk('/home/gunubansal129/CS/yolov8/data/images/train'):
    for file in files:
        if file.endswith('.jpg') or file.endswith('.png'):
            images.append(os.path.join(root, file))

  from .autonotebook import tqdm as notebook_tqdm
  return self.fget.__get__(instance, owner)()


In [2]:
torch.cuda.is_available()

True

In [3]:
# Define a dataset class
class ImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, img_path

In [4]:
transform = transforms.Compose([
    # transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = ImageDataset(images, transform=transform)
data_loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)

In [5]:
def add_vector_to_index(embedding, index):
    #convert embedding to numpy
    vector = embedding.detach().cpu().numpy()
    #Convert to float32 numpy
    vector = np.float32(vector)
    #Normalize vector: important to avoid wrong results when searching
    faiss.normalize_L2(vector)
    #Add to index
    index.add(vector)

def extract_features_dino(image):
    with torch.no_grad():
        inputs = processor_dino(images=image, return_tensors="pt").to(device)
        outputs = model_dino(**inputs)
        image_features = outputs.last_hidden_state
        return image_features.mean(dim=1)
    
def extract_features_clip(image):
    with torch.no_grad():
        inputs = processor_clip(images=image, return_tensors="pt").to(device)
        image_features = model_clip.get_image_features(**inputs)
        return image_features

index_dino = faiss.IndexFlatL2(768)
index_clip = faiss.IndexFlatL2(512)

for image in images:
    img = Image.open(image).convert('RGB')
    dino_features = extract_features_dino(img)
    add_vector_to_index(dino_features, index_dino)
    clip_features = extract_features_clip(img)
    add_vector_to_index(clip_features, index_clip)


faiss.write_index(index_dino, 'index_dino.index')
faiss.write_index(index_clip, 'index_clip.index')

In [6]:
source = "/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_41_P_439.jpg"
img = Image.open(source).convert('RGB')

with torch.no_grad():
    inputs = processor_dino(images=img, return_tensors="pt").to(device)
    outputs = model_dino(**inputs)
    image_features = outputs.last_hidden_state
    image_features_dino = image_features.mean(dim=1)

with torch.no_grad():
    inputs = processor_clip(images=img, return_tensors="pt").to(device)
    image_features_clip = model_clip.get_image_features(**inputs)

def normalizeL2(embeddings):
    vector = embeddings.detach().cpu().numpy()
    vector = np.float32(vector)
    faiss.normalize_L2(vector)
    return vector

input_features_dino = normalizeL2(image_features_dino)
index_dino = faiss.read_index('index_dino.index')

input_features_clip = normalizeL2(image_features_clip)
index_clip = faiss.read_index('index_clip.index')

D_dino, I_dino = index_dino.search(input_features_dino, 10)
D_clip, I_clip = index_clip.search(input_features_clip, 5)

In [7]:
# show results
for i in range(5):
    print(f"Image {I_dino[0][i]} with distance {D_dino[0][i]}")
    print(images[I_dino[0][i]])
    # img = Image.open(images[I_dino[0][i]]).convert('RGB')
    # img.show()

Image 666 with distance 0.0
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_41_P_439.jpg
Image 230 with distance 0.07882020622491837
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_40_P_438.jpg
Image 189 with distance 0.08717029541730881
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_39_P_437.jpg
Image 164 with distance 0.12229707092046738
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_27_P_433.jpg
Image 361 with distance 0.1292014867067337
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_34_P_434.jpg


In [8]:
for i in range(5):
    print(f"Image {I_clip[0][i]} with distance {D_clip[0][i]}")
    print(images[I_clip[0][i]])
    # img = Image.open(images[I_clip[0][i]]).convert('RGB')
    # img.show()

Image 666 with distance 0.0
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_41_P_439.jpg
Image 230 with distance 0.03782706335186958
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_40_P_438.jpg
Image 361 with distance 0.06143786013126373
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_34_P_434.jpg
Image 164 with distance 0.07154563814401627
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_27_P_433.jpg
Image 1084 with distance 0.07210373878479004
/home/gunubansal129/CS/yolov8/data/images/train/ind_raja_12_2013_001_P004_78285944_29961528_C004C_16_01_2014_14_12_36_P_436.jpg
