In [None]:
!pip -q install numpy rich
%mkdir -p quake_adaptive_demo
%cd quake_adaptive_demo


/content/quake_adaptive_demo


In [None]:
%%writefile quake_min.py
from __future__ import annotations
import math, time, random
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Optional
import numpy as np

def l2(a: np.ndarray, b: np.ndarray) -> float:
    return float(np.sum((a - b) ** 2))

def l2_batch(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    x2 = np.sum(x*x, axis=1, keepdims=True)
    y2 = np.sum(y*y, axis=1, keepdims=True).T
    return x2 + y2 - 2 * (x @ y.T)

def topk_indices(arr: np.ndarray, k: int) -> np.ndarray:
    if k >= arr.shape[0]:
        return np.argsort(arr)
    idx = np.argpartition(arr, kth=k)[:k]
    return idx[np.argsort(arr[idx])]

def kmeans(x: np.ndarray, k: int, iters: int = 15, seed: int = 0) -> Tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    n, d = x.shape
    centroids = x[rng.choice(n, size=k, replace=False)].copy()
    assign = np.zeros(n, dtype=np.int32)
    for _ in range(iters):
        dists = l2_batch(x, centroids)
        assign = np.argmin(dists, axis=1)
        for j in range(k):
            mask = assign == j
            if np.any(mask):
                centroids[j] = np.mean(x[mask], axis=0)
            else:
                centroids[j] = x[rng.integers(0, n)]
    return centroids, assign

@dataclass
class BasePartition:
    vecs: np.ndarray
    ids: np.ndarray
    centroid: np.ndarray
    hits: int = 0
    last_split_at: int = 0

@dataclass
class CoarseCell:
    centroid: np.ndarray
    base_ids: List[int] = field(default_factory=list)

class AdaptiveIVF:
    def __init__(self, dim: int, k_coarse: int = 16, k_base: int = 4, seed: int = 0):
        self.dim = dim
        self.k_coarse = k_coarse
        self.k_base = k_base
        self.rng = np.random.default_rng(seed)
        self.coarse: List[CoarseCell] = []
        self.base_parts: List[BasePartition] = []
        self.id2loc: Dict[int, Tuple[int,int]] = {}
        self.query_counter = 0
        self.split_size = 3000
        self.merge_size = 300
        self.hot_split_multiplier = 1.5
        self.cold_merge_multiplier = 1.5

    def build(self, x: np.ndarray, ids: Optional[np.ndarray] = None, coarse_k: Optional[int] = None, base_k: Optional[int] = None):
        if ids is None:
            ids = np.arange(x.shape[0], dtype=np.int64)
        if coarse_k: self.k_coarse = coarse_k
        if base_k: self.k_base = base_k
        coarse_centroids, coarse_assign = kmeans(x, self.k_coarse, iters=12, seed=42)
        self.coarse = [CoarseCell(c) for c in coarse_centroids]
        for c_id in range(self.k_coarse):
            mask_c = coarse_assign == c_id
            if not np.any(mask_c): continue
            x_c = x[mask_c]
            ids_c = ids[mask_c]
            kb = min(self.k_base, max(1, x_c.shape[0] // 50))
            base_centroids, base_assign = kmeans(x_c, kb, iters=10, seed=123 + c_id)
            for b in range(kb):
                mask_b = base_assign == b
                if not np.any(mask_b): continue
                xp = x_c[mask_b]
                ip = ids_c[mask_b]
                bp = BasePartition(vecs=xp, ids=ip, centroid=np.mean(xp, axis=0))
                self.base_parts.append(bp)
                bidx = len(self.base_parts) - 1
                self.coarse[c_id].base_ids.append(bidx)
                for off, _id in enumerate(ip):
                    self.id2loc[int(_id)] = (bidx, off)

    def insert(self, v: np.ndarray, _id: int):
        c_d = np.array([l2(v, c.centroid) for c in self.coarse])
        c_idx = int(np.argmin(c_d))
        base_idxs = self.coarse[c_idx].base_ids
        if not base_idxs:
            bp = BasePartition(vecs=v[None,:], ids=np.array([_id]), centroid=v.copy())
            self.base_parts.append(bp)
            bidx = len(self.base_parts) - 1
            self.coarse[c_idx].base_ids.append(bidx)
            self.id2loc[_id] = (bidx, 0)
            return
        dists = [l2(v, self.base_parts[b].centroid) for b in base_idxs]
        bidx = int(base_idxs[int(np.argmin(dists))])
        bp = self.base_parts[bidx]
        bp.vecs = np.vstack([bp.vecs, v])
        bp.ids = np.append(bp.ids, _id)
        bp.centroid = np.mean(bp.vecs, axis=0)
        self.id2loc[_id] = (bidx, bp.vecs.shape[0]-1)

    def delete(self, _id: int):
        loc = self.id2loc.get(int(_id))
        if loc is None: return
        bidx, off = loc
        bp = self.base_parts[bidx]
        last = bp.vecs.shape[0]-1
        bp.vecs[[off, last]] = bp.vecs[[last, off]]
        bp.ids[[off, last]] = bp.ids[[last, off]]
        bp.vecs = bp.vecs[:-1]
        bp.ids = bp.ids[:-1]
        if bp.vecs.shape[0] > 0:
            bp.centroid = np.mean(bp.vecs, axis=0)
        if off < bp.ids.shape[0]:
            self.id2loc[int(bp.ids[off])] = (bidx, off)
        self.id2loc.pop(int(_id), None)

    def _partition_scores(self, q: np.ndarray):
        centroids = np.array([bp.centroid for bp in self.base_parts])
        d2 = l2_batch(q[None,:], centroids)[0]
        size = np.array([bp.vecs.shape[0] for bp in self.base_parts])
        tau = np.median(np.sqrt(d2)) + 1e-6
        logits = -np.sqrt(d2) / (tau + 1e-6) + 0.5*np.log(size + 1.0)
        m = np.max(logits)
        p = np.exp(logits - m); p = p / np.sum(p)
        out = [(i, float(p[i]), int(size[i])) for i in range(len(self.base_parts))]
        out.sort(key=lambda t: t[1], reverse=True)
        return out

    def _choose_nprobe(self, probs, target_recall: float = 0.9, max_probe: int = 64) -> int:
        cum = 0.0; n = 0
        for _, pi, _ in probs:
            cum += pi; n += 1
            if cum >= target_recall: return n
            if n >= max_probe: return n
        return n

    def search(self, q: np.ndarray, k: int = 10, target_recall: float = 0.9, exact_ref: Optional[np.ndarray] = None):
        self.query_counter += 1
        probs = self._partition_scores(q)
        nprobe = self._choose_nprobe(probs, target_recall=target_recall)
        cand_ids, cand_vecs = [], []
        for i in range(nprobe):
            bidx = probs[i][0]
            bp = self.base_parts[bidx]
            bp.hits += 1
            cand_ids.append(bp.ids); cand_vecs.append(bp.vecs)
        if len(cand_ids) == 0:
            return np.array([]), np.array([]), {"nprobe": 0, "scanned": 0}
        C_ids = np.concatenate(cand_ids)
        C = np.vstack(cand_vecs)
        d2 = l2_batch(q[None,:], C)[0]
        topk = topk_indices(d2, min(k, d2.shape[0]))
        found_ids = C_ids[topk]; found_d2 = d2[topk]
        rec = None
        if exact_ref is not None and exact_ref.size > 0:
            inter = len(set(map(int, found_ids)) & set(map(int, exact_ref)))
            rec = inter / max(1, min(k, exact_ref.size))
        return found_ids, found_d2, {"nprobe": nprobe, "scanned": C.shape[0], "recall_at_k": rec}

    def maintain(self, hot_qps_window: int = 2000):
        for bidx, bp in enumerate(list(self.base_parts)):
            size = bp.vecs.shape[0]
            hotness = (bp.hits - bp.last_split_at)
            split_thresh = self.split_size / max(1.0, (hotness / hot_qps_window))
            split_thresh = max(self.split_size / self.hot_split_multiplier, min(self.split_size * 2, split_thresh))
            if size >= split_thresh and size >= 16:
                c, a = kmeans(bp.vecs, k=2, iters=8, seed=17 + bidx)
                mask0 = a == 0; mask1 = ~mask0
                if np.any(mask0) and np.any(mask1):
                    bp0 = BasePartition(vecs=bp.vecs[mask0], ids=bp.ids[mask0], centroid=np.mean(bp.vecs[mask0], axis=0))
                    bp1 = BasePartition(vecs=bp.vecs[mask1], ids=bp.ids[mask1], centroid=np.mean(bp.vecs[mask1], axis=0))
                    self.base_parts[bidx] = bp0
                    self.base_parts.append(bp1)
                    new_idx = len(self.base_parts) - 1
                    for off, _id in enumerate(bp0.ids): self.id2loc[int(_id)] = (bidx, off)
                    for off, _id in enumerate(bp1.ids): self.id2loc[int(_id)] = (new_idx, off)
                    self.base_parts[bidx].last_split_at = self.query_counter
                    self.base_parts[new_idx].last_split_at = self.query_counter
        tiny = [i for i, bp in enumerate(self.base_parts) if bp.vecs.shape[0] <= self.merge_size]
        used = set()
        for i in tiny:
            if i in used: continue
            ci = self.base_parts[i].centroid
            best_j, best_d = None, float('inf')
            for j in tiny:
                if j == i or j in used: continue
                cj = self.base_parts[j].centroid
                d = l2(ci, cj)
                if d < best_d:
                    best_d, best_j = d, j
            if best_j is None: continue
            bpi = self.base_parts[i]
            bpj = self.base_parts[best_j]
            vecs = np.vstack([bpi.vecs, bpj.vecs])
            ids = np.concatenate([bpi.ids, bpj.ids])
            bpi.vecs, bpi.ids, bpi.centroid = vecs, ids, np.mean(vecs, axis=0)
            for off, _id in enumerate(ids):
                self.id2loc[int(_id)] = (i, off)
            bpj.vecs = np.zeros((0, self.dim))
            bpj.ids = np.zeros((0,), dtype=np.int64)

    def exact_topk(self, q: np.ndarray, all_vecs: np.ndarray, all_ids: np.ndarray, k: int) -> np.ndarray:
        d2 = l2_batch(q[None,:], all_vecs)[0]
        topk = topk_indices(d2, min(k, d2.shape[0]))
        return all_ids[topk]


Writing quake_min.py


In [None]:
%%writefile run_demo.py
import time, math, random
from typing import Tuple
import numpy as np
from rich import print, box
from rich.table import Table
from quake_min import AdaptiveIVF

def make_dataset(n=50000, d=64, n_clusters=50, seed=7) -> Tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    centers = rng.normal(size=(n_clusters, d)) * 4.0
    sizes = rng.multinomial(n, [1/n_clusters]*n_clusters)
    xs = []
    for i, sz in enumerate(sizes):
        if sz == 0: continue
        pts = centers[i] + rng.normal(size=(sz, d))
        xs.append(pts)
    X = np.vstack(xs).astype(np.float32)
    ids = np.arange(X.shape[0], dtype=np.int64)
    return X, ids

def zipf_partition_sampler(P, alpha=1.2, seed=0):
    rng = np.random.default_rng(seed)
    ranks = np.arange(1, P+1)
    weights = 1 / (ranks ** alpha)
    weights = weights / weights.sum()
    def sample():
        return int(rng.choice(P, p=weights))
    return sample, weights

def main():
    n = 40000
    d = 64
    k_coarse = 16
    k_base = 4
    queries = 500
    k = 10
    target_recall = 0.9
    seed = 7

    print(f"[bold]Building dataset n={n}, d={d}[/bold]")
    X, ids = make_dataset(n=n, d=d, n_clusters=60, seed=seed)

    print(f"[bold]Building index (k_coarse={k_coarse}, k_base={k_base})[/bold]")
    idx = AdaptiveIVF(dim=d, k_coarse=k_coarse, k_base=k_base, seed=seed)
    t0 = time.time()
    idx.build(X, ids)
    build_s = time.time() - t0
    print(f"Build time: {build_s:.2f}s, base partitions: {len(idx.base_parts)}")

    P = len(idx.base_parts)
    sample_partition, weights = zipf_partition_sampler(P, alpha=1.1, seed=seed+1)
    rng = np.random.default_rng(seed+2)

    qlat, qrec, qnprobe, qscan = [], [], [], []
    print("[bold]Running queries with APS + maintenance...[/bold]")
    for t in range(1, queries+1):
        p = sample_partition()
        bp = idx.base_parts[p]
        if bp.vecs.shape[0] == 0:
            v = X[rng.integers(0, X.shape[0])]
        else:
            v = bp.vecs[rng.integers(0, bp.vecs.shape[0])]
        q = (v + rng.normal(size=v.shape)*0.1).astype(np.float32)
        exact_ids = idx.exact_topk(q, X, ids, k=k)
        t1 = time.time()
        found_ids, found_d2, meta = idx.search(q, k=k, target_recall=target_recall, exact_ref=exact_ids)
        dt = (time.time() - t1) * 1000.0
        qlat.append(dt); qnprobe.append(meta["nprobe"]); qscan.append(meta["scanned"]); qrec.append(meta["recall_at_k"] or 0.0)

        if t % 20 == 0:
            for _ in range(10):
                v_new = rng.normal(size=(d,)).astype(np.float32) * 0.5 + rng.normal(size=(d,)).astype(np.float32)
                new_id = int(ids.max() + rng.integers(1, 1000))
                idx.insert(v_new, new_id)
            for _ in range(10):
                del_id = int(ids[rng.integers(0, ids.shape[0])])
                idx.delete(del_id)

        if t % 50 == 0:
            idx.maintain()

        if t % 100 == 0:
            print(f"[dim].. q{t}: latency {dt:.2f} ms, nprobe {meta['nprobe']}, scanned {meta['scanned']}, recall@{k}={meta['recall_at_k']:.2f}[/dim]")

    def mean(x): return float(np.mean(x))
    def p(x,q): return float(np.percentile(x,q))

    tbl = Table(title="Adaptive IVF Demo — Summary", box=box.SIMPLE_HEAVY)
    tbl.add_column("Metric"); tbl.add_column("Value")
    tbl.add_row("Queries", str(queries))
    tbl.add_row("Avg latency (ms)", f"{mean(qlat):.2f}")
    tbl.add_row("p50 / p95 latency (ms)", f"{p(qlat,50):.2f} / {p(qlat,95):.2f}")
    tbl.add_row("Avg nprobe", f"{mean(qnprobe):.2f}")
    tbl.add_row("Avg vectors scanned", f"{mean(qscan):.0f}")
    tbl.add_row(f"Avg recall@{k}", f"{mean(qrec):.3f}")
    print(); print(tbl)

if __name__ == "__main__":
    main()


Writing run_demo.py


In [4]:
!python run_demo.py

[1mBuilding dataset [0m[1;33mn[0m[1m=[0m[1;36m40000[0m[1m, [0m[1;33md[0m[1m=[0m[1;36m64[0m
[1mBuilding index [0m[1m([0m[1;33mk_coarse[0m[1m=[0m[1;36m16[0m[1m, [0m[1;33mk_base[0m[1m=[0m[1;36m4[0m[1m)[0m
Build time: [1;36m0.[0m54s, base partitions: [1;36m64[0m
[1mRunning queries with APS + maintenance[0m[1;33m...[0m
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(
[2m.. q100: latency [0m[1;2;36m12.05[0m[2m ms, nprobe [0m[1;2;36m64[0m[2m, scanned [0m[1;2;36m36556[0m[2m, recall@[0m[1;2;36m10[0m[2m=[0m[1;2;36m1[0m[1;2;36m.00[0m
[2m.. q200: latency [0m[1;2;36m18.67[0m[2m ms, nprobe [0m[1;2;36m64[0m[2m, scanned [0m[1;2;36m35206[0m[2m, recall@[0m[1;2;36m10[0m[2m=[0m[1;2;36m1[0m[1;2;36m.00[0m
[2m.. q300: latency [0m[1;2;36m12.21[0m[2m ms, nprobe [0m[1;2;36m64[0m[2m, scanned [0m[1;2;36m35215[0m[2m, recall@[0m[1;2;36m10[0m[2m=[0m[1;2;36m1[0m[1;2;36m.00[0m
[2m.. 