In [1]:
import pandas as pd
import numpy as np
import time

from sklearn.cluster import KMeans
from numpy.linalg import norm
from scipy.spatial.distance import euclidean
from scipy.linalg import inv

In [2]:
# Make sure this path points to your downloaded and unzipped GloVe file.
GLOVE_FILE_PATH = '/Users/haneulkim/Desktop/data/wiki_giga_2024_50_MFT20_vectors_seed_123_alpha_0.75_eta_0.075_combined.txt'

def load_glove_embeddings(file_path):
    """
    Loads GloVe word embeddings from a text file.
    GloVe is an unsupervised learning algorithm for obtaining vector
    representations for words [5, 6].
    """
    words, vectors = [], []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word, embs = values[0], values[1:]
            try:
                vector = np.asarray(embs, "float32")
                words.append(word)
                vectors.append(vector)
            except ValueError:
                continue
    embeddings_df = pd.DataFrame({"word": words, "vector": vectors})
    return embeddings_df

embeddings_df = load_glove_embeddings(GLOVE_FILE_PATH)
embeddings_df['dimensions'] = embeddings_df['vector'].apply(lambda x: x.shape[0])
embeddings_df = embeddings_df[embeddings_df['dimensions'] == 50].copy()

In [27]:
n_items = 10_000
n_subvectors = 10    
n_clusters = 8       # K: Number of centroids per subspace (4-bit quantization: 2^4 = 0-15 integer points)
eta = 2.04           # Anisotropic weighting factor

item_embeddings = np.stack(embeddings_df.iloc[:n_items]['vector'].values)
item_embeddings /= norm(item_embeddings, axis=1, keepdims=True)
embedding_dim = item_embeddings.shape[1]

# Calculate the dimension of each subvector
subvector_dim = embedding_dim // n_subvectors

print("subvector_dim: ", subvector_dim)
print(item_embeddings.shape)

subvector_dim:  5
(10000, 50)


In [28]:
# The ScaNN paper notes that for MIPS tasks equivalent to cosine similarity,
# data is unit-normalized. This is a common practice [9].
item_embeddings = np.stack(embeddings_df.iloc[:n_items]['vector'].values)
item_embeddings /= norm(item_embeddings, axis=1, keepdims=True)
print(item_embeddings.shape)

(10000, 50)


How we learn the codebooks(finding optimal cluster) differentiates various quantization methods.

In [29]:
subspaces = [[] for j in range(n_subvectors)]
for v in item_embeddings:
    for j in range(n_subvectors):
        subvector = v[j * subvector_dim:(j + 1) * subvector_dim]
        subspaces[j].append(subvector)
subspaces = [np.array(s) for s in subspaces]

# 1. Quantization

In [30]:
assert (np.append(subspaces[0][0], subspaces[1][0]) == item_embeddings[0][:subvector_dim*2]).all

In [31]:
def encode(vector, codebooks):
    codes = []
    for j in range(n_subvectors):
        sub_vec = vector[j * subvector_dim:(j + 1) * subvector_dim]
        codebook_subspace_j = codebooks[j] 
        # find the nearest centroid for each subvector using euclidean distance.
        dist = np.sum((codebook_subspace_j - sub_vec) ** 2, axis=1)
        codes.append(np.argmin(dist)) # assigning index of the nearest centroid.
    return np.array(codes)

def decode(q_vector, codebooks):
    """
    Note, sum of local errors is same as concatenating then computing distance.
    """
    reconstructed_subvecs = []
    for j in range(n_subvectors):
        # get centroid of cluster(code) that q_vector was assigned to.
        centroid = codebooks[j][q_vector[j]]
        reconstructed_subvecs.append(centroid)
    return np.concatenate(reconstructed_subvecs)

## 1.1. Standard Product Quantization

In [32]:
# Training: Standard PQ: Isotropic loss function (a.k.a. Euclidean distance loss)
# codebook_size = number of clusters
codebooks = [] 
verbose=False
print(f"cluster each subspace of {subvector_dim}-dimensions into {n_clusters} clusters")
for j, subspace in enumerate(subspaces):
    kmeans = KMeans(n_clusters=n_clusters, 
        init='k-means++', n_init=1, max_iter=300, verbose=verbose).fit(subspace)
    codebooks.append(kmeans.cluster_centers_)


print(codebooks[0].shape)

print(f"number of kmeans iterations: {kmeans.n_iter_}")
# so for each (10000, 16) subspace we have 10 clusters.

cluster each subspace of 5-dimensions into 8 clusters
(8, 5)
number of kmeans iterations: 33


In [33]:
q_vectors = []
for v in item_embeddings:
    q_vec = encode(v, codebooks)
    q_vectors.append(q_vec)
