In [1]:
# !pip install -U datasets chromadb

In [13]:
from datasets import load_dataset
from tqdm import tqdm
import traceback
import re
import pandas as pd
from PIL import Image
import numpy as np
import chromadb
import time
import pickle
import os
from chromadb.utils.batch_utils import create_batches


# Create client

In [20]:
current_path = os.getcwd()

file_name = "image_text_embeddings_clip-vit-large-patch14_0_11759.pkl"

model_name = file_name[22:-12].replace('_', '-')

file_path = os.path.join(current_path,'embeddings', file_name)

with open(file_path, 'rb') as f:
  embeddings = pickle.load(f)

In [21]:
embeddings[0]#['text_embeddings']['caption_embedding_th'][0]

{'concept_id': 0,
 'image_id': 'argentina-0-1-eng',
 'image_key': 'argentina-0-1-eng',
 'index_in_dataset': 0,
 'concept': 'train',
 'concept_country': 'Tren Roca (commuter)',
 'country': 'Argentina',
 'title': "A Guide to Touring With Argentina's Trains | Travel Argentina",
 'concept_in_native': 'Tren',
 'image_embedding': array([[ 1.53695326e-03, -2.79524103e-02,  4.24493216e-02,
          3.10010673e-03,  9.62058734e-03, -2.83857062e-02,
          1.76329017e-02,  2.86542270e-02,  1.47622973e-02,
          2.29399297e-02, -1.46284904e-02,  1.72697958e-02,
         -1.32912090e-02,  3.86032127e-02, -2.61994228e-02,
          2.89968029e-02, -3.44796828e-03,  2.52284035e-02,
          1.93002634e-02,  1.86498713e-04, -6.43779263e-02,
         -9.34193563e-03, -1.08885206e-02,  1.75335575e-02,
         -4.96927835e-03, -9.98830050e-03, -2.26698611e-02,
          1.88610759e-02, -5.11858203e-02,  3.84308025e-03,
          1.38834817e-02,  4.89708818e-02, -2.53755376e-02,
          7.610

In [22]:
client = chromadb.PersistentClient('/content/chromadb')
collection_name = f"embeddings_gme_{int(time.time())}"
if collection_name in client.list_collections():
  client.delete_collection(name=collection_name)
collection = client.get_or_create_collection(
    name=collection_name,
    metadata={"hnsw:space": "cosine"}
)

In [23]:
ids = []
embeddings_list = []
metadatas_list = []
# embeddings2 = []

for item in tqdm(embeddings, total=len(embeddings)):
    # store caption embeddings
    for emb1 in item['text_embeddings']:
      for i, emb2 in enumerate(item['text_embeddings'][emb1]):
        # print(emb2)
        index_name = f'{item["index_in_dataset"]}_{i}_{emb1}'
        ids.append(index_name)
        embeddings_list.append(emb2) # Convert numpy array to list
        metadatas_list.append({
          'index': item["index_in_dataset"],
          'lang': emb1.split('_')[-1],
          'num_text': i+1
        })

100%|██████████| 11759/11759 [00:00<00:00, 59003.75it/s]


In [24]:
batches = create_batches(
    api=client,
    ids=ids,
    embeddings=embeddings_list,
    metadatas=metadatas_list,
    # batch_size=1000
)

for i, batch in tqdm(enumerate(batches), total=len(batches)):
    # print(f"Adding batch {i+1} of size {len(batch[0])}")
    collection.add(
        ids=batch[0],
        embeddings=batch[1],   # Will be None if not provided to create_batches
        metadatas=batch[2],
        documents=batch[3]
    )

100%|██████████| 9/9 [00:20<00:00,  2.26s/it]


# Retrieve

In [25]:
all_img = {
    'ids': [embedding['index_in_dataset'] for embedding in embeddings],
     'embeddings': [embedding['image_embedding'] for embedding in embeddings]
}

In [26]:
all_lang = []
for emb1 in embeddings[0]['text_embeddings']:
  all_lang.append(emb1.split('_')[-1])
print(len(all_lang))

4


In [27]:
num_top_results = 100
search_results = []

for id, embedding in tqdm(zip(all_img['ids'], all_img['embeddings']), total=len(all_img['ids'])):
    results = collection.query(
        query_embeddings=embedding,
        n_results=num_top_results,
        include=['distances', 'metadatas'],
        # where={"lang": lang}
        )
    if results and results['ids'] and results['ids'][0]:
        retrieved_ids = results['ids'][0]
        retrieved_distances = results['distances'][0]
        retrieved_metadatas = results['metadatas'][0]

        # 3. Process results and calculate similarity score
        for i, (doc_id, distance, metadata) in enumerate(zip(
            retrieved_ids, retrieved_distances, retrieved_metadatas
        )):
            # display(metadata)
            similarity_score = 1 - distance # Convert cosine distance to cosine similarity
            original_item_id_ds = metadata.get('index', 'N/A')
            caption_type = metadata.get('native', False)

            search_results.append({
                'query_id': id,
                'result_rank': i + 1,
                'result_id': original_item_id_ds,
                'lang': metadata.get('lang', 'N/A'),
                'caption_idx': metadata.get('num_text', 'N/A'),
                'score': float(f"{similarity_score:.4f}") # Format score for cleaner output
            })
        # break
    else:
        print(f"No results found for the given query at {id}.")


100%|██████████| 11759/11759 [01:18<00:00, 150.63it/s]


In [28]:
df = pd.DataFrame(search_results)
df.to_csv(f'result_imgtotext_{model_name}.csv', index=False)
print('saved ' + model_name)
df.shape

saved clip-vit-large-patch14


(1175735, 6)