In [1]:
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from sklearn.decomposition import PCA
import faiss
import os
from tqdm import tqdm
import sys
sys.path.append('../src/backend')
from embedding import generate_embedding
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_chunked_embeddings(embeddings_dir):
    """
    Load embeddings from chunked NPZ files.
    
    Args:
        embeddings_dir: Path to directory containing NPZ files
        
    Returns:
        tuple: (embeddings array, product_ids array)
    """
    # Find all NPZ files
    npz_files = sorted([f for f in os.listdir(embeddings_dir) 
                       if f.startswith('fusion_embeddings_chunk_') and f.endswith('.npz')])
    
    all_embeddings = []
    all_pids = []
    
    print("Loading embeddings from chunks...")
    for npz_file in tqdm(npz_files):
        file_path = os.path.join(embeddings_dir, npz_file)
        data = np.load(file_path)
        all_embeddings.append(data['embeddings'])
        all_pids.extend(data['product_ids'])
    
    # Concatenate all embeddings
    embeddings = np.vstack(all_embeddings)
    return embeddings, np.array(all_pids)

In [7]:
def visualize_query_clusters(index: faiss.IndexIVFFlat,
                           embeddings: np.ndarray,
                           query_text: str,
                           pids: np.ndarray,
                           df: pd.DataFrame,
                           nprobe: int = 1,
                           top_k: int = 5):
    """
    Visualize clusters visited during query search and their top k closest points.
    
    Args:
        index: FAISS IndexIVFFlat (trained + populated).
        embeddings: (N, D) L2-normalized array of all your item vectors.
        query_text: the text to embed and query.
        pids: (N,) array of product IDs, aligned with embeddings.
        df: DataFrame with at least ['Pid','Name'] columns.
        nprobe: how many IVF centroids to visit.
        top_k: how many points per cluster to show.
    """
    # 1) Embed the query and set nprobe
    qv = generate_embedding(query_text).astype('float32')
    index.nprobe = nprobe

    # 2) Get the top-nprobe clusters for the query
    _, cluster_ids = index.quantizer.search(qv.reshape(1, -1), nprobe)
    cluster_ids = cluster_ids[0]

    # 3) Get top k results directly from FAISS
    distances, indices = index.search(qv.reshape(1, -1), top_k * nprobe)
    distances = distances[0]
    indices = indices[0]

    # 4) Get cluster assignments for only the retrieved points
    _, assigned = index.quantizer.search(embeddings[indices], 1)
    assigned = assigned.flatten()

    records = []
    for idx, (dist, cluster_id) in enumerate(zip(distances, assigned)):
        if idx >= top_k * nprobe:
            break
            
        records.append({
            'pid': pids[indices[idx]],
            'cluster': f"Cluster {cluster_id}",
            'distance': float(embeddings[indices[idx]].dot(qv))
        })

    if not records:
        print("🚫 No points found in the visited clusters!")
        return None

    # 5) Build result DataFrame + join names
    result_df = pd.DataFrame(records)
    name_map = df.set_index('Pid')['Name'].to_dict()
    result_df['name'] = result_df['pid'].map(name_map)

    # 6) PCA on just the selected points + query
    pts = embeddings[indices[:len(records)]]
    all_pts = np.vstack([pts, qv])
    pca = PCA(n_components=2)
    pts2d = pca.fit_transform(all_pts)

    # assign back into DataFrame
    result_df[['x','y']] = pts2d[:-1]
    qx, qy = pts2d[-1]

    # 7) Plot with Plotly
    fig = px.scatter(
        result_df,
        x='x', y='y',
        color='cluster',
        size='distance',
        hover_name='name',
        title=f"Top {top_k} in {nprobe} Clusters for {query_text}",
        labels={'x':'PCA 1','y':'PCA 2'},
        opacity=0.75,
        hover_data={'x': False, 'y': False}  # Hide x and y coordinates from tooltip
    )
    fig.add_trace(go.Scatter(
        x=[qx], y=[qy],
        mode='markers',
        marker=dict(symbol='diamond', size=15, color='red'),
        name='Query',
        hovertemplate=query_text + '<extra></extra>'
    ))
    fig.update_layout(showlegend=True)
    return fig

In [8]:
index = faiss.read_index('../data/faiss_indices/fusion_index_nlist4000.faiss')
embeddings, pids = load_chunked_embeddings('../data/embeddings')
df = pd.read_csv('../data/csv/sample_1M.csv')

fig = visualize_query_clusters(
    index, 
    embeddings, 
    "comfortable shoes for running",
    pids,
    df,
    nprobe=3,
    top_k=5
)

fig.show()

Loading embeddings from chunks...


100%|██████████| 2/2 [00:02<00:00,  1.11s/it]
