# 05. Similarity Search with FAISS

This notebook brings everything together to build a similarity search engine using FAISS. We will:
- Load the final dataset with multimodal embeddings.
- Prepare the text and image embeddings and fuse them into a single vector for each item.
- Build a FAISS index for efficient similarity search.
- Perform a search using an existing item as a query.
- Perform a search using a text query.
- Create a simple Gradio interface to demonstrate the search functionality.

## 1. Setup and Data Loading

In [2]:
import sys
import os
import pandas as pd
import numpy as np
import faiss
import gradio as gr
import torch
import gc

# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('__file__'))))

from src.embedding_utils import get_text_model, prepare_multimodal_embeddings
from src.faiss_utils import (
    build_faiss_index, 
    search_faiss_index, 
    create_multimodal_query, 
    get_search_results,
    save_faiss_index,
    load_faiss_index
)

# --- Configuration ---
CATEGORY = "CDs_and_Vinyl"
DATA_DIR = os.path.abspath(os.path.join(os.path.dirname('__file__'), '..'))
FULL_DATA_FILE = os.path.join(DATA_DIR, f"reviews_with_img_text_emb_{CATEGORY}.parquet")
FAISS_INDEX_FILE = os.path.join(DATA_DIR, f"faiss_index_{CATEGORY}.bin")
ASIN_MAP_FILE = os.path.join(DATA_DIR, f"asin_to_idx_{CATEGORY}.json")


# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load data
print(f"Loading data from {FULL_DATA_FILE}...")
# To save memory, load only the columns we absolutely need
columns_to_load = ['parent_asin', 'title', 'clip_img_emb', 'image_url']
df_full = pd.read_parquet(FULL_DATA_FILE, columns=columns_to_load, engine='fastparquet')

# Create a copy to avoid fragmentation issues and ensure contiguous memory
df = df_full.copy()
del df_full
gc.collect() # Force garbage collection to free up memory

# Ensure all embeddings are numpy arrays
df['img_embedding'] = df['img_embedding'].apply(np.array)
df['fused_embedding'] = df['fused_embedding'].apply(np.array)
print(f"Loaded {len(df)} rows.")

# Load text model
text_model = get_text_model(device=device)

Using device: cuda
Loading data from c:\Users\minhk\OneDrive\Documents\HCMUTSUB\DACN\reviews_with_img_text_emb_CDs_and_Vinyl.parquet...


ValueError: Following columns were requested but are not available: {'title', 'image_url'}.
All requested columns: ['parent_asin', 'title', 'clip_img_emb', 'image_url']
Available columns: ['asin', 'parent_asin', 'user_id', 'rating', 'title_x', 'text', 'timestamp', 'helpful_vote', 'verified_purchase', 'images_x', 'main_category', 'title_y', 'average_rating', 'rating_number', 'features', 'description', 'price', 'categories', 'store', 'img_url', 'clip_img_emb', 'has_img_emb', 'review_text', 'meta_text', 'review_text_emb', 'meta_text_emb', 'text_emb', 'images_y.hi_res', 'images_y.large', 'images_y.thumb', 'images_y.variant']

## 2. Prepare Embeddings and Build FAISS Index

We'll extract the text and image embeddings, fuse them into a single multimodal embedding, and then build a FAISS index for fast similarity search.

In [None]:
# Prepare embeddings for FAISS
multimodal_embeddings, asin_to_idx = prepare_multimodal_embeddings(
    df, 
    text_emb_col='fused_embedding', 
    img_emb_col='img_embedding'
)

# Build and save FAISS index
faiss_index = build_faiss_index(multimodal_embeddings)
save_faiss_index(faiss_index, asin_to_idx, FAISS_INDEX_FILE, ASIN_MAP_FILE)

print(f"FAISS index built and saved for category '{CATEGORY}'.")
print(f"Index size: {faiss_index.ntotal} vectors")

TypeError: prepare_multimodal_embeddings() got an unexpected keyword argument 'text_embedding_col'

