In [59]:
import numpy as np
from queue import PriorityQueue
import time

number_of_samples = 10000
N_max = 16
d = 256

class Node:
    def __init__(self, node_id):
        self.node_id = node_id
        self.neighbors = []
    
    def get_neighbors(self):
        return self.neighbors
    
    def add_neighbor(self, neighbor_id):
        if neighbor_id not in self.neighbors:
            self.neighbors.append(neighbor_id)
    
    def remove_neighbor(self, neighbor_id):
        if neighbor_id in self.neighbors:
            self.neighbors.remove(neighbor_id)


class Edge:
    def __init__(self, start_id, end_id):
        assert start_id != end_id, "No self loops"
        self.start_id = start_id
        self.end_id = end_id


class Graph:
    def __init__(self, nodes=None, edges=None):
        self.nodes = {}
        self.edges = edges if edges is not None else []
        
        if nodes:
            for node in nodes:
                self.nodes[node.node_id] = node

    def add_node(self, node):
        self.nodes[node.node_id] = node

    def add_edge(self, edge):
        self.edges.append(edge)
        if edge.start_id in self.nodes:
            self.nodes[edge.start_id].add_neighbor(edge.end_id)

    def add_bi_directional_edge(self, start_id, end_id):
        self.add_edge(Edge(start_id, end_id))
        self.add_edge(Edge(end_id, start_id))

    def get_node(self, node_id):
        return self.nodes.get(node_id, None)


