In [None]:
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel
import os
import pandas as pd
import faiss
import display

from data_preprocessing_helper import data_preprocessing
from embeddings_helper import embeddings_generator_and_retreival

In [None]:
meta_data_directory = r"D:\LLM_Project\Multimodel Chatbot\Data\train"
master_csv_file_path = os.path.join(meta_data_directory, "new_master_csv.csv")
master_df = pd.read_csv(master_csv_file_path, dtype={"image_id": str})

In [None]:
dp_class = data_preprocessing()

In [None]:
image_directory = r"D:\LLM_Project\Multimodel Chatbot\Data\train\cropped_image_unique"
all_images, all_ids = dp_class.get_all_images_and_mapping(image_directory, master_df)
all_descriptions = master_df['description']

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = 'google/siglip2-base-patch16-224'
model = AutoModel.from_pretrained(model_name, torch_dtype="auto").to(device)
model.eval()
model = AutoProcessor.from_pretrained(model_name)

In [None]:
siglip_emb = embeddings_generator_and_retreival(model, model)

In [None]:
batch_size = 4
siglip_emb.generate_image_embeddings_and_faiss_index(image_directory, all_images, batch_size)

In [None]:
batch_size = 4
siglip_emb.generate_text_embeddings_and_faiss_index(image_directory, all_descriptions, batch_size)

### Retreival Test

In [None]:
example_directory = r"D:\LLM_Project\Multimodel Chatbot\Data\train\Example"
image_file = "000101.jpg"
example_path = os.path.join(example_directory, image_file)
image = Image.open(example_path).convert("RGB")

img_emb = siglip_emb.embed_image(image)

In [None]:
faiss_index_path = os.path.join(directory, "faiss_siglip2_base.index")
faiss_index = faiss.read_index(faiss_index_path)

In [None]:
k = 5        # top-5 results
dist, idx = faiss_index.search(img_emb, k)
print(idx[0])

In [None]:
for id_ in idx[0]:
    image_id = master_df.iloc[id_]['image_id']
    item_id = master_df.iloc[id_]['item_id']
    
    image_directory = r"D:\LLM_Project\Multimodel Chatbot\Data\train\cropped_image_unique"
    retreival_path = os.path.join(image_directory, f"{image_id}_{item_id}.jpg")
    image = Image.open(retreival_path).convert("RGB")
    display(image)