## 3. Search: Find Similar Items

Let's test the index by picking a random item and finding its nearest neighbors.

In [6]:
# Load the index and mapping
faiss_index, asin_to_idx = load_faiss_index(f'faiss_index_{CATEGORY}.bin', f'asin_to_idx_{CATEGORY}.json')
idx_to_asin = {v: k for k, v in asin_to_idx.items()}

# Pick a random item to use as a query
random_item_asin = df.sample(1)['parent_asin'].values[0]
query_vector_index = asin_to_idx[random_item_asin]
query_vector = faiss_index.reconstruct(query_vector_index)

# Search for similar items
distances, indices = search_faiss_index(faiss_index, query_vector.reshape(1, -1), k=5)

# Display results
print(f"Query Item: {random_item_asin}")
get_search_results(distances, indices, df, idx_to_asin, random_item_asin)

TypeError: load_faiss_index() takes 1 positional argument but 2 were given

## 4. Search: Using a Text Query

Now, let's perform a search using a natural language query. We'll generate a multimodal query vector from the text and use it to search the FAISS index.

In [None]:
# Define a text query
text_query = "a comfortable and stylish pair of shoes for everyday wear"

# Create a multimodal query vector
query_vector = create_multimodal_query(text_query, text_model, device)

# Search the index
distances, indices = search_faiss_index(faiss_index, query_vector.reshape(1, -1), k=5)

# Display results
print(f"Query Text: '{text_query}'")
get_search_results(distances, indices, df, idx_to_asin)

## 5. Gradio Demo

Finally, let's wrap our search functionality in a simple Gradio interface to create an interactive demo.

In [None]:
import gradio as gr
from PIL import Image
import requests
from io import BytesIO

def search_and_display(query_text, search_by, query_item_asin=None):
    if search_by == "Text":
        query_vector = create_multimodal_query(query_text, text_model, device)
    elif search_by == "Item ASIN":
        if not query_item_asin or query_item_asin not in asin_to_idx:
            return "Invalid or missing ASIN.", [], []
        query_vector_index = asin_to_idx[query_item_asin]
        query_vector = faiss_index.reconstruct(query_vector_index)
    else:
        return "Invalid search type.", [], []

    distances, indices = search_faiss_index(faiss_index, query_vector.reshape(1, -1), k=5)
    
    results_df = get_search_results(distances, indices, df, idx_to_asin, query_item_asin if search_by == "Item ASIN" else None, return_df=True)
    
    if results_df.empty:
        return "No results found.", [], []

    # Prepare output for Gradio
    output_text = results_df.to_string(index=False)
    
    image_gallery = []
    for url in results_df['image_url']:
        try:
            response = requests.get(url)
            img = Image.open(BytesIO(response.content))
            image_gallery.append(img)
        except Exception as e:
            print(f"Could not load image {url}: {e}")
            # Add a placeholder image if loading fails
            image_gallery.append(Image.new('RGB', (100, 100), color = 'gray'))

    return output_text, image_gallery

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Multimodal Product Search")
    gr.Markdown(f"Search for products in the '{CATEGORY}' category.")
    
    with gr.Row():
        with gr.Column():
            search_by = gr.Radio(["Text", "Item ASIN"], label="Search By", value="Text")
            text_query_input = gr.Textbox(label="Text Query", placeholder="e.g., a comfortable and stylish pair of shoes")
            asin_query_input = gr.Textbox(label="Item ASIN", placeholder="e.g., B07816F551")
            search_button = gr.Button("Search")
        
        with gr.Column():
            results_output = gr.Textbox(label="Search Results", lines=10)
            gallery_output = gr.Gallery(label="Product Images", columns=5, height="auto")

    search_button.click(
        fn=search_and_display,
        inputs=[text_query_input, search_by, asin_query_input],
        outputs=[results_output, gallery_output]
    )

demo.launch(debug=True, share=False)