In [9]:
import os
import torch
import numpy as np
import open_clip
from PIL import Image
import matplotlib.pyplot as plt
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
import chromadb
import sys
from tqdm.notebook import tqdm
import json 

In [10]:
def GET_PROJECT_ROOT():
    current_abspath = os.path.abspath('./')
    while True:
        if os.path.split(current_abspath)[1] == 'Image-Retrieval-Simple-System-With-Streamlit':
            project_root = current_abspath
            break
        else:
            current_abspath = os.path.dirname(current_abspath)
    return project_root

PROJECT_ROOT = GET_PROJECT_ROOT()
os.chdir(PROJECT_ROOT)
sys.path.append(PROJECT_ROOT)
print(f"Current working directory: {os.getcwd()}")

Current working directory: d:\BachKhoa\AIO-Projects\OwnProject\Image-Retrieval-Simple-System-With-Streamlit


In [11]:
embedding_function = OpenCLIPEmbeddingFunction()

In [34]:
def get_single_image_embedding(
    image
):
    embedding = embedding_function._encode_image(image=image)
    return np.array(embedding)

def read_image_from_path (path, size=(224, 224)):
    im = Image.open(path).convert('RGB').resize(size)
    return np.array(im)

def precompute_embeddings(root_path, class_names):
    all_embeddings = []
    global_index2img_path = []
    
    for class_name in tqdm(class_names, desc="Computing embeddings"):
        class_path = os.path.join(root_path, class_name)
        for img_name in sorted(os.listdir(class_path)):
            img_path = os.path.join(class_path, img_name)
            image = read_image_from_path(img_path)
            embedding = get_single_image_embedding(image)
            all_embeddings.append(embedding)
            global_index2img_path.append(img_path)
    
    return np.array(all_embeddings), global_index2img_path


def save_embeddings(embeddings, global_index2img_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    np.save(os.path.join(output_dir, "global_embeddings.npy"), embeddings)
    with open(os.path.join(output_dir, "global_index2img_path.json"), 'w') as f:
        json.dump(global_index2img_path, f)

def load_embeddings(input_dir):
    embeddings = np.load(os.path.join(input_dir, "global_embeddings.npy"))
    with open(os.path.join(input_dir, "global_index2img_path.json"), 'r') as f:
        global_index2img_path = json.load(f)
    return embeddings, global_index2img_path

def cosine_similarity(query, data):
    query = query.reshape(1, -1)  
    data = data.reshape(-1, query.shape[1])  
    
    dot_product = np.dot(query, data.T)
    query_norm = np.linalg.norm(query)
    data_norm = np.linalg.norm(data, axis=1)
    
    similarities = dot_product / (query_norm * data_norm + np.finfo(float).eps)
    return similarities.flatten()

def search_similar_images(query_embedding, embeddings, global_index2img_path, top_k=20):
    similarities = cosine_similarity(query_embedding, embeddings)
    top_indices = np.argsort(similarities)[-top_k:][::-1]
    return [(global_index2img_path[i], similarities[i]) for i in top_indices]

In [16]:
ROOT = './data/processed'
TRAIN_DIR = f'{ROOT}/train'
TEST_DIR = f'{ROOT}/test'
EMBEDDINGS_DIR = f'{ROOT}/embeddings'

CLASS_NAME = sorted(list(os.listdir(f'{ROOT}/train'))) # A list of ClassName from ImageNetV1
QUERY_PATH = f'{TEST_DIR}/Orange_easy/0_100.jpg'

In [14]:
all_embeddings, global_index2img_path = precompute_embeddings(
    TRAIN_DIR, CLASS_NAME
)

Computing embeddings:   0%|          | 0/60 [00:00<?, ?it/s]

In [35]:
query_img = read_image_from_path(QUERY_PATH)
query_embedding = get_single_image_embedding(query_img).reshape(1, -1)
similar_images = search_similar_images(query_embedding, all_embeddings, global_index2img_path)


In [37]:
similar_images[:6]

[('./data/processed/train\\Orange_easy\\r_305_100.jpg', 0.9708879446852944),
 ('./data/processed/train\\Orange_easy\\r_193_100.jpg', 0.9505567539680412),
 ('./data/processed/train\\Orange_easy\\r_137_100.jpg', 0.9431781392710117),
 ('./data/processed/train\\Orange_easy\\r_170_100.jpg', 0.9415943297391881),
 ('./data/processed/train\\Orange_easy\\dark.png', 0.8643594598443793),
 ('./data/processed/train\\goldfish\\n01443537_1415.JPEG', 0.4325751026037816)]