In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import time
import random

In [2]:
def distance(a,b):
    return np.sqrt(np.sum((a - b)**2))                      

In [3]:
def generate_dataset(N, D):
    return np.random.rand(N, D)

def generate_query(D):
    return np.random.rand(D)

# Naive

In [4]:
def naive_knn(query, dataset, K):
    distances = [distance(query, point) for point in dataset]
    sorted_indices = np.argsort(distances)

    k_sorted_points =  np.array([dataset[i] for i in sorted_indices[:K]])
    k_sorted_distances = np.array(distances)[sorted_indices[:K]]

    return sorted_indices[:K], k_sorted_points,  k_sorted_distances

In [5]:
N,D,K= 1000, 5, 10
ds= generate_dataset(N,D)
q= generate_query(D)
if D==2:
    plt.scatter(ds[:, 0], ds[:, 1])
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')
    plt.title('Scatter plot of 2D Dataset')
    plt.show()

In [6]:
indices, points, distances=naive_knn(q, ds, K)
print(points.shape)
print(indices)
print(points)
print(distances)

(10, 5)
[508 576 743 972  44 731 250 148 124 128]
[[0.18880225 0.89881759 0.48061463 0.02509819 0.29025713]
 [0.23831656 0.89943075 0.40399306 0.34239942 0.38415212]
 [0.0030833  0.97159352 0.50304068 0.31152268 0.32471155]
 [0.27152427 0.95949309 0.22054768 0.16690242 0.4779965 ]
 [0.32895058 0.85367267 0.43239671 0.18279736 0.50334501]
 [0.13749313 0.85925127 0.16984635 0.25746505 0.4861598 ]
 [0.24316426 0.95432088 0.43741774 0.22182947 0.14731658]
 [0.21392793 0.80113744 0.55656834 0.12313497 0.60877014]
 [0.16051373 0.93612223 0.68887516 0.24799755 0.3066494 ]
 [0.01215896 0.93522769 0.30348597 0.0258052  0.67729439]]
[0.1857689  0.24937569 0.25017533 0.25543898 0.26914711 0.28540006
 0.28814683 0.32720837 0.333737   0.34049964]


# KD Tree

In [7]:
class KDTree:
    def __init__(self, data, leaf_size):
        self.data = np.hstack((data, np.arange(len(data)).reshape(-1, 1))) ##stacking the indices
        self.leaf_size = leaf_size
        self.tree = self.build_kdtree(self.data)
    
    def build_kdtree(self, data, depth=0):
        if len(data) <= self.leaf_size:
            return data  
        
        axis = depth % (data.shape[1] - 1)  # Alternate splitting axis, ignore index column
        sorted_data = data[data[:, axis].argsort()]
        median_index = len(sorted_data) // 2
        left = self.build_kdtree(sorted_data[:median_index], depth + 1)
        right = self.build_kdtree(sorted_data[median_index + 1:], depth + 1)
        
        return (sorted_data[median_index], left, right)
    
    def query(self, query, K):
        indices, distances = self._query(self.tree, query, K, depth=0) 
        points = self.data[indices, :-1]
        return indices, points, distances
    
    def _query(self, node, query, K, depth):
    
        if isinstance(node, np.ndarray): #leaf node is an array
            points = node[:, :-1] 
            original_indices = node[:, -1].astype(int)  
            distances = np.array([distance(query, point) for point in points])
            sorted_indices = np.argsort(distances)
            nearest_indices = original_indices[sorted_indices[:K]]
            nearest_distances = distances[sorted_indices[:K]]
            return nearest_indices, nearest_distances
        
        if isinstance(node, tuple) and len(node) == 3: # non-leaf node is a tuple (median point, left subtree, right subtree)
            median, left, right = node
            axis = depth % (query.shape[0])
            
            if query[axis] < median[axis]:
                primary, other = left, right
            else:
                primary, other = right, left
            
            # Recursively search the primary side
            indices, distances = self._query(primary, query, K, depth + 1)
            
            
            if len(indices) < K or abs(query[axis] - median[axis]) < max(distances): # if the other tree's data also need to be combined
                other_indices, other_distances = self._query(other, query, K, depth + 1)
                combined_indices = np.concatenate([indices, other_indices])
                combined_distances = np.concatenate([distances, other_distances])
                sorted_combined = np.argsort(combined_distances)
                indices = combined_indices[sorted_combined][:K]
                distances = combined_distances[sorted_combined][:K]
            
            return indices, distances
        
       
        return np.array([]), np.array([])




In [8]:

tree = KDTree(ds, leaf_size=20)


indices, points, distances = tree.query(q, K)

print("Indices of nearest neighbors:", indices)
print("Nearest points:", points)
print("Distances to nearest points:", distances)


Indices of nearest neighbors: [508 576 743 972 731 250 148 124 128 854]
Nearest points: [[0.18880225 0.89881759 0.48061463 0.02509819 0.29025713]
 [0.23831656 0.89943075 0.40399306 0.34239942 0.38415212]
 [0.0030833  0.97159352 0.50304068 0.31152268 0.32471155]
 [0.27152427 0.95949309 0.22054768 0.16690242 0.4779965 ]
 [0.13749313 0.85925127 0.16984635 0.25746505 0.4861598 ]
 [0.24316426 0.95432088 0.43741774 0.22182947 0.14731658]
 [0.21392793 0.80113744 0.55656834 0.12313497 0.60877014]
 [0.16051373 0.93612223 0.68887516 0.24799755 0.3066494 ]
 [0.01215896 0.93522769 0.30348597 0.0258052  0.67729439]
 [0.08493355 0.9788638  0.21929512 0.3956276  0.26531005]]
