<a href="https://colab.research.google.com/github/KhangTheKangaroo/Image-Retrieval/blob/main/Image_Retrieval_with_CLIP_(Vector_Database_Collection).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!gdown --id 1msLVo0g0LFmL9-qZ73vq9YEVZwbzOePF # Download the dataset
!unzip -q data.zip

Downloading...
From (original): https://drive.google.com/uc?id=1msLVo0g0LFmL9-qZ73vq9YEVZwbzOePF
From (redirected): https://drive.google.com/uc?id=1msLVo0g0LFmL9-qZ73vq9YEVZwbzOePF&confirm=t&uuid=bcb2475c-c50c-4a79-8847-1d0d6d523194
To: /content/data.zip
100% 76.1M/76.1M [00:00<00:00, 110MB/s]


In [None]:
%pip install chromadb
%pip install open-clip-torch

In [None]:
import os
import chromadb
import numpy as np
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction

In [None]:
ROOT = 'data'
CLASS_NAME = sorted(list(os.listdir(f"{ROOT}/train"))) # Get the images' classes from data
HNSW_SPACE = 'hnsw:space'

In [None]:
def plot_results(image_path, files_path, results):
    query_image = Image.open(image_path).resize((448,448))
    images = [query_image]
    class_name = []
    for id_img in results['ids'][0]:
        id_img = int(id_img.split('_')[-1])
        img_path = files_path[id_img]
        img = Image.open(img_path).resize((448,448))
        images.append(img)
        class_name.append(img_path.split('/')[2])

    fig, axes = plt.subplots(2, 3, figsize=(12, 8))

    # Iterate through images and plot them
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i])
        if i == 0:
            ax.set_title(f"Query Image: {image_path.split('/')[2]}")
        else:
            ax.set_title(f"Top {i+1}: {class_name[i-1]}")
        ax.axis('off')  # Hide axes
    # Display the plot
    plt.show()

In [None]:
# Rather than Extrating the features of the image every time we gotta do a search
# We create a vector database to optimize time

def get_files_path(path):
  files_path = []
  for label in CLASS_NAME:
    label_path = path + "/" + label # Construct the full path to the directory containing images of the current class.
    filenames = os.listdir(label_path) # List all filenames in the directory of the current class.
    for filename in filenames:
      filepath = label_path + '/' + filename # Construct the full file path for the current image.
      files_path.append(filepath)
  return files_path

data_path = f'{ROOT}/train' # Set the data path to the directory containing the training dataset.
files_path = get_files_path(path=data_path)

In [None]:
embedding_function = OpenCLIPEmbeddingFunction()

def get_single_image_embedding(image):
    embedding = embedding_function._encode_image(image=np.array(image))
    return embedding

In [None]:
# Help store the vector of the features inside a collection

def add_embedding(collection, files_path):
 ids = []
 embeddings = []
 for id_filepath, filepath in tqdm(enumerate(files_path)):
  ids.append(f"id_{id_filepath}")
  image = Image.open(filepath)
  embedding = get_single_image_embedding(image=image)
  embeddings.append(embedding)
  collection.add(
  embeddings=embeddings,
  ids=ids) # Add the image embedding and its vector of feature

In [None]:

chroma_client = chromadb.Client() # Create a Chroma Client

l2_collection = chroma_client.get_or_create_collection(name="l2_collection",
                                                           metadata={HNSW_SPACE: "l2"}) # Create L2 Collection
add_embedding(collection=l2_collection, files_path=files_path)

In [None]:
def search(image_path, collection, n_results):
  query_image = Image.open(image_path)
  query_embedding = get_single_image_embedding(query_image)
  results = collection.query(query_embeddings=[query_embedding], n_results = n_results) # Return n number of results

  return results

In [None]:
test_path = f'{ROOT}/test'
test_files_path = get_files_path(path=test_path)
test_path = test_files_path[1]
l2_results = search(image_path=test_path, collection=l2_collection, n_results=5)
plot_results(image_path=test_path, files_path=files_path, results=l2_results)

In [None]:
cosine_collection = chroma_client.get_or_create_collection(name="Cosine_collection",
                                                           metadata={HNSW_SPACE: "cosine"}) # Create Cosine Similarity Collection
add_embedding(collection=cosine_collection, files_path=files_path)

In [None]:
test_path = f'{ROOT}/test'
test_files_path = get_files_path(path=test_path)
test_path = test_files_path[1]
cosine_results = search(image_path=test_path, collection=cosine_collection, n_results=5)
plot_results(image_path=test_path, files_path=files_path, results=cosine_results)