q_vectors = np.array(q_vectors, dtype=np.int8)

reconstructed_vectors = []  
for q_vec in q_vectors:
    reconstructed_vector = decode(q_vec, codebooks)
    reconstructed_vectors.append(reconstructed_vector)
reconstructed_vectors = np.array(reconstructed_vectors)
reconstructed_vectors.shape

(10000, 50)

## 1.2. Quantization w anisotropic loss function

In [34]:
def anisotropic_loss(residuals, x, eta=4.125, verbose=False):
    x_norm_sq = np.sum(x**2, axis=-1, keepdims=True)
    # Avoid division by zero for zero-norm vectors
      # our numerator is inner_prod which will be 0 if x is 0 vector. so when we divide
      # by x_norm_sq 0 vector should get 0. In order to do that, we set x_norm_sq to 1 when it is 0.
    x_norm_sq[x_norm_sq == 0] = 1.0

    # a1b2 + a2b2+ ... + anbn
    # (N, k, d) * (N, 1, d) -> (N, k, d) => for all N rows, d is projection of residuals onto x.
    inner_prod = np.sum(residuals * x, axis=-1, keepdims=True)
    if verbose:
        print("residuals * x.shape: ", (residuals * x).shape)
        print("inner_prod.shape: ", inner_prod.shape)

    r_parallel = inner_prod / x_norm_sq * x
    r_orthogonal = residuals - r_parallel
    if verbose:
        print("r_parallel.shape: ", r_parallel.shape)
        print("r_orthogonal.shape: ", r_orthogonal.shape)

    r_parallel_sq_norm = np.sum(r_parallel**2, axis=-1, keepdims=True)
    r_orthogonal_sq_norm = np.sum(r_orthogonal**2, axis=-1, keepdims=True)
    if verbose:
        print("r_parallel_sq_norm.shape: ", r_parallel_sq_norm.shape)
        print("r_orthogonal_sq_norm.shape: ", r_orthogonal_sq_norm.shape)
    # --- Step 4: Combine into the final anisotropic loss ---
    # This is the final loss function from the paper [3, 4].
    # ` ∝ η * ||r‖(xi, x̃i)||² + ||r⊥(xi, x̃i)||²
    # Paper defines eta as ratio of parallel and orthogonal components (eta = h‖ / h⊥). section 3.2.
     # This allows us to avoid computing complex integrals.
     # Refer to my note.
    return eta * r_parallel_sq_norm + r_orthogonal_sq_norm

In [35]:
# Training: PQ with anisotropic loss function
n_iter = 100 # for training with anisotropic loss function
scann_codebooks = np.zeros((n_subvectors, n_clusters, subvector_dim), dtype=np.float32)

loss_tracking = []
verbose = True
for j, subspace in enumerate(subspaces):
    centroids = np.array([subspace[np.random.choice(n_items)] for _ in range(n_clusters)]) 
    # # Kmeans : 
    # # "We note that we can also optionally initialize the codebook by first training the codebook under 
    # # regular `2-reconstruction loss, which speed up training process"
    # kmeans = KMeans(n_clusters=n_clusters, 
    #     init='k-means++', n_init=1, max_iter=300).fit(subspace)
    # centroids = kmeans.cluster_centers_
    print(f"--Training subspace {j} --")
    for i in range(n_iter):
        # Step 2: Partition Assignment.
        residuals = subspace[:, np.newaxis, :] - centroids
        losses = anisotropic_loss(residuals, subspace[:, np.newaxis, :], eta=eta, verbose=False)
        assignments = np.argmin(losses, axis=1).flatten()

        # losses(10000, 8, 1). losses[0,0,0] represent anisotropic loss for 0th subvector to 0th centroid. losses[0, 1, 0] represent anisotropic loss for 0th subvector to 1st centroid, ...
        # so np.min() => for each subvector, find the minimum anisotropic loss (centroid)
        # np.mean() => average over all items
        current_loss = np.mean(np.min(losses, axis=1))
        loss_tracking.append({
            "eta": eta,
            "subspace": j,
            "iteration": i,
            "loss": current_loss
        })
        if verbose:
            print(f"Iter {i+1}/{n_iter} | Avg. Anisotropic Loss: {current_loss:.6f}")
        
        # Step 3: Codebook update.
        new_centroids = np.zeros(centroids.shape)
        # The goal is to find the single new position for the centroid cj 
        # that minimizes the total anisotropic loss for all the points in its cluster, Xj
        for k in range(n_clusters):
            assigned_points = subspace[assignments == k]
            if len(assigned_points) == 0:
                # If a cluster is empty, re-initialize its centroid to a random point
                new_centroids[k] = subspace[np.random.choice(n_items)]
                continue
            
            # === Based on Theorem 4.2, with h_par = eta and h_orth = 1. ===
            # Since our vectors are unit-norm, ||x||^2 is constant for the full vector,
            # but not necessarily for the subvectors.

            # This creates an identity matrix I and scales it by the subvector_dimension
            A = len(assigned_points) * np.identity(subvector_dim, dtype=np.float32)
            xxt_sum = np.zeros((subvector_dim, subvector_dim), dtype=np.float32)
            for x in assigned_points:
                x_norm_sq = np.sum(x**2)
                xxt_sum += ((eta - 1) / x_norm_sq) * np.outer(x, x)
            
            # Right side of the equation: b = sum(eta*x)
            b = eta * np.sum(assigned_points, axis=0)
            
            # simplified equation from theorem 4.2.
            # LinAlgError: singular matrix
            try:
                new_centroids[k] = inv(A + xxt_sum) @ b
            except np.linalg.LinAlgError:
                # If matrix is singular, just keep the old centroid
                new_centroids[k] = centroids[k]

        # returns True if all elements are equal with a tolerance of 1e-5
        if np.allclose(centroids, new_centroids, atol=1e-5):
            print(f"subspace {j} - Converged after {i+1} iterations.")
            break
        centroids = new_centroids
    # verbose = False
    scann_codebooks[j] = centroids

