# SIFT1M Trajectory Router — v2 (Filtered GT Recall)

**Fix applied:** Ground truth is filtered to only count neighbours that exist
within BASE_LIMIT. Queries where fewer than 3 valid GT neighbours exist are
skipped entirely. This gives a fair evaluation of routing quality on the
partial index, rather than being penalised for neighbours that were never loaded.

**Other fixes retained from v1:**
- L2-consistent node scoring: `2·q·mean - ||mean||²`
- Precomputed node squared norms
- Increased candidate budget

In [1]:
import os, time
import numpy as np

SIFT_DIR  = r"G:\train_jw\datasets\sift1m"
BASE_NPY  = os.path.join(SIFT_DIR, "sift_base.npy")
QUERY_NPY = os.path.join(SIFT_DIR, "sift_query.npy")
GT_IVEC   = os.path.join(SIFT_DIR, "sift_groundtruth.ivecs")

assert os.path.exists(BASE_NPY),  BASE_NPY
assert os.path.exists(QUERY_NPY), QUERY_NPY
assert os.path.exists(GT_IVEC),   GT_IVEC

BASE_LIMIT  = 200_000
QUERY_LIMIT = 2_000
TOPK_EVAL   = 10
MIN_VALID_GT = 3   # skip query if fewer than this many GT neighbours are in-index

NLIST        = 256
NPROBE       = 8
KMEANS_TRAIN = 50_000

BAG_SIZE    = 32
assert (BAG_SIZE & (BAG_SIZE-1)) == 0
TREE_LEVELS = int(np.log2(BAG_SIZE))  # 5

TOP_BAGS_PER_LIST = 32
BEAM              = 2
CANDS_PER_LIST    = 128
FINAL_RERANK      = 500

SEED = 0
rng  = np.random.default_rng(SEED)

def read_ivecs(fname):
    a = np.fromfile(fname, dtype=np.int32)
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:]

print("Config OK")

Config OK


In [2]:
print("Loading SIFT...")
X  = np.load(BASE_NPY,  mmap_mode="r")[:BASE_LIMIT].astype(np.float32)
Q  = np.load(QUERY_NPY, mmap_mode="r")[:QUERY_LIMIT].astype(np.float32)
GT = read_ivecs(GT_IVEC)[:QUERY_LIMIT, :100]

n, d = X.shape
print("X:", X.shape, "Q:", Q.shape, "GT:", GT.shape)

# --- Filter GT per query: keep only indices < BASE_LIMIT ---
# For each query store the valid GT neighbours and whether the query is usable
GT_valid = []
query_usable = []
for i in range(len(Q)):
    valid = [int(g) for g in GT[i] if g < BASE_LIMIT]
    GT_valid.append(valid)
    query_usable.append(len(valid) >= MIN_VALID_GT)

n_usable = sum(query_usable)
mean_valid = np.mean([len(v) for v in GT_valid])
print(f"Usable queries (>= {MIN_VALID_GT} valid GT neighbours in index): {n_usable}/{len(Q)}")
print(f"Mean valid GT neighbours per query: {mean_valid:.1f} / {TOPK_EVAL}")
print(f"Expected recall ceiling (if routing perfect): ~{mean_valid/TOPK_EVAL:.2f}")

Loading SIFT...
X: (200000, 128) Q: (2000, 128) GT: (2000, 100)
Usable queries (>= 3 valid GT neighbours in index): 1995/2000
Mean valid GT neighbours per query: 21.7 / 10
Expected recall ceiling (if routing perfect): ~2.17


In [3]:
from sklearn.cluster import MiniBatchKMeans

print("Training IVF centroids...")
t0 = time.time()

train_idx = rng.choice(n, size=min(KMEANS_TRAIN, n), replace=False)
Xtr = np.array(X[train_idx], dtype=np.float32)

kmeans = MiniBatchKMeans(
    n_clusters=NLIST, batch_size=4096,
    n_init=1, max_iter=200,
    random_state=SEED, verbose=0
)
kmeans.fit(Xtr)
C    = kmeans.cluster_centers_.astype(np.float32)
C_sq = np.sum(C**2, axis=1)

print("Centroid train time:", round(time.time()-t0, 2), "s")

Training IVF centroids...
Centroid train time: 3.99 s