class HNSW:
    def __init__(self, sample_nr, n_max, ef_construction=200):
        self.samples = np.random.random((sample_nr, d))
        self.sample_nr = sample_nr
        self.n_max = n_max
        self.ef_construction = ef_construction
        self.entry_point_id = None
        self.initialize_graph()

    def draw_levels(self):
        seeds = np.random.uniform(low=0, high=1, size=self.sample_nr)
        m_L = 1 / np.log(self.n_max)
        return np.floor(-np.log(seeds) * m_L).astype(np.int32)

    def setup_level_graphs(self):
        self.level_graphs = [None] * (self.max_levels.max() + 1)
        for level in range(self.max_levels.max() + 1):
            self.level_graphs[level] = Graph([], [])
        
    def distance(self, id1, id2):
        return np.linalg.norm(self.samples[id1] - self.samples[id2])
    
    def distance_vector(self, vector, id2):
        return np.linalg.norm(vector - self.samples[id2])

    def greedy_search(self, query_id, entry_point_id, stop_at):
        entry_point_max_level = self.max_levels[entry_point_id]
        assert stop_at <= entry_point_max_level, "Entry point started below stop_at level"

        closest_match = entry_point_id
        min_distance = self.distance(closest_match, query_id)
        
        for level in range(entry_point_max_level, stop_at - 1, -1):
            change = True
            while change:
                node = self.level_graphs[level].get_node(closest_match)
                if node is None:
                    break
                    
                neighbors = node.get_neighbors()
                if len(neighbors) == 0:
                    change = False
                    continue
    
                closest_neighbor_info = min(
                    [(n_id, self.distance(query_id, n_id)) for n_id in neighbors], 
                    key=lambda pair: pair[1]
                )

                if closest_neighbor_info[1] < min_distance:
                    closest_match = closest_neighbor_info[0]
                    min_distance = closest_neighbor_info[1]
                else:
                    change = False

        return closest_match
    
    def greedy_search_vector(self, query_vector, entry_point_id, stop_at, stats=None):
        entry_point_max_level = self.max_levels[entry_point_id]
        assert stop_at <= entry_point_max_level, "Entry point started below stop_at level"

        closest_match = entry_point_id
        min_distance = self.distance_vector(query_vector, closest_match)
        
        for level in range(entry_point_max_level, stop_at - 1, -1):
            if stats is not None:
                stats['levels_visited'].append(level)
            
            change = True
            while change:
                node = self.level_graphs[level].get_node(closest_match)
                if node is None:
                    break
                    
                neighbors = node.get_neighbors()
                if len(neighbors) == 0:
                    change = False
                    continue
                
                if stats is not None:
                    stats['distance_computations'] += len(neighbors)
    
                closest_neighbor_info = min(
                    [(n_id, self.distance_vector(query_vector, n_id)) for n_id in neighbors], 
                    key=lambda pair: pair[1]
                )

                if closest_neighbor_info[1] < min_distance:
                    closest_match = closest_neighbor_info[0]
                    min_distance = closest_neighbor_info[1]
                else:
                    change = False

        return closest_match

    def beam_search(self, query_id, entry_point_ids, level):
        candidates = PriorityQueue()
        visited = set()
        results = []
        
        for ep_id in entry_point_ids if isinstance(entry_point_ids, list) else [entry_point_ids]:
            dist = self.distance(query_id, ep_id)
            candidates.put((dist, ep_id))
            visited.add(ep_id)
            results.append((ep_id, dist))
        
        while not candidates.empty():
            current_dist, current_id = candidates.get()
            
            if len(results) >= self.ef_construction:
                furthest_dist = max(results, key=lambda x: x[1])[1]
                if current_dist > furthest_dist:
                    break
            
            node = self.level_graphs[level].get_node(current_id)
            if node is None:
                continue
                
            for neighbor_id in node.get_neighbors():
                if neighbor_id not in visited:
                    visited.add(neighbor_id)
                    dist = self.distance(query_id, neighbor_id)
                    
                    if len(results) < self.ef_construction:
                        candidates.put((dist, neighbor_id))
                        results.append((neighbor_id, dist))
                    else:
                        furthest_dist = max(results, key=lambda x: x[1])[1]
                        if dist < furthest_dist:
                            candidates.put((dist, neighbor_id))
                            results.append((neighbor_id, dist))
                            results = sorted(results, key=lambda x: x[1])[:self.ef_construction]
        
        results = sorted(results, key=lambda x: x[1])
        return results[:self.n_max]
    
    def beam_search_vector(self, query_vector, entry_point_ids, level, stats=None):
        candidates = PriorityQueue()
        visited = set()
        results = []
        
        for ep_id in entry_point_ids if isinstance(entry_point_ids, list) else [entry_point_ids]:
            dist = self.distance_vector(query_vector, ep_id)
            candidates.put((dist, ep_id))
            visited.add(ep_id)
            results.append((ep_id, dist))
            if stats is not None:
                stats['distance_computations'] += 1
        
        while not candidates.empty():
            current_dist, current_id = candidates.get()
            
            if len(results) >= self.ef_construction:
                furthest_dist = max(results, key=lambda x: x[1])[1]
                if current_dist > furthest_dist:
                    break
            
            node = self.level_graphs[level].get_node(current_id)
            if node is None:
                continue
                
            for neighbor_id in node.get_neighbors():
                if neighbor_id not in visited:
                    visited.add(neighbor_id)
                    dist = self.distance_vector(query_vector, neighbor_id)
                    if stats is not None:
                        stats['distance_computations'] += 1
                    
                    if len(results) < self.ef_construction:
                        candidates.put((dist, neighbor_id))
                        results.append((neighbor_id, dist))
                    else:
                        furthest_dist = max(results, key=lambda x: x[1])[1]
                        if dist < furthest_dist:
                            candidates.put((dist, neighbor_id))
                            results.append((neighbor_id, dist))
                            results = sorted(results, key=lambda x: x[1])[:self.ef_construction]
        
        results = sorted(results, key=lambda x: x[1])
        return results[:self.n_max]

    def add_point(self, point_id, entry_point_id):
        for level in range(self.max_levels[point_id] + 1):
            self.level_graphs[level].add_node(Node(point_id))
        
        if point_id == entry_point_id:
            return
        
        if self.max_levels[point_id] > self.max_levels[entry_point_id]:
            start_level = self.max_levels[entry_point_id]
            beam_entry_point_ids = [entry_point_id]
        else:
            start_level = self.max_levels[point_id]
            closest_id = self.greedy_search(point_id, entry_point_id, stop_at=start_level)
            beam_entry_point_ids = [closest_id]
        
        for level in range(start_level, -1, -1):
            neighbors = self.beam_search(point_id, beam_entry_point_ids, level)
            
            for neighbor_id, dist in neighbors:
                if neighbor_id != point_id:
                    self.level_graphs[level].add_bi_directional_edge(point_id, neighbor_id)
                    
                    neighbor_node = self.level_graphs[level].get_node(neighbor_id)
                    if len(neighbor_node.get_neighbors()) > self.n_max:
                        self.prune_connections(neighbor_id, level)
            
            beam_entry_point_ids = [n[0] for n in neighbors]
    
    def prune_connections(self, node_id, level):
        node = self.level_graphs[level].get_node(node_id)
        neighbors = node.get_neighbors()
        
        if len(neighbors) <= self.n_max:
            return
        
        neighbor_dists = [(n_id, self.distance(node_id, n_id)) for n_id in neighbors]
        neighbor_dists = sorted(neighbor_dists, key=lambda x: x[1])[:self.n_max]
        kept_neighbors = set([n_id for n_id, _ in neighbor_dists])
        
        for neighbor_id in neighbors:
            if neighbor_id not in kept_neighbors:
                node.remove_neighbor(neighbor_id)
                neighbor_node = self.level_graphs[level].get_node(neighbor_id)
                if neighbor_node:
                    neighbor_node.remove_neighbor(node_id)
        
    def build_index(self):
        self.entry_point_id = 0
        self.add_point(0, 0)
        
        for point_id in range(1, self.sample_nr):
            self.add_point(point_id, self.entry_point_id)
            
            if self.max_levels[point_id] > self.max_levels[self.entry_point_id]:
                self.entry_point_id = point_id
            
            if (point_id + 1) % 100 == 0:
                print(f"Inserted {point_id + 1}/{self.sample_nr} points")
        
        print(f"HNSW index built with {self.sample_nr} points")
        
    def search(self, query_id, k=10, ef_search=None):
        if ef_search is None:
            ef_search = max(self.ef_construction, k)
        
        if self.entry_point_id is None:
            return []
        
        closest_id = self.greedy_search(query_id, self.entry_point_id, stop_at=0)
        
        original_ef = self.ef_construction
        self.ef_construction = ef_search
        results = self.beam_search(query_id, [closest_id], level=0)
        self.ef_construction = original_ef
        
        return results[:k]
    
    def search_vector(self, query_vector, k=10, ef_search=None, return_stats=False):
        if ef_search is None:
            ef_search = max(self.ef_construction, k)
        
        if self.entry_point_id is None:
            return []
        
        stats = {
            'levels_visited': [],
            'distance_computations': 0,
            'entry_point_level': self.max_levels[self.entry_point_id]
        }
        
        closest_id = self.greedy_search_vector(query_vector, self.entry_point_id, stop_at=0, stats=stats)
        
        original_ef = self.ef_construction
        self.ef_construction = ef_search
        results = self.beam_search_vector(query_vector, [closest_id], level=0, stats=stats)
        self.ef_construction = original_ef
        
        results = results[:k]
        
        if return_stats:
            return results, stats
        return results
    
    def brute_force_search(self, query_vector, k=10):
        distances = [(i, self.distance_vector(query_vector, i)) for i in range(self.sample_nr)]
        distances = sorted(distances, key=lambda x: x[1])
        return distances[:k]
        
    def initialize_graph(self):
        self.max_levels = self.draw_levels()
        self.maximum_height = self.max_levels.max()
        self.setup_level_graphs()



