Fast Exact Retrieval for Nearest-neighbor Search

In [2]:
!pip install sortedcontainers matplotlib numpy
import matplotlib.pyplot as plt
import numpy as np
import time
from sortedcontainers import SortedDict


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [3]:
## Create red black tree datastructure

class VectorNode:
    def __init__(self, vector):
        self.vector = vector
        self.left = None
        self.right = None
        self.parent = None
        self.color = 'red'

class VectorTree:
    def __init__(self):
        self.root = None

    def insert(self, vector):
        if self.root is None:
            self.root = VectorNode(vector)
            self.root.color = 'black'
        else:
            node = self._insert(vector, self.root)
            self._fix_tree(node)

    def _insert(self, vector, node):
        if self._euclidean_distance(vector, node.vector) < 0:
            if node.left is None:
                node.left = VectorNode(vector)
                node.left.parent = node
                return node.left
            else:
                return self._insert(vector, node.left)
        else:
            if node.right is None:
                node.right = VectorNode(vector)
                node.right.parent = node
                return node.right
            else:
                return self._insert(vector, node.right)

    def _fix_tree(self, node):
        while node.parent is not None and node.parent.color == 'red':
            if node.parent == node.parent.parent.left:
                uncle = node.parent.parent.right
                if uncle is not None and uncle.color == 'red':
                    node.parent.color = 'black'
                    uncle.color = 'black'
                    node.parent.parent.color = 'red'
                    node = node.parent.parent
                else:
                    if self._euclidean_distance(node.vector, node.parent.right.vector) < 0:
                        node = node.parent
                        self._rotate_left(node)
                    node.parent.color = 'black'
                    node.parent.parent.color = 'red'
                    self._rotate_right(node.parent.parent)
            else:
                uncle = node.parent.parent.left
                if uncle is not None and uncle.color == 'red':
                    node.parent.color = 'black'
                    uncle.color = 'black'
                    node.parent.parent.color = 'red'
                    node = node.parent.parent
                else:
                    if self._euclidean_distance(node.vector, node.parent.left.vector) > 0:
                        node = node.parent
                        self._rotate_right(node)
                    node.parent.color = 'black'
                    node.parent.parent.color = 'red'
                    self._rotate_left(node.parent.parent)
        self.root.color = 'black'

    def _rotate_left(self, node):
        right_child = node.right
        node.right = right_child.left
        if right_child.left is not None:
            right_child.left.parent = node
        right_child.parent = node.parent
        if node.parent is None:
            self.root = right_child
        elif node == node.parent.left:
            node.parent.left = right_child
        else:
            node.parent.right = right_child
        right_child.left = node
        node.parent = right_child

    def _rotate_right(self, node):
        left_child = node.left
        node.left = left_child.right
        if left_child.right is not None:
            left_child.right.parent = node
        left_child.parent = node.parent
        if node.parent is None:
            self.root = left_child
        elif node == node.parent.right:
            node.parent.right = left_child
        else:
            node.parent.left = left_child
        left_child.right = node
        node.parent = left_child

    def retrieve(self, vector, k):
        """
        Retrieve the k nearest neighbors to the given vector.
        """
        distances = SortedDict()
        self._retrieve(vector, k, self.root, distances)
        return distances

    def _retrieve(self, vector, k, node, distances):
        if node is None:
            return
        distance = self._euclidean_distance(vector, node.vector)
        if len(distances) < k or distance < distances.peekitem(-1)[0]:
            distances[distance] = node.vector
            if len(distances) > k:
                distances.popitem()
        if node.left is not None and self._euclidean_distance(vector, node.left.vector) >= -distances.peekitem(-1)[0]:
            self._retrieve(vector, k, node.left, distances)
        if node.right is not None and self._euclidean_distance(vector, node.right.vector) <= distances.peekitem(-1)[0]:
            self._retrieve(vector, k, node.right, distances)

    def _euclidean_distance(self, v1, v2):
        return ((v1 - v2) ** 2).sum() ** 0.5


In [4]:
## test

# create a list of num_vectors
num_vectors = [1, 10, 50, 100, 500, 1000]

# for each num_vector
for num_vector in num_vectors:
    # create a VectorTree object
    tree = VectorTree()
    # create an empty list to store retrieval times for this trial
    retrieval_times = []

    # set random vector to be last vector inserted
    selected_vector = None

    # insert num_vector random vectors
    for i in range(num_vector):
        vector = np.random.rand(10)
        tree.insert(vector)
        selected_vector = vector
    
    # create an empty list to store retrieval times for this selected vector
    retrieval_times_for_selected_vector = []
    # for each retrieval
    for retrieval in range(5):
        start_time = time.time()
        distances = tree.retrieve(selected_vector, 1)
        end_time = time.time()
        retrieval_time = end_time - start_time
        retrieval_times_for_selected_vector.append(retrieval_time)

    # calculate the average retrieval time across the retrievals for the selected vector and append it to retrieval_times
    average_retrieval_time_for_selected_vector = sum(retrieval_times_for_selected_vector) / len(retrieval_times_for_selected_vector)
    retrieval_times.append(average_retrieval_time_for_selected_vector)

# plot the average retrieval times for this num_vector
plt.plot(range(1, 7), retrieval_times, label=str(num_vector))
    
# set the x-axis label and legend
plt.xlabel('Trial')
plt.ylabel('Average retrieval time')
plt.legend(title='Number of vectors')
plt.show()


AttributeError: 'NoneType' object has no attribute 'vector'

: 