In [4]:
print("Assigning vectors to IVF lists...")
t0 = time.time()

assign = np.empty(n, dtype=np.int32)
chunk  = 50_000
for i in range(0, n, chunk):
    Xc   = np.array(X[i:i+chunk], dtype=np.float32)
    x2   = np.sum(Xc**2, axis=1, keepdims=True)
    dist = x2 + C_sq[None, :] - 2.0 * (Xc @ C.T)
    assign[i:i+len(Xc)] = np.argmin(dist, axis=1).astype(np.int32)

lists = [[] for _ in range(NLIST)]
for i, li in enumerate(assign):
    lists[int(li)].append(i)

list_sizes = np.array([len(l) for l in lists])
print("Assign time:", round(time.time()-t0, 2), "s")
print(f"List sizes — mean: {list_sizes.mean():.0f}  max: {list_sizes.max()}  min: {list_sizes.min()}")

Assigning vectors to IVF lists...
Assign time: 0.41 s
List sizes — mean: 781  max: 1855  min: 246


In [5]:
print("Building trajectory tree index...")
t0 = time.time()

traj = []
total_bags = 0

for li in range(NLIST):
    idxs = np.array(lists[li], dtype=np.int32)
    nb   = len(idxs) // BAG_SIZE
    if nb == 0:
        traj.append(None)
        continue

    idxs       = idxs[:nb*BAG_SIZE]
    bag_leaves = idxs.reshape(nb, BAG_SIZE)
    V          = np.array(X[idxs], dtype=np.float32).reshape(nb, BAG_SIZE, d)

    nodes    = []
    nodes_sq = []

    for L in range(TREE_LEVELS):
        nodesL = 2**L
        group  = BAG_SIZE // nodesL
        M  = np.empty((nb, nodesL, d), dtype=np.float32)
        M2 = np.empty((nb, nodesL),    dtype=np.float32)
        for j in range(nodesL):
            seg        = V[:, j*group:(j+1)*group, :].mean(axis=1)
            M[:, j, :] = seg
            M2[:, j]   = np.sum(seg**2, axis=1)
        nodes.append(M)
        nodes_sq.append(M2)

    traj.append({"bag_leaves": bag_leaves, "nodes": nodes, "nodes_sq": nodes_sq})
    total_bags += nb

print("Tree build time:", round(time.time()-t0, 2), "s")
print("Total bags:", total_bags)

Building trajectory tree index...
Tree build time: 0.36 s
Total bags: 6134


In [6]:
def centroid_shortlist(qv, nprobe=NPROBE):
    q2   = float(np.dot(qv, qv))
    dist = q2 + C_sq - 2.0 * (C @ qv)
    return np.argsort(dist)[:nprobe].astype(np.int32)


def route_within_list(qv, q2, li,
                      top_bags=TOP_BAGS_PER_LIST,
                      beam=BEAM,
                      cands=CANDS_PER_LIST):
    t = traj[li]
    if t is None:
        return np.empty((0,), dtype=np.int32)

    bag_leaves = t["bag_leaves"]
    nodes      = t["nodes"]
    nodes_sq   = t["nodes_sq"]
    nb         = bag_leaves.shape[0]

    # Stage A: score all bags by L2-proxy to root mean
    # score = 2*q·mean - ||mean||^2  (higher = closer in L2)
    root    = nodes[0][:, 0, :]   # (nb, d)
    root_sq = nodes_sq[0][:, 0]   # (nb,)
    scores  = 2.0 * (root @ qv) - root_sq
    bag_idx = np.argsort(-scores)[:min(top_bags, nb)]

    # Stage B: beam descent
    cand_global = []
    for b in bag_idx:
        cand_nodes = [0]
        for L in range(1, TREE_LEVELS):
            next_nodes = []
            for parent in cand_nodes:
                c0, c1 = parent*2, parent*2+1
                s0 = 2.0*float(np.dot(nodes[L][b, c0], qv)) - float(nodes_sq[L][b, c0])
                s1 = 2.0*float(np.dot(nodes[L][b, c1], qv)) - float(nodes_sq[L][b, c1])
                next_nodes.append((c0, s0))
                next_nodes.append((c1, s1))
            next_nodes.sort(key=lambda x: -x[1])
            cand_nodes = [i for i, _ in next_nodes[:beam]]

        group_size = BAG_SIZE // (2**(TREE_LEVELS-1))
        for node in cand_nodes:
            start = node * group_size
            for i in range(start, start + group_size):
                cand_global.append(int(bag_leaves[b, i]))

    if not cand_global:
        return np.empty((0,), dtype=np.int32)

    cand_global = np.unique(np.array(cand_global, dtype=np.int32))
    if len(cand_global) > cands:
        V    = np.array(X[cand_global], dtype=np.float32)
        dist = np.sum((V - qv)**2, axis=1)
        cand_global = cand_global[np.argsort(dist)[:cands]]

    return cand_global