# After training, create a DataFrame to track all loss for each subspace
loss_df = pd.DataFrame(loss_tracking)
# Each row contains: subspace, iteration, all_losses (flattened), min_loss_per_item, avg_min_loss
losses.shape


--Training subspace 0 --
Iter 1/100 | Avg. Anisotropic Loss: 0.080056


Iter 2/100 | Avg. Anisotropic Loss: 0.060392
Iter 3/100 | Avg. Anisotropic Loss: 0.055772
Iter 4/100 | Avg. Anisotropic Loss: 0.054222
Iter 5/100 | Avg. Anisotropic Loss: 0.053536
Iter 6/100 | Avg. Anisotropic Loss: 0.052996
Iter 7/100 | Avg. Anisotropic Loss: 0.052434
Iter 8/100 | Avg. Anisotropic Loss: 0.051923
Iter 9/100 | Avg. Anisotropic Loss: 0.051473
Iter 10/100 | Avg. Anisotropic Loss: 0.051070
Iter 11/100 | Avg. Anisotropic Loss: 0.050708
Iter 12/100 | Avg. Anisotropic Loss: 0.050469
Iter 13/100 | Avg. Anisotropic Loss: 0.050314
Iter 14/100 | Avg. Anisotropic Loss: 0.050216
Iter 15/100 | Avg. Anisotropic Loss: 0.050152
Iter 16/100 | Avg. Anisotropic Loss: 0.050099
Iter 17/100 | Avg. Anisotropic Loss: 0.050065
Iter 18/100 | Avg. Anisotropic Loss: 0.050043
Iter 19/100 | Avg. Anisotropic Loss: 0.050032
Iter 20/100 | Avg. Anisotropic Loss: 0.050026
Iter 21/100 | Avg. Anisotropic Loss: 0.050021
Iter 22/100 | Avg. Anisotropic Loss: 0.050013
Iter 23/100 | Avg. Anisotropic Loss: 0.050

(10000, 8, 1)

In [None]:
import plotly.express as px
fig = px.line(
    loss_df, 
    x="iteration", 
    y="loss", 
    color="subspace",
    markers=True,
    title=f"Anisotropic PQ Loss per Subspace over Iterations. eta = {eta}"
)
fig.update_layout(
    xaxis_title="Iteration",
    yaxis_title="Loss",
    legend_title="Subspace"
)
fig.show()›

In [37]:
"""
residuals = (N, k, d) = (10000, 10, 16)
    - axis0: number of items
    - axis1: number of clusters
    - axis2: residual of subvector and centroid
Example:
- residuals[0, 0] or residuals[0, 0, :] → residual for 0th subvector w.r.t. centroid 0.
- residuals[0, 3] → residual for 0th subvector w.r.t. centroid 3.
- residuals[0] → (10, 16) residuals for 0th subvector to all centroids.
"""
print(subspace.shape)
print(centroids.shape)
print(residuals.shape)

assert (subspace[0] - centroids[0] == residuals[0, 0, :]).all()

(10000, 5)
(8, 5)
(10000, 8, 5)


In [38]:
q_vectors_scann = []
for v in item_embeddings:
    q_vec = encode(v, scann_codebooks)
    q_vectors_scann.append(q_vec)
