In [163]:
import math
import numpy as np
from unittest.mock import NonCallableMagicMock
from urllib.parse import non_hierarchical

class Node:
    def __init__(self, _id, data, axis, left = None, right = None):
        self._id = _id
        self.left = left    
        self.right = right 
        self.axis = axis
        self.data = data

class KD_Tree:
    def __init__(self, data):
        self.data = data
        self.tree = None

    def _build(self, points, depth):
        k = len(points.columns)
        _axis = depth % k
        _column = points.columns[_axis]
        
        if len(points) == 0:
            return None
        
        object_list = points.sort_values(by = [_column], ascending = True)
        if len(object_list) % 2 == 0:
            median_idx = int((len(object_list)/2))
        else:
            median_idx = math.floor((len(object_list)/2))
        #print(object_list.iloc[median_idx][_column])
        node = Node(_id = round(object_list.iloc[median_idx][_column],3),
                    data = object_list.iloc[median_idx],
                    axis = _column)
        node.left = self._build(object_list.iloc[0:median_idx], depth+1)

        node.right  = self._build(object_list.iloc[median_idx+1:],  depth+1)
           
        return node
            
    def build(self):
        self.tree = self._build(self.data, depth = 0)
        
    def distance(self, X, Y):
        res = 0
        for i in range(len(X)):
            res += (X[i] - Y[i])**2
        return math.sqrt(res)
        
    def nearest(self, query, node, best_node, best_distance):
        if node == None:
            return best_node, best_distance
        
        d = self.distance(query.values[0], node.data.values)
        #print(node.data.values)
        #print(node.data.values)
        if d <  best_distance:
            best_node = node
            best_distance = d
            
     
        if query[node.axis].values[0] < node.data[node.axis]:
            good_side = node.left
            bad_side = node.right
        else:
            good_side = node.right
            bad_side = node.left
        
        #print(good_side.data)
        best_node, best_distance  = self.nearest(query, good_side, best_node, best_distance)
        if abs(query[node.axis].values[0] - node.data[node.axis])  < best_distance:
            #print(query[node.axis].values[0])
            best_node, best_distance  = self.nearest(query, bad_side, best_node, best_distance)
        
            
        return best_node, best_distance
    
    def create_nn(self, query):
        return self.nearest(query, self.tree, None, np.infty)
        
    

In [166]:
import pandas as pd
test_df = pd.DataFrame(data = [[2,3], [5,4], [9,6], [4,7], [8,1], [7,2]], columns = ["X","Y"])
KD = KD_Tree(test_df)
KD.build()
node = KD.tree
#print(node._id)


7
4
2
4
6
8


In [167]:
query = pd.DataFrame(data = [[6,7]], columns = ["X", "Y"])
#print(query)
res = KD.create_nn(query)
print(res[0].data)

X    4
Y    7
Name: 3, dtype: int64