hnsw = HNSW(number_of_samples, N_max)
print(f"Building HNSW index with {number_of_samples} points...")
hnsw.build_index()

# Generate random query vector
query_vector = np.random.random(d)
k = 10

# HNSW search with stats
hnsw_results, stats = hnsw.search_vector(query_vector, k=k, return_stats=True)

print(f"Distance computations: {stats['distance_computations']}")
print(f"Entry point level: {stats['entry_point_level']}")
print(f"\nTop {k} results:")
for i, (neighbor_id, dist) in enumerate(hnsw_results, 1):
    print(f"  {i}. Point {neighbor_id}: distance = {dist:.4f}")

# Brute force search
print("BRUTE FORCE SEARCH")
brute_results = hnsw.brute_force_search(query_vector, k=k)

print(f"Distance computations: {number_of_samples}")
print(f"\nTop {k} results:")
for i, (neighbor_id, dist) in enumerate(brute_results, 1):
    print(f"  {i}. Point {neighbor_id}: distance = {dist:.4f}")

# Comparison
print("COMPARISON")
print(f"Distance computation reduction: {number_of_samples/stats['distance_computations']:.2f}x")

Building HNSW index with 10000 points...
Inserted 100/10000 points
Inserted 200/10000 points
Inserted 300/10000 points
Inserted 400/10000 points
Inserted 500/10000 points
Inserted 600/10000 points
Inserted 700/10000 points
Inserted 800/10000 points
Inserted 900/10000 points
Inserted 1000/10000 points
Inserted 1100/10000 points
Inserted 1200/10000 points
Inserted 1300/10000 points
Inserted 1400/10000 points
Inserted 1500/10000 points
Inserted 1600/10000 points
Inserted 1700/10000 points
Inserted 1800/10000 points
Inserted 1900/10000 points
Inserted 2000/10000 points
Inserted 2100/10000 points
Inserted 2200/10000 points
Inserted 2300/10000 points
Inserted 2400/10000 points
Inserted 2500/10000 points
Inserted 2600/10000 points
Inserted 2700/10000 points
Inserted 2800/10000 points
Inserted 2900/10000 points
Inserted 3000/10000 points
Inserted 3100/10000 points
Inserted 3200/10000 points
Inserted 3300/10000 points
Inserted 3400/10000 points
Inserted 3500/10000 points
Inserted 3600/10000 poi