**Import required packages**

In [None]:
!pip install -q chromadb
!pip show chromadb

In [None]:
import os
import chromadb
import numpy as np
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction

**Get Data**

In [None]:
ROOT = "data"
CLASS_NAME = sorted(os.listdir(f"{ROOT}/train"))
HNSW_SPACE = "hnsw:space"

In [None]:
def get_files_path(path):
    files_path = []
    for label in CLASS_NAME:
        label_path = path + '/' + label
        filenames = os.listdir(label_path)
        for filename in filenames:
            filespath = label_path + '/' + filename
            files_path.append(filespath)
    return files_path

In [None]:
files_path = get_files_path(f"{ROOT}/train")
files_path

In [None]:
def plot_results(image_path, files_path, results):
    query_image = Image.open(image_path).resize((448, 448))
    images = [query_image]
    class_name = []
    for id_img in results['id'][0]:
        id_img = int(id_img.split('_')[-1])
        img_path = files_path[id_img]
        img = Image.open(img_path).resize((448, 448))
        images.append(img)
        class_name.append(img_path.split('/')[2])

    fig, axes =plt.subplots(2, 3, figsize=(12, 8))

    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i])
        if i == 0:
            ax.set_title(f"Query Image: {image_path.split('/')[2]}")
        else:
            ax.set_title(f"Top {i+1}: {class_name[i-1]}")
        ax.axis('off')
    plt.show()

**Image Embedding**

In [None]:
embedding_function = OpenCLIPEmbeddingFunction()

def get_single_image_embedding(image):
    embedding = embedding_function._encode_image(image=np.array(image))
    return embedding

In [None]:
img = Image.open('data/train/African_crocodile/n01697457_260.JPEG')
get_single_image_embedding(img)

**Chromadb L2 Embedding Collection**

In [None]:
def add_embedding(collection, files_path):
    ids = []
    embeddings = []
    for id_filepath, filepath in tqdm(enumerate(files_path)):
        ids.append(f'id_{id_filepath}')
        image = Image.open(filepath)
        embedding = get_single_image_embedding(image)
        embeddings.append(embedding)
    collection.add(ids=ids, embeddings=embeddings)

In [None]:
chroma_client = chromadb.Client()

L2_collection = chroma_client.get_or_create_collection(name="l2_collection",
                                                       metadata={HNSW_SPACE: "l2"})
add_embedding(collection=L2_collection, files_path=files_path)

**Search Image With L2 Coollection**

In [None]:
def search(image_path, collection, n_results):
    query_image = Image.open(image_path)
    query_embedding = get_single_image_embedding(query_image)
    results = collection.query(query_embedding=[query_embedding],
                               n_results=n_results)
    return results

In [None]:
test_path = f"{ROOT}/test"
test_file_path = get_files_path(test_path)
test_path = test_file_path[1]
l2_results = search(image_path=test_path, collection=L2_collection, n_results=10)
print(l2_results)

In [None]:
plot_results(image_path=test_path, files_path=files_path, results=l2_results)

**Search Image With Cosine similarity Collection**

In [None]:
cosine_collection = chroma_client.get_or_create_collection(name="cosine_collection",
                                                           metadata={HNSW_SPACE: "cosine"})
add_embedding(collection=cosine_collection, files_path=files_path)

In [None]:
test_path = f"{ROOT}/test"
test_file_path = get_files_path(test_path)
test_path = test_file_path[1]
cosine_results = search(image_path=test_path, collection=cosine_collection, n_results=10)

In [None]:
cosine_results

In [None]:
plot_results(image_path=test_path, files_path=files_path, results=cosine_results)