def search_query(qv, nprobe=NPROBE, top_bags=TOP_BAGS_PER_LIST, beam=BEAM):
    q2          = float(np.dot(qv, qv))
    probe_lists = centroid_shortlist(qv, nprobe=nprobe)

    all_cands = []
    for li in probe_lists:
        c = route_within_list(qv, q2, int(li), top_bags=top_bags, beam=beam)
        if len(c) > 0:
            all_cands.append(c)

    if not all_cands:
        return np.empty((0,), dtype=np.int32)

    all_cands = np.unique(np.concatenate(all_cands))
    if len(all_cands) > FINAL_RERANK:
        V    = np.array(X[all_cands], dtype=np.float32)
        dist = np.sum((V - qv)**2, axis=1)
        all_cands = all_cands[np.argsort(dist)[:FINAL_RERANK]]

    V    = np.array(X[all_cands], dtype=np.float32)
    dist = np.sum((V - qv)**2, axis=1)
    return all_cands[np.argsort(dist)[:TOPK_EVAL]]


def recall_at_k_filtered(found, valid_gt, k=10):
    """
    Recall against only the GT neighbours that exist in the index.
    Denominator = min(k, len(valid_gt)) so it stays in [0,1].
    """
    denom = min(k, len(valid_gt))
    if denom == 0:
        return None
    return len(set(found[:k].tolist()) & set(valid_gt)) / denom


print("Functions defined OK")

Functions defined OK


In [7]:
print("Running main evaluation...")
t0 = time.time()

recalls  = []
skipped  = 0

for i in range(len(Q)):
    if not query_usable[i]:
        skipped += 1
        continue
    top = search_query(Q[i])
    r   = recall_at_k_filtered(top, GT_valid[i], k=TOPK_EVAL)
    if r is not None:
        recalls.append(r)

print(f"Queries evaluated: {len(recalls)}  (skipped: {skipped})")
print(f"Mean recall@{TOPK_EVAL} (filtered GT): {np.mean(recalls):.3f}")
print(f"Eval time: {round(time.time()-t0, 2)} s")

Running main evaluation...
Queries evaluated: 1995  (skipped: 5)
Mean recall@10 (filtered GT): 0.903
Eval time: 16.48 s


In [8]:
# Sweep over nprobe and top_bags
print(f"{'nprobe':>8} {'top_bags':>10} {'beam':>6} {'recall@10':>12} {'n_eval':>8} {'time_s':>8}")
print("-" * 58)

for nprobe in [4, 8, 16]:
    for top_bags in [16, 32, 64]:
        t0 = time.time()
        recalls_sw = []
        for i in range(len(Q)):
            if not query_usable[i]:
                continue
            top = search_query(Q[i], nprobe=nprobe, top_bags=top_bags, beam=BEAM)
            r   = recall_at_k_filtered(top, GT_valid[i], k=TOPK_EVAL)
            if r is not None:
                recalls_sw.append(r)
        dt = time.time() - t0
        print(f"{nprobe:>8} {top_bags:>10} {BEAM:>6} {np.mean(recalls_sw):>12.3f} {len(recalls_sw):>8} {dt:>8.1f}")

  nprobe   top_bags   beam    recall@10   n_eval   time_s
----------------------------------------------------------
       4         16      2        0.708     1995      5.4
       4         32      2        0.812     1995      8.3
       4         64      2        0.820     1995      8.8
       8         16      2        0.814     1995     10.7
       8         32      2        0.903     1995     16.3
       8         64      2        0.908     1995     17.4
      16         16      2        0.856     1995     20.9
      16         32      2        0.930     1995     31.7
      16         64      2        0.934     1995     33.8
