In [1]:
import numpy as np
import faiss
import pickle
from pathlib import Path
from sentence_transformers import SentenceTransformer, CrossEncoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
INDEX_DIR = Path("./indices/basic_all-MiniLM-L6-v2")

# Load index components
print(f"Loading index from {INDEX_DIR}")

# Load FAISS index
index = faiss.read_index(str(INDEX_DIR / "index.faiss"))

# Load product IDs
with open(INDEX_DIR / "product_ids.pkl", 'rb') as f:
    product_ids = pickle.load(f)

# Load product texts (for reranking)
with open(INDEX_DIR / "product_texts.pkl", 'rb') as f:
    product_texts = pickle.load(f)

print(f"✓ Loaded index with {index.ntotal} vectors")
print(f"✓ Loaded {len(product_texts)} product texts")

Loading index from indices\basic_all-MiniLM-L6-v2
✓ Loaded index with 482105 vectors
✓ Loaded 482105 product texts


In [3]:
def get_embedding_function(model='sentence-transformers/all-MiniLM-L6-v2'):
    """Initialize embedding function for queries"""
    encoder = SentenceTransformer(model)
    
    def embed(texts):
        return encoder.encode(texts, normalize_embeddings=True, show_progress_bar=False)
    
    return embed

embed = get_embedding_function()
print("✓ Embedding function ready")

✓ Embedding function ready


In [4]:
def search_index(query_embedding, k=10):
    """Search the pre-built FAISS index"""
    # Prepare query
    query_embedding = np.array(query_embedding).astype('float32')
    faiss.normalize_L2(query_embedding)
    
    # Search with IVF index
    index.nprobe = 32
    scores, indices = index.search(query_embedding, k)
    
    # Convert to product IDs
    results = [product_ids[i] for i in indices[0] if i < len(product_ids)]
    
    return scores[0], results

# Test
test_embedding = embed(["test query"])
scores, results = search_index(test_embedding, k=5)
print(f"✓ Index search working: found {len(results)} results")

✓ Index search working: found 5 results


In [5]:
def get_reranker(model='cross-encoder/ms-marco-MiniLM-L-6-v2'):
    """Initialize cross-encoder reranker"""
    cross_encoder = CrossEncoder(model)
    
    def rerank(query, candidate_ids, top_k=10):
        if not candidate_ids:
            return []
        
        # Get texts and score
        pairs = [[query, product_texts[pid]] for pid in candidate_ids]
        scores = cross_encoder.predict(pairs, show_progress_bar=False)
        
        # Sort by score
        ranked = sorted(zip(candidate_ids, scores), key=lambda x: x[1], reverse=True)
        return ranked[:top_k]
    
    return rerank

rerank = get_reranker()
print("✓ Reranker ready")

✓ Reranker ready


In [6]:
def search(query, k=10, use_reranker=True):
    """
    Complete semantic search
    
    1. Embed query
    2. Search index  
    3. Rerank results
    """
    # Embed
    query_embedding = embed([query])
    
    # Search (get more candidates if reranking)
    search_k = 50 if use_reranker else k
    scores, candidate_ids = search_index(query_embedding, k=search_k)
    
    # Rerank
    if use_reranker and candidate_ids:
        results = rerank(query, candidate_ids, top_k=k)
        return results
    else:
        return list(zip(candidate_ids[:k], scores[:k]))

In [7]:
test_queries = [
    "wireless bluetooth headphones",
    "stainless steel kitchen sink", 
    "laptop computer dell",
    "running shoes nike",
    "coffee maker espresso"
]

print("Testing search function:\n")

for query in test_queries:
    print(f"Query: '{query}'")
    results = search(query, k=5, use_reranker=True)
    
    for i, (pid, score) in enumerate(results, 1):
        text = product_texts.get(pid, "")[:80] + "..."
        print(f"  {i}. [{score:.2f}] [{pid}] {text}")
    print()

Testing search function:

Query: 'wireless bluetooth headphones'
  1. [8.71] [B07KLZQKL7] Wireless Earbuds, Tepoinn Bluetooth 5.0 True Wireless Bluetooth Headphones with ...
  2. [8.57] [B07Q8G7K48] Wireless Bluetooth Headphones Over-Ear with Deep Bass, Foldable Wireless and Wir...
  3. [8.34] [B07WSKKYPR] Sony Wireless Headphones WH-CH510: Wireless Bluetooth On-Ear Headset with Mic fo...
  4. [8.28] [B08RHMD51S] Bluetooth Headphones Wireless,TUINYO Over Ear Stereo Wireless Headset 40H Playti...
  5. [8.25] [B084YPMMYX] Wireless Earbuds, 5.0 True Wireless Bluetooth Headphones 3D Stereo Sound Wireles...

Query: 'stainless steel kitchen sink'
  1. [8.44] [B096S24CX3] hongyang Commercial Kitchen Sink Stainless Steel Single Bowl Sinks for Outdoor I...
  2. [8.09] [B0947559RS] Giantex Stainless Steel Utility Sink, 304 Commercial Sink w/ Backsplash, Drain S...
  3. [7.95] [B006ZTELZ0] Kraus KTM33 33 inch Topmount 50/50 Double Bowl 18 gauge Stainless Steel Kitchen ...
  4. [7.85] [B0032C4126]

In [8]:
def quick_eval(num_queries=50):
    """Quick evaluation on random queries"""
    import pandas as pd
    from tqdm import tqdm
    
    # Load test data
    df = pd.read_parquet("./data/shopping_queries_dataset_examples.parquet")
    df = df[(df['split'] == 'test') & (df['small_version'] == 1) & (df['product_locale'] == 'us')]
    
    # Get unique test queries
    test_queries = df.groupby('query_id').first().reset_index()
    test_queries = test_queries.sample(n=min(num_queries, len(test_queries)), random_state=42)
    
    hits = []
    
    for _, row in tqdm(test_queries.iterrows(), total=len(test_queries)):
        query_text = row['query']
        query_id = row['query_id']
        
        # Get relevant products
        query_data = df[df['query_id'] == query_id]
        relevant = set(query_data[query_data['esci_label'].isin(['E', 'S'])]['product_id'])
        
        if not relevant:
            continue
        
        # Search
        results = search(query_text, k=10, use_reranker=True)
        predicted = [pid for pid, _ in results]
        
        # Check if we got a hit
        hit = any(p in relevant for p in predicted[:10])
        hits.append(1.0 if hit else 0.0)
    
    print(f"\nHits@10: {np.mean(hits):.3f} (on {len(hits)} queries)")

# Run quick evaluation
quick_eval(num_queries=50)

100%|██████████| 50/50 [00:06<00:00,  7.56it/s]


Hits@10: 0.860 (on 50 queries)



