In [1]:
import numpy as np
import os
import time
from math import sqrt

In [2]:
# Hadamard (power-of-two only)
def hadamard(n):
    assert (n & (n - 1)) == 0
    H = np.array([[1]], dtype=np.float32)
    while H.shape[0] < n:
        H = np.block([[H, H], [H, -H]])
    return H


# L2 normalise
def l2_normalize(x):
    return x / np.linalg.norm(x, axis=-1, keepdims=True)


# Blockwise INT8 quantisation
def quantize_int8_blockwise(x, block=32):
    x = x.astype(np.float32)
    n = len(x)
    q = np.zeros_like(x, dtype=np.int8)
    scales = []

    for i in range(0, n, block):
        chunk = x[i:i+block]
        s = np.max(np.abs(chunk)) + 1e-8
        scales.append(s / 127.0)
        q[i:i+block] = np.round(chunk / scales[-1]).astype(np.int8)

    return q, np.array(scales, dtype=np.float32)


def dot_q8_v8_blockwise(q1, s1, q2, s2, block=32):
    total = 0.0
    idx = 0
    for b in range(len(s1)):
        size = min(block, len(q1) - idx)
        a = q1[idx:idx+size].astype(np.float32)
        bq = q2[idx:idx+size].astype(np.float32)
        total += np.dot(a, bq) * s1[b] * s2[b]
        idx += size
    return total

In [3]:
GLOVE_PATH = r"G:\train_jw\data\glove\glove.6B.300d.txt"
assert os.path.exists(GLOVE_PATH)

print("Loading 50k GloVe300...")

vecs = []
with open(GLOVE_PATH, "r", encoding="utf-8") as f:
    for i, line in enumerate(f):
        if i >= 50000:
            break
        parts = line.strip().split()
        vec = np.array(parts[1:], dtype=np.float32)
        vecs.append(vec)

X_all = np.stack(vecs)
X_all = l2_normalize(X_all)

N, d = X_all.shape
print("Loaded:", X_all.shape)

Loading 50k GloVe300...
Loaded: (50000, 300)


In [4]:
bag_size = 32
num_bags = N // bag_size
X_all = X_all[:num_bags * bag_size]  # trim

X_bags = X_all.reshape(num_bags, bag_size, d)

print("Bags:", X_bags.shape)

Bags: (1562, 32, 300)


In [5]:
print("Building spectral index...")

Hn = hadamard(bag_size).astype(np.float32) / sqrt(bag_size)

index = []

t0 = time.time()

for B in X_bags:
    Y = Hn @ B
    Yq, Ys = [], []
    for k in range(bag_size):
        q, s = quantize_int8_blockwise(Y[k], block=32)
        Yq.append(q)
        Ys.append(s)
    index.append((np.stack(Yq), np.stack(Ys)))

print("Build time:", round(time.time() - t0, 2), "seconds")

Building spectral index...
Build time: 8.59 seconds


In [6]:
def flat_search(q):
    sims = X_all @ q
    return np.argmax(sims)

# Test set
Q = X_all[np.random.choice(N, size=300, replace=False)]

correct = 0
t0 = time.time()

for q in Q:
    gt = flat_search(q)
    pred = flat_search(q)
    if gt == pred:
        correct += 1

print("Flat Recall@1:", correct / len(Q))
print("Flat time:", round(time.time() - t0, 2), "seconds")

Flat Recall@1: 1.0
Flat time: 1.93 seconds


