In [28]:
import os
import cupy as cp
import numpy as np
import torch
import math
import time
import random
from collections import defaultdict
from tqdm import tqdm

In [29]:
print(torch.cuda.is_available())  # cuda test
print(torch.__version__)          # version test

True
2.5.1+cu121


HNSW

In [30]:
class HNSW:
    def __init__(self, max_layers=3, ef_construction=80, M=12, ef_search=50, device='cuda'):
        self.max_layers = max_layers
        self.ef_construction = ef_construction
        self.M = M
        self.ef_search = ef_search
        self.device = device
        self.enter_point = None
        self.graphs = [defaultdict(list) for _ in range(max_layers)]
        self.vectors = None
        self.cluster_centers = None
        self.cluster_assignments = None
    
    def _select_neighbors_simple(self, q, candidates, layer, K):
        if len(candidates) <= K:
            return candidates
        candidate_vectors = self.vectors[torch.tensor(candidates, device=self.device)]
        # distances = torch.norm(candidate_vectors - q, dim=1)
        distances = torch.cdist(q.unsqueeze(0), candidate_vectors).squeeze(0)  # faster than norm
        _, indices = torch.topk(distances, K, largest=False)
        return [candidates[i] for i in indices.cpu().numpy()]
    
    def _search_layer(self, q, ep, ef, layer):
        visited = set(ep)
        candidates = ep.copy()
        heap = []
        
        for node in ep:
            dist = torch.norm(self.vectors[node] - q)
            heap.append((dist, node))
        
        heap.sort()
        result = ep.copy()
        
        while candidates:
            current = candidates.pop(0)
            if len(result) >= ef and heap[0][0] > heap[ef-1][0]:
                break
                
            for neighbor in self.graphs[layer][current]:
                if neighbor not in visited:
                    visited.add(neighbor)
                    dist = torch.norm(self.vectors[neighbor] - q)
                    if len(result) < ef or dist < heap[-1][0]:
                        result.append(neighbor)
                        candidates.append(neighbor)
                        heap.append((dist, neighbor))
                        heap.sort()
                        if len(heap) > ef:
                            heap = heap[:ef]
        
        return result[:ef]
    
    def _get_random_layer(self):
        return min(int(-math.log(random.random()) * 1.0), self.max_layers - 1)
    
    def build_index(self, vectors):
        start_time = time.time()
        self.vectors = torch.tensor(vectors, device=self.device)
        N, D = self.vectors.shape
        self.enter_point = random.randint(0, N-1)
        
        for node in range(N):
            self.graphs[0][node] = []
        
        for node in range(N):
            l = self._get_random_layer()
            ep = [self.enter_point]
            
            for layer in reversed(range(l+1)):
                ep = self._search_layer(self.vectors[node], ep, self.ef_construction, layer)
                neighbors = self._select_neighbors_simple(self.vectors[node], ep, layer, self.M)
                
                for neighbor in neighbors:
                    if node not in self.graphs[layer][neighbor] and neighbor != node:
                        self.graphs[layer][neighbor].append(node)
                        self.graphs[layer][node].append(neighbor)
                
                for n in neighbors:
                    if len(self.graphs[layer][n]) > self.M * 2:
                        self.graphs[layer][n] = self._select_neighbors_simple(
                            self.vectors[n], self.graphs[layer][n], layer, self.M)
            
            if l > 0:
                self.enter_point = node
        
        build_time = time.time() - start_time
        return build_time
    
    def kmeans(self, K, max_iters=30):
        start_time = time.time()
        N, D = self.vectors.shape
        centroids = self.vectors[torch.randperm(N)[:K]].clone()
        
        for _ in range(max_iters):
            distances = torch.cdist(self.vectors, centroids)
            cluster_ids = torch.argmin(distances, dim=1)
            
            new_centroids = torch.zeros_like(centroids)
            counts = torch.zeros(K, device=self.device)
            
            for k in range(K):
                mask = cluster_ids == k
                if mask.any():
                    new_centroids[k] = self.vectors[mask].mean(dim=0)
                    counts[k] = mask.sum()
            
            empty_clusters = counts == 0
            if empty_clusters.any():
                new_centroids[empty_clusters] = self.vectors[torch.randperm(N)[:empty_clusters.sum()]]
            
            if torch.allclose(centroids, new_centroids, rtol=1e-4):
                break
                
            centroids = new_centroids
        
        self.cluster_centers = centroids
        self.cluster_assignments = cluster_ids
        return time.time() - start_time
    
    def search(self, query, K):
        start_time = time.time()
        query = torch.tensor(query, device=self.device)
        ep = [self.enter_point]
        
        for layer in reversed(range(self.max_layers)):
            ep = self._search_layer(query, ep, self.ef_search, layer)
        
        candidates = ep
        candidate_vectors = self.vectors[torch.tensor(candidates, device=self.device)]
        distances = torch.norm(candidate_vectors - query, dim=1)
        _, indices = torch.topk(distances, min(K, len(candidates)), largest=False)
        
        search_time = time.time() - start_time
        return [candidates[i] for i in indices.cpu().numpy()], search_time
    
    def search_with_clusters(self, query, K, cluster_K=2):
        start_time = time.time()
        if self.cluster_centers is None:
            return self.search(query, K)
        
        query = torch.tensor(query, device=self.device)
        cluster_distances = torch.norm(self.cluster_centers - query, dim=1)
        _, closest_clusters = torch.topk(cluster_distances, cluster_K, largest=False)
        
        candidates = []
        for cluster_id in closest_clusters:
            mask = self.cluster_assignments == cluster_id
            candidates.extend(torch.where(mask)[0].tolist())
        
        candidate_vectors = self.vectors[torch.tensor(candidates, device=self.device)]
        distances = torch.norm(candidate_vectors - query, dim=1)
        _, indices = torch.topk(distances, min(K, len(candidates)), largest=False)
        
        search_time = time.time() - start_time
        return [candidates[i] for i in indices.cpu().numpy()], search_time


