# K Nearest Neighbours

In [2]:
import numpy as np
import scipy as sp

%run basic_model.ipynb

In [4]:
import heapq
from collections import namedtuple

class BallTree(BasicModel):
    class Node:
        def __init__(
            self, 
            observation_indexes,
            pivot_index=None,
            radius=None,
            left=None,
            right=None
        ):
            self.observation_indexes=observation_indexes
            self.pivot_index = pivot_index
            self.radius = radius
            self.left = left
            self.right = right
        
        def is_leaf():
            return not self.pivot
    
    def __init__(
        self, 
        X, 
        leaf_size=40, 
        metric='minkowski'
    ):
        self.X = super().check_and_transform_X(X)
        
        super().check_value_type_and_set(
            'leaf_size',
            leaf_size,
            int
        )
        
        super().check_value_type_and_set(
            'metric',
            metric,
            str
        )
        
        observation_indexes = np.arange(X.shape[0])
        
        self.root = self.__construct_ball(
            observation_indexes
        )
    
    @staticmethod
    def most_spreaded_dimensionality(X):
        max_std = X[:, 0].std()
        most_spreaded_dimensionality = 0
        
        for i in range(1, X.shape[1]):
            curr_std = X[:, i].std()
            
            if curr_std > max_std:
                max_std = curr_std
                most_spreaded_dimensionality = i
        
        return most_spreaded_dimensionality
    
    @staticmethod
    def arg_median(array):
        if len(array) % 2 == 1:
            return np.where(array == np.median(array))[0][0]
        else:
            l,r = len(array) // 2 - 1, len(array) // 2
            
            left = np.partition(array, l)[l]
            right = np.partition(array, r)[r]
            
            result = (
                np.where(array == left)[0][0],
                np.where(a == array)[0][0]
            )
            
            return result
    
    def dist(self, x, y):
        return sp.spatial.distance.pdist(
            [x, y],
            self.metric
        )[0, 1]
    
    def __construct_ball(self, observation_indexes):
        if observation_indexes <= self.leaf_size:
            return BallTree.Node(
                observation_indexes=observation_indexes
            )
        
        node_sample = self.X[observation_indexes, :]
        
        most_spreaded_dim = BallTree.most_spreaded_dimensionality(node_sample)
        
        # find pivot
        median_sample_index = BallTree.arg_median(
            node_sample[:, most_spreaded_dim]
        )
        
        if isinstance(median_sample_index, tuple):
            median_sample_index = median_sample_index[0]
        
        # calculate ball radius
        radius = -1
        
        median_point = node_sample[median_sample_index, :]
        
        for i in range(node_sample.shape[0]):
            if i == median_sample_index:
                continue
            
            radius = max(
                radius,
                self.dist(
                    node_sample[i, :], 
                    median_point
                )
            )
        
        # split observations by pivot
        left_split_indexes = []
        right_split_indexes = []
        
        median_value = node_sample[
            median_sample_index,
            most_spreaded_dim
        ]
        
        for i in range(node_sample.shape[0]):
            if i == median_sample_index:
                continue
                
            observation_index = observation_indexes[i]
            
            if node_sample[i, most_spreaded_dim] <= median_value:
                left_split_indexes.append(observation_index)
            else:
                right_split_indexes.append(observation_index)
        
        return BallTree.Node(
            observation_indexes=observation_indexes,
            pivot_index=observation_indexes[median_sample_index],
            radius=radius,
            left=self.__construct_ball(left_split_indexes),
            right=self.__construct_ball(right_split_indexes)
        )
    
    PrioritizedPointIndex = namedtuple(
        'PrioritizedPointIndex',
        ['dist_to_target', 'point_index']
    )
    
    def __process_leaf(
        self,
        k,
        target_point,
        node,
        heap
    ):
        if len(heap) == 0:
            heap.append(BallTree.PrioritizedPointIndex(
                dist_to_target=self.dist(
                    self.X[node.observation_indexes[0], :],
                    target_point
                )
            ))

        for observation_index in node.observation_indexes:
            from_target_to_current_observation = self.dist(
                target_point,
                self.X[observation_index, :]
            )

            from_target_to_heap_worst = -heap[0]

            if from_target_to_current_observation < from_target_to_heap_worst:
                heapq.heappush(heap, BallTree.PrioritizedPointIndex(
                    dist_to_target=from_target_to_current_observation,
                    point_index=observation_index
                ))
            
                if len(heap) > k:
                    heapq.heappop(heap)
    
    def k_nearest_neighbours_search(
        self,
        k,
        target_point,
        node,
        heap=None
    ):
        heap = heap or []
        
        from_target_to_node_center = self.dist(
            target_point,
            self.X[node.pivot_index, :]
        )
        
        from_target_to_heap_worst = -heap[0].dist_to_target
        
        if len(heap) != 0 and \
           from_target_to_node_center - node.radius >= from_target_to_heap_worst:
            return
        
        if node.is_leaf():
            self.__process_leaf(k, target_point, node, heap)
        else:
            self.k_nearest_neighbours_search(
                k, 
                target_point, 
                node.left, 
                heap
            )
            
            self.k_nearest_neighbours_search(
                k, 
                target_point, 
                node.right, 
                heap
            )
        
        return heap