Skip to content

Commit 553f8ef

Browse files
Create approx_nearest_neighbours.py
1 parent 7530a41 commit 553f8ef

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""
2+
Approximate Nearest Neighbor (ANN) Search
3+
https://en.wikipedia.org/wiki/Nearest_neighbor_search#Approximate_nearest_neighbor
4+
5+
ANN search finds "close enough" vectors instead of the exact nearest neighbor,
6+
which makes it much faster for large datasets.
7+
8+
This implementation uses a simple **random projection hashing** method.
9+
Steps:
10+
1. Generate random hyperplanes to hash vectors into buckets.
11+
2. Place dataset vectors into buckets.
12+
3. For a query vector, look into its bucket (and maybe nearby buckets).
13+
4. Return the approximate nearest neighbor from those candidates.
14+
15+
Each result contains:
16+
1. The nearest (approximate) vector.
17+
2. Its distance from the query vector.
18+
"""
19+
from __future__ import annotations
20+
21+
import math
22+
from collections import defaultdict
23+
24+
import numpy as np
25+
def euclidean(input_a: np.ndarray, input_b: np.ndarray) -> float:
26+
"""
27+
Calculates Euclidean distance between two vectors.
28+
>>> euclidean(np.array([0]), np.array([1]))
29+
1.0
30+
>>> euclidean(np.array([1, 2]), np.array([1, 5]))
31+
3.0
32+
"""
33+
return math.sqrt(sum(pow(a - b, 2) for a, b in zip(input_a, input_b)))
34+
35+
36+
class ANN:
37+
"""
38+
Approximate Nearest Neighbor using random projection hashing.
39+
"""
40+
41+
def __init__(self, dataset: np.ndarray, n_planes: int = 5, seed: int = 42) -> None:
42+
"""
43+
:param dataset: ndarray of shape (n_samples, n_features)
44+
:param n_planes: number of random hyperplanes for hashing
45+
:param seed: random seed for reproducibility
46+
"""
47+
self.dataset = dataset
48+
self.n_planes = n_planes
49+
rng = np.random.default_rng(seed)
50+
self.planes = rng.standard_normal((n_planes, dataset.shape[1]))
51+
self.buckets: dict[str, list[np.ndarray]] = defaultdict(list)
52+
self._build_index()
53+
54+
def _hash_vector(self, vec: np.ndarray) -> str:
55+
"""
56+
Hash a vector based on which side of each hyperplane it falls on.
57+
Returns a bit string.
58+
59+
>>> dataset = np.array([[1, 2]])
60+
>>> ann = ANN(dataset, n_planes=2, seed=0)
61+
>>> h = ann._hash_vector(np.array([1, 2]))
62+
>>> isinstance(h, str)
63+
True
64+
>>> len(h) == ann.n_planes
65+
True
66+
"""
67+
signs = (vec @ self.planes.T) >= 0
68+
return "".join(["1" if s else "0" for s in signs])
69+
70+
def _build_index(self) -> None:
71+
"""
72+
Build hash buckets for all dataset vectors.
73+
74+
>>> dataset = np.array([[0, 0], [1, 1]])
75+
>>> ann = ANN(dataset, n_planes=2, seed=0)
76+
>>> all(isinstance(k, str) for k in ann.buckets.keys())
77+
True
78+
>>> sum(len(v) for v in ann.buckets.values()) == len(dataset)
79+
True
80+
"""
81+
for vec in self.dataset:
82+
h = self._hash_vector(vec)
83+
self.buckets[h].append(vec)
84+
85+
def query(self, query_vectors: np.ndarray) -> list[list[list[float] | float]]:
86+
"""
87+
Find approximate nearest neighbor for query vector(s).
88+
:param query_vectors: ndarray of shape (m, n_features)
89+
:return: list of [nearest_vector, distance]
90+
91+
>>> dataset = np.array([[0, 0], [1, 1], [2, 2], [10, 10]])
92+
>>> ann = ANN(dataset, n_planes=4, seed=0)
93+
>>> ann.query(np.array([[0, 1]])) # doctest: +NORMALIZE_WHITESPACE
94+
[[[0, 0], 1.0]]
95+
"""
96+
results = []
97+
for vec in query_vectors:
98+
h = self._hash_vector(vec)
99+
candidates = self.buckets[h]
100+
101+
if not candidates: # fallback: search entire dataset
102+
candidates = self.dataset
103+
104+
# Approximate NN search among candidates
105+
best_vec = candidates[0]
106+
best_dist = euclidean(vec, best_vec)
107+
for cand in candidates[1:]:
108+
d = euclidean(vec, cand)
109+
if d < best_dist:
110+
best_vec, best_dist = cand, d
111+
results.append([best_vec.tolist(), best_dist])
112+
return results
113+
114+
115+
if __name__ == "__main__":
116+
import doctest
117+
doctest.testmod()
118+

0 commit comments

Comments
 (0)