In [7]:
def spectral_search_two_stage(q, K=32, early_exit=True):
    """
    Two-stage spectral ANN over Hadamard bags.

    Stage 1: For each bag, compute s_y (n dot products) and compute a SAFE upper bound:
             U_b = ||s_y||_2 / sqrt(n).  (Because max_i (H^T s_y)_i <= ||s_y||_2 / sqrt(n))

             Rank bags by U_b and keep top-K.

    Stage 2: For top-K bags, do progressive reconstruction in best-first |s_y| order.
             Optionally stop early within a bag using L2 tail bound once the bag's top-1 is stable enough
             (or simply run full reconstruction for those K bags for maximum correctness).

    Returns: (global_idx, stats_dict)
    """
    q8, qs = quantize_int8_blockwise(q, block=32)

    # ---- Stage 1: score all bags with a safe bound ----
    Ub = np.empty(len(index), dtype=np.float32)
    Sy_cache = [None] * len(index)  # cache s_y for the K selected bags

    inv_sqrt_n = 1.0 / sqrt(bag_size)

    for b, (Yq, Ys) in enumerate(index):
        s_y = np.array([
            dot_q8_v8_blockwise(q8, qs, Yq[k], Ys[k], block=32)
            for k in range(bag_size)
        ], dtype=np.float32)

        # safe upper bound on max similarity in this bag
        Ub[b] = np.linalg.norm(s_y) * inv_sqrt_n
        Sy_cache[b] = s_y  # caching all is ok at 1562 bags; can optimise later

    # pick top-K bags
    K = min(K, len(index))
    top_bags = np.argsort(-Ub)[:K]

    # ---- Stage 2: resolve only top-K bags ----
    global_best = -1e9
    global_idx = -1

    channels_used = 0
    bags_fully_resolved = 0

    for b in top_bags:
        s_y = Sy_cache[b]
        order = np.argsort(-np.abs(s_y))
        s_y_ord = s_y[order]

        # progressive inverse reconstruction
        partial = np.zeros(bag_size, dtype=np.float32)

        # L2 tail energy for within-bag early exit (optional)
        tail_energy = float(np.sum(s_y_ord**2))

        best_local = -1e9
        best_local_idx = -1

        if not early_exit:
            # full reconstruction for this bag
            for m in range(bag_size):
                k = order[m]
                partial += Hn[:, k] * s_y[k]
                channels_used += 1
            best_local_idx = int(np.argmax(partial))
            best_local = float(partial[best_local_idx])
            bags_fully_resolved += 1
        else:
            # progressive with within-bag stability bound
            # stop when top-1 margin > 2 * R2 (energy tail bound)
            sqrt_n = sqrt(bag_size)

            for m in range(bag_size):
                k = order[m]
                partial += Hn[:, k] * s_y[k]
                channels_used += 1

                tail_energy -= float(s_y[k]**2)
                R2 = sqrt(max(tail_energy, 0.0)) / sqrt_n

                # current top-1 and top-2
                idx_sorted = np.argsort(-partial)
                i1, i2 = int(idx_sorted[0]), int(idx_sorted[1])
                margin = float(partial[i1] - partial[i2])

                # early-exit: top1 cannot change within this bag
                if margin > 2.0 * R2:
                    best_local_idx = i1
                    best_local = float(partial[i1])
                    break

            # if never broke, we effectively resolved it fully
            if best_local_idx < 0:
                best_local_idx = int(np.argmax(partial))
                best_local = float(partial[best_local_idx])
                bags_fully_resolved += 1

        # update global best across bags
        if best_local > global_best:
            global_best = best_local
            global_idx = b * bag_size + best_local_idx

    stats = {
        "K": K,
        "channels_used": channels_used,
        "bags_fully_resolved": bags_fully_resolved,
        "top_bag_bound": float(Ub[top_bags[0]]),
        "bound_median_topK": float(np.median(Ub[top_bags])),
    }
    return global_idx, stats

In [8]:
# ============================================================
# Evaluate Two-Stage Spectral ANN (Recall@1 vs K)
# ============================================================

def evaluate_two_stage(K=32, early_exit=True):
    correct = 0
    ch_used = 0
    fully = 0

    t0 = time.time()
    for q in Q:
        gt = flat_search(q)
        pred, st = spectral_search_two_stage(q, K=K, early_exit=early_exit)
        if gt == pred:
            correct += 1
        ch_used += st["channels_used"]
        fully += st["bags_fully_resolved"]

    dt = time.time() - t0
    return {
        "K": K,
        "early_exit": early_exit,
        "recall@1": correct / len(Q),
        "avg_channels": ch_used / len(Q),
        "avg_full_bags": fully / len(Q),
        "time_s": dt,
    }

for K in [8, 16, 32, 64, 128, 256]:
    res = evaluate_two_stage(K=K, early_exit=True)
    print(res)

{'K': 8, 'early_exit': True, 'recall@1': 0.87, 'avg_channels': 121.85666666666667, 'avg_full_bags': 0.0, 'time_s': 724.2387630939484}
{'K': 16, 'early_exit': True, 'recall@1': 0.88, 'avg_channels': 249.82333333333332, 'avg_full_bags': 0.0, 'time_s': 725.2589881420135}
{'K': 32, 'early_exit': True, 'recall@1': 0.8733333333333333, 'avg_channels': 507.23, 'avg_full_bags': 0.0033333333333333335, 'time_s': 712.7815494537354}
{'K': 64, 'early_exit': True, 'recall@1': 0.8733333333333333, 'avg_channels': 1026.8866666666668, 'avg_full_bags': 0.01, 'time_s': 720.0615191459656}
{'K': 128, 'early_exit': True, 'recall@1': 0.8733333333333333, 'avg_channels': 2072.47, 'avg_full_bags': 0.02, 'time_s': 731.7740693092346}
{'K': 256, 'early_exit': True, 'recall@1': 0.8733333333333333, 'avg_channels': 4197.68, 'avg_full_bags': 0.02, 'time_s': 730.3368244171143}


In [9]:
correct = 0
total_pruned = 0
total_channels = 0

t0 = time.time()

for q in Q:

    gt = flat_search(q)
    pred, pruned, channels = spectral_search(q)

    if gt == pred:
        correct += 1

    total_pruned += pruned
    total_channels += channels

print("\n=== Spectral ANN Results ===")
print("Recall@1:", correct / len(Q))
print("Avg bags pruned:", total_pruned / len(Q))
print("Avg channels used per query:", total_channels / len(Q))
print("Total time:", round(time.time() - t0, 2), "seconds")

NameError: name 'spectral_search' is not defined