Distances to nearest points: [0.1857689  0.24937569 0.25017533 0.25543898 0.28540006 0.28814683
 0.32720837 0.333737   0.34049964 0.34055163]


# LSH

In [9]:
class LSH:
    def __init__(self, dataset, num_hashes=5):
        self.dataset = dataset
        self.num_hashes = num_hashes
        self.dataset_augmented = np.hstack((self.dataset, np.ones((self.dataset.shape[0], 1))))  # Adding bias term
        self.hashes = self._generate_hashes()  # Hashes with bias term included

    def _generate_hashes(self):
        return np.random.randn(self.num_hashes, self.dataset.shape[1] + 1)

    def _hash_(self, point):
        point_augmented = np.append(point, 1)  
        return np.sign(np.dot(self.hashes, point_augmented))  

    def query(self, query, K):
        query_hash = self._hash_(query)

        hash_buckets = {}
        for index, point in enumerate(self.dataset):
            point_hash = tuple(self._hash_(point))  # Made it into a tuple so it can be used as a dictionary key
            if point_hash not in hash_buckets:
                hash_buckets[point_hash] = []
            hash_buckets[point_hash].append(index)

        query_hash_tuple = tuple(query_hash)

        nearest_neighbors = self.find_k(query, hash_buckets.get(query_hash_tuple, []), K) # looking into same bucket

        if len(nearest_neighbors) < K: #neighbouring buckets
            for i in range(self.num_hashes):
                if len(nearest_neighbors) >= K:
                    break
 
                modified_query_hash = list(query_hash_tuple)
                modified_query_hash[i] = 1 - modified_query_hash[i]  # Flipping the ith hash value
                modified_query_hash_tuple = tuple(modified_query_hash)
                
                if modified_query_hash_tuple in hash_buckets:
                    remaining_neighbors = self.find_k(query, hash_buckets[modified_query_hash_tuple], K - len(nearest_neighbors))
                    nearest_neighbors.extend(remaining_neighbors)

        nearest_neighbors = sorted(nearest_neighbors, key=lambda x: x[1])[:K] 

        nearest_indices = [index for index, _ in nearest_neighbors]
        nearest_points = self.dataset[nearest_indices]
        nearest_distances = [dist for _, dist in nearest_neighbors]

        return nearest_indices, nearest_points, nearest_distances

    def find_k(self, query, bucket_indices, K):
        distances = []
        for index in bucket_indices:
            point = self.dataset[index]
            dist = distance(query, point) 
            distances.append((index, dist))

        return sorted(distances, key=lambda x: x[1])[:K] 


In [10]:

num_hashes = 15 

lsh = LSH(ds, num_hashes)
nearest_indices, nearest_points, nearest_distances = lsh.query(q, K)

print("Query Point:", q)
print(nearest_indices)
print(nearest_points)
print(nearest_distances)

Query Point: [0.11266692 0.9322325  0.38959672 0.13000921 0.38139635]
[508, 148, 583, 593, 999, 364, 761, 412]
[[0.18880225 0.89881759 0.48061463 0.02509819 0.29025713]
 [0.21392793 0.80113744 0.55656834 0.12313497 0.60877014]
 [0.33565034 0.83473074 0.57266212 0.10849671 0.55249725]
 [0.20840855 0.72291725 0.69173774 0.12780284 0.60568076]
 [0.4865881  0.95498749 0.72282343 0.04510985 0.58884273]
 [0.34295209 0.90922486 0.88340219 0.01973479 0.54705592]
 [0.32059099 0.70603833 0.72078127 0.11239619 0.76921046]
 [0.16313142 0.88322389 0.89265186 0.14603094 0.76827729]]
[0.18576890153757664, 0.32720837286017884, 0.34997062958318065, 0.4411087069082088, 0.549196484520224, 0.580523893594054, 0.595642026351071, 0.6387069317080944]


In [11]:
indices, points, distances=naive_knn(q, ds, K)
print(indices)
print(points)
print(distances)

[508 576 743 972  44 731 250 148 124 128]
[[0.18880225 0.89881759 0.48061463 0.02509819 0.29025713]
 [0.23831656 0.89943075 0.40399306 0.34239942 0.38415212]
 [0.0030833  0.97159352 0.50304068 0.31152268 0.32471155]
 [0.27152427 0.95949309 0.22054768 0.16690242 0.4779965 ]
 [0.32895058 0.85367267 0.43239671 0.18279736 0.50334501]
 [0.13749313 0.85925127 0.16984635 0.25746505 0.4861598 ]
 [0.24316426 0.95432088 0.43741774 0.22182947 0.14731658]
 [0.21392793 0.80113744 0.55656834 0.12313497 0.60877014]
 [0.16051373 0.93612223 0.68887516 0.24799755 0.3066494 ]
 [0.01215896 0.93522769 0.30348597 0.0258052  0.67729439]]
[0.1857689  0.24937569 0.25017533 0.25543898 0.26914711 0.28540006
 0.28814683 0.32720837 0.333737   0.34049964]