q_vectors_scann = np.array(q_vectors_scann, dtype=np.int8)

scann_reconstructed_vectors = []
for q_vec in q_vectors_scann:
    reconstructed_vector = decode(q_vec, scann_codebooks)
    scann_reconstructed_vectors.append(reconstructed_vector)
scann_reconstructed_vectors = np.array(scann_reconstructed_vectors)
scann_reconstructed_vectors.shape


(10000, 50)

# 2. Comparison

In [39]:
n_queries = 100
# Exclude the first n_items (item_embeddings) and randomly sample n_queries queries
query_candidates = embeddings_df.iloc[n_items:]
query_sample = query_candidates.sample(n=n_queries, random_state=42)
query_embeddings = np.stack(query_sample['vector'].values).astype(np.float32)
query_embeddings /= norm(query_embeddings, axis=1, keepdims=True)

In [None]:
# 1. Compute Ground Truth (exact search)
k_values = [1, 5, 10, 20]
print("\nComputing ground truth neighbors...")
# Don't forget we've already normalized item_embeddings 
# therefore cosine similarity is equivalent to inner product.
ground_truth_scores = query_embeddings @ item_embeddings.T  # (n_queries, n_items)
ground_truth_neighbors = {}
for k in k_values:
    ground_truth_neighbors[k] = np.argsort(-ground_truth_scores, axis=1)[:, :k]


# 2. Approximate search with quantized representations
def approximate_search_pq(query, reconstructed_items, k):
    """Search using reconstructed vectors"""
    scores = query @ reconstructed_items.T
    return np.argsort(-scores)[:k]


# 3. Calculate Recall@k
def calculate_recall_at_k(true_neighbors, approx_neighbors):
    """Calculate what fraction of true top-k are in approximate top-k"""
    intersection = len(np.intersect1d(true_neighbors, approx_neighbors))
    return intersection / len(true_neighbors)

print("\n--- Recall @k ---")
for k in k_values:
    standard_recalls = []
    scann_recalls = []
    
    for q_idx in range(len(query_embeddings)):
        query = query_embeddings[q_idx]
        true_neighbors = ground_truth_neighbors[k][q_idx]
        
        # Standard PQ
        standard_approx = approximate_search_pq(query, reconstructed_vectors, k)
        standard_recalls.append(calculate_recall_at_k(true_neighbors, standard_approx))
        
        # ScaNN PQ
        scann_approx = approximate_search_pq(query, scann_reconstructed_vectors, k)
        scann_recalls.append(calculate_recall_at_k(true_neighbors, scann_approx))
    print(f"Recall@{k:2d}: Standard PQ = {np.mean(standard_recalls):.4f}, ScaNN PQ = {np.mean(scann_recalls):.4f}")


# 4. Calculate Recall 1@k
def calculate_recall_1_at_k(true_neighbors, approx_neighbors):
    """Calculate what fraction of true top-k are in approximate top-k"""
    return 1 if true_neighbors[0] in approx_neighbors else 0

print("\n--- Recall 1@k ---")
for k in k_values:
    standard_recalls = []
    scann_recalls = []
    
    for q_idx in range(len(query_embeddings)):
        query = query_embeddings[q_idx]
        true_neighbors = ground_truth_neighbors[k][q_idx]
        
        # Standard PQ
        standard_approx = approximate_search_pq(query, reconstructed_vectors, k)
        standard_recalls.append(calculate_recall_1_at_k(true_neighbors, standard_approx))
        
        # ScaNN PQ
        scann_approx = approximate_search_pq(query, scann_reconstructed_vectors, k)
        scann_recalls.append(calculate_recall_1_at_k(true_neighbors, scann_approx))
    
    print(f"Recall@{k:2d}: Standard PQ = {np.mean(standard_recalls):.4f}, ScaNN PQ = {np.mean(scann_recalls):.4f}")



Computing ground truth neighbors...

--- Recall @k ---
Recall@ 1: Standard PQ = 0.0300, ScaNN PQ = 0.0900
Recall@ 5: Standard PQ = 0.1120, ScaNN PQ = 0.1100
Recall@10: Standard PQ = 0.1590, ScaNN PQ = 0.1510
Recall@20: Standard PQ = 0.1775, ScaNN PQ = 0.1940

--- Recall 1@k ---
Recall@ 1: Standard PQ = 0.0300, ScaNN PQ = 0.0900
Recall@ 5: Standard PQ = 0.1200, ScaNN PQ = 0.1600
Recall@10: Standard PQ = 0.2000, ScaNN PQ = 0.2200
Recall@20: Standard PQ = 0.3000, ScaNN PQ = 0.3700
