### Imports and Initialization

In [None]:
!pip install chromadb
!pip install numpy pandas transformers
!pip install pillow tqdm torch annoy datasets

In [None]:
!pip install -U huggingface_hub

In [28]:
from huggingface_hub import snapshot_download, login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [111]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [4]:
import torch
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
from PIL import Image
from tqdm import tqdm
from chromadb.config import Settings
import chromadb
import requests
from annoy import AnnoyIndex
import json
import numpy as np
from imageinterminal import display_image

In [5]:
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
tokenizer = CLIPTokenizer.from_pretrained(model_name)

### Embeddings Related Functions

In [7]:
def get_image_embedding(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        image_features = model.get_image_features(**inputs)
    image_embedding = image_features / image_features.norm(dim=-1, keepdim=True)
    return image_embedding.cpu().numpy().squeeze()

def get_text_embedding(text, model=model, tokenizer=tokenizer, device=device):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        text_features = model.get_text_features(**inputs)
    text_embedding = text_features / text_features.norm(dim=-1, keepdim=True)
    
    return text_embedding.cpu().numpy().squeeze()

def cosine_similarity(a, b):
    dot_product = np.dot(a, b)
    norm_a = np.linalg.norm(a)
    norm_b = np.linalg.norm(b)
    return dot_product / (norm_a * norm_b)

### Sample check of CLIP's embeddings capability

In [8]:
embeddings = get_image_embedding('horse.png')
text_embeddings = get_text_embedding("a black and white image of horse", model, tokenizer, device)
cosine_similarity(embeddings, text_embeddings)

np.float32(0.30010346)

In [10]:
len(embeddings)

512

### Forming Search Index

In [10]:
with open('./data/examples.jsonl', 'r') as file:
    data = json.load(file)

combined_dict = {}
for i in range(400):
    first_image = data[i]['image_0']
    first_key = str(i) + '0'
    second_image = data[i]['image_1']
    second_key = str(i) + '1'
    combined_dict[int(first_key)] =  get_image_embedding('./data/images/images_' + first_image + '.png')
    combined_dict[int(second_key)] =  get_image_embedding('./data/images/images_' + second_image + '.png')

In [11]:
all_combined_dict = {}
for i in range(400):
    first_image = data[i]['image_0']
    first_key = str(i) + '0'
    second_image = data[i]['image_1']
    second_key = str(i) + '1'
    first_caption = data[i]['caption_0']
    second_caption = data[i]['caption_1']
    all_combined_dict[int(first_key)] =  ['./data/images/images_' + first_image + '.png', first_caption]
    all_combined_dict[int(second_key)] =  ['./data/images/images_' + second_image + '.png', second_caption]

In [95]:
f = 512 # Number of Dimensions
t = AnnoyIndex(f)
for i, j in combined_dict.items():
    t.add_item(i, j) # Adding the key-value pair of the AnnoyIndex

t.build(f) # Building 512 trees for AnnoyIndex, more the number of trees, more the memory consumed, better are the results of ANN algorithm
t.save('image-search-tree.ann') # Saving the AnnoyIndex for faster reading

  t = AnnoyIndex(f)


True

### Image Search Results

In [3]:
search_space = AnnoyIndex(512)
search_space.load('./image-search-tree.ann')

  search_space = AnnoyIndex(512)


True

In [4]:
def text_image_search(query: str, num : int = 5):
    query_vector = get_text_embedding(query) 
    ans = search_space.get_nns_by_vector(query_vector, num)
    for i in ans:
        image_path = all_combined_dict[i][0]
        caption = all_combined_dict[i][1]
        im = Image.open(image_path)
        im.show()
        print(caption, cosine_similarity(query_vector, combined_dict[i]))

In [5]:
text_image_search('donkey hearing some secrets')