In [31]:
def compare_ann_performance(A, X, K, num_queries=10):
    """
    Compare performance of GPU/CPU HNSW and exact KNN
    
    Args:
        A: Dataset (N x D numpy array)
        X: Query vector (D,)
        K: Number of nearest neighbors to return
        num_queries: Number of test queries
    
    Returns:
        dict: Contains build time, search times and recall for each method
    """
    # Convert to numpy array for consistency
    if isinstance(A, torch.Tensor):
        A = A.cpu().numpy()
    if isinstance(X, torch.Tensor):
        X = X.cpu().numpy()
    
    print(f"\nStarting performance comparison (Dataset: {A.shape[0]} vectors of dim {A.shape[1]}, K={K}, queries={num_queries})")
    results = {
        'hnsw_gpu': {'build_time': 0, 'search_times': [], 'recall': 0},
        'hnsw_cpu': {'build_time': 0, 'search_times': [], 'recall': 0},
        'exact_knn': {'build_time': 0, 'search_times': [], 'recall': 1.0}
    }
    
    # 1. Test GPU HNSW
    print("\n[1/3] Building GPU HNSW index...")
    torch.cuda.synchronize()
    hnsw_gpu = HNSW(device='cuda')
    build_time = hnsw_gpu.build_index(A)
    results['hnsw_gpu']['build_time'] = build_time
    print(f"GPU HNSW index built in: {build_time:.4f}s")
    
    # 2. Test CPU HNSW
    print("\n[2/3] Building CPU HNSW index...")
    hnsw_cpu = HNSW(device='cpu')
    build_time = hnsw_cpu.build_index(A)
    results['hnsw_cpu']['build_time'] = build_time
    print(f"CPU HNSW index built in: {build_time:.4f}s")
    
    # 3. Exact KNN (no build time, only search)
    print("\n[3/3] Preparing exact KNN tests...")
    A_tensor = torch.tensor(A, device='cuda')
    X_tensor = torch.tensor(X, device='cuda')
    
    # Run multiple queries for averaging
    exact_results = []
    print("\nRunning exact KNN queries...")
    for _ in tqdm(range(num_queries), desc="Exact KNN Progress"):
        start_time = time.time()
        distances = torch.norm(A_tensor - X_tensor, dim=1)
        _, indices = torch.topk(distances, K, largest=False)
        search_time = time.time() - start_time
        results['exact_knn']['search_times'].append(search_time)
        exact_results.append(indices.cpu().numpy())
    
    # Calculate recall rates
    print("\nRunning HNSW queries and calculating recall...")
    for i in tqdm(range(num_queries), desc="Overall Progress"):
        # GPU HNSW query
        gpu_result, search_time = hnsw_gpu.search(X, K)
        results['hnsw_gpu']['search_times'].append(search_time)
        results['hnsw_gpu']['recall'] += calculate_recall(gpu_result, exact_results[0], K)
        
        # CPU HNSW query
        cpu_result, search_time = hnsw_cpu.search(X, K)
        results['hnsw_cpu']['search_times'].append(search_time)
        results['hnsw_cpu']['recall'] += calculate_recall(cpu_result, exact_results[0], K)
        
        # Show interim progress
        if (i+1) % max(1, num_queries//5) == 0 or (i+1) == num_queries:
            current_gpu_time = sum(results['hnsw_gpu']['search_times']) / len(results['hnsw_gpu']['search_times'])
            current_cpu_time = sum(results['hnsw_cpu']['search_times']) / len(results['hnsw_cpu']['search_times'])
            current_exact_time = sum(results['exact_knn']['search_times']) / len(results['exact_knn']['search_times'])
            print(f"\n--- Progress {i+1}/{num_queries} ---")
            print(f"Current avg query times: GPU HNSW={current_gpu_time:.6f}s, CPU HNSW={current_cpu_time:.6f}s, Exact KNN={current_exact_time:.6f}s")
    
    # Calculate averages
    for method in results:
        if method != 'exact_knn':
            results[method]['recall'] /= num_queries
        results[method]['avg_search_time'] = sum(results[method]['search_times']) / num_queries
    
    return results


def calculate_recall(ann_result, knn_result, K):
    ann_set = set(ann_result[:K])
    knn_set = set(knn_result[:K])
    return len(ann_set & knn_set) / K

In [32]:
# Generate random data
N = 100  # Number of vectors
D = 32    # Dimension
K = 3      # Number of clusters/top-K
A = np.random.randn(N, D).astype(np.float32)
X = np.random.randn(D).astype(np.float32)
num_queries = 1  # Number of queries
# Run comparison
results = compare_ann_performance(A, X, K, num_queries=num_queries)
    
# Print final results
print("\n=== Final Performance Results ===")
print(f"Dataset: {N} vectors of dim {D} | K={K} | Test queries={num_queries}")
print("\n=== Performance Summary ===")
print("{:<15} {:<15} {:<15} {:<15}".format(
    "Method", "Build Time(s)", "Avg Query Time(s)", "Recall"))
    
for method, data in results.items():
    print("{:<15} {:<15.6f} {:<15.6f} {:<15.6f}".format(
        method,
        data['build_time'],
        data['avg_search_time'],
        data['recall']))


Starting performance comparison (Dataset: 100 vectors of dim 32, K=3, queries=1)

[1/3] Building GPU HNSW index...
GPU HNSW index built in: 21.6745s

[2/3] Building CPU HNSW index...
CPU HNSW index built in: 0.6772s

[3/3] Preparing exact KNN tests...

Running exact KNN queries...


Exact KNN Progress: 100%|██████████| 1/1 [00:00<?, ?it/s]



Running HNSW queries and calculating recall...


Overall Progress: 100%|██████████| 1/1 [00:00<00:00,  2.37it/s]


--- Progress 1/1 ---
Current avg query times: GPU HNSW=0.409381s, CPU HNSW=0.012204s, Exact KNN=0.000000s

=== Final Performance Results ===
Dataset: 100 vectors of dim 32 | K=3 | Test queries=1

=== Performance Summary ===
Method          Build Time(s)   Avg Query Time(s) Recall         
hnsw_gpu        21.674493       0.409381        0.666667       
hnsw_cpu        0.677247        0.012204        0.666667       
exact_knn       0.000000        0.000000        1.000000       



