In [49]:
import numpy as np
import math
class KDNode(object):
    def __init__(self, value, split, left, right):
        #value = [x, y]
        self.value = value
        self.split = split
        self.right = right
        self.left = left

In [50]:
class KDTree(object):
    def __init__(self, data, split_method='default'):
        k = len(data[0])
        
        def createNode_default(self, data, split=0):
            if data.shape[0] == 0:
                return None
            #np.median()
            #np中的切片返回指向原本数组的引用，这样在find——median中对data部分排序了
            #保证返回的中间数的索引和数组是对应的，如果不对data在partition的时候改变顺序的话
            #那最后用median的索引在原来的数组里取出来的中间点不是数学意义上的中间点，并且必须对
            #整个点的每一个坐标同时改变顺序
            split_pos = self.__find_median(data, split)
            median = data[split_pos]
            split_next = (split+1) % k
            return KDNode(median, split, \
                          createNode_default(self, data[: split_pos], split_next), \
                          createNode_default(self, data[split_pos + 1:], split_next))
                               
        def createNode_maximalVariance(self, data):
            if not data:
                return None
            pass
        
        def unknown_split_method(self, data):
            raise RuntimeError("input split method not exists")
            return None
        
        switcher = {
            'default': createNode_default ,
            'maximalVar': createNode_maximalVariance
        }
        
        fun_name = switcher.get(split_method, unknown_split_method)
        self.root = fun_name(self, data)
        print(self.root)
        
    def show(self, node):
        if node is None:
            return None
        
        print(node.value)
        self.show(node.left)
        self.show(node.right)
        return None
    
    def search(self, root, x, count=1, metric='Euclidean'):
        nearest = []
        for i in range(count):
            nearest.append([-1, None])
            self.nearest = np.array(nearest)
        
        self.__search_helper(root, x, metric)
        return self.nearest
    
    def __search_helper(self, node, x, metric='Euclidean'):
        if node is not None:
            dist_to_split_plane = x[node.split] - node.value[node.split]
            if dist_to_split_plane < 0:
                self.__search_helper(node.left, x, metric)
            else:
                self.__search_helper(node.right, x, metric)
            dist_to_cur_point = self.__dist_measure(x, node.value)
            for i, d in enumerate(self.nearest):
                if d[0] < 0 or dist_to_cur_point < d[0]:  # 如果当前nearest内i处未标记（-1），或者新点与x距离更近
                    self.nearest = np.insert(self.nearest, i, [dist_to_cur_point, node.value], axis=0)  # 插入比i处距离更小的
                    self.nearest = self.nearest[:-1]
                    break
            # 找到nearest集合里距离最大值的位置，为-1值的个数
            n = list(self.nearest[:, 0]).count(-1)
            # 切分轴的距离比nearest中最大的小（存在相交）
            if self.nearest[-n - 1, 0] > abs(dist_to_split_plane):
                if dist_to_split_plane < 0:  # 相交，x[axis]< node.data[axis]时，去右边（左边已经遍历了）
                    self.__search_helper(node.right, x, metric)
                else:  # x[axis]> node.data[axis]时，去左边，（右边已经遍历了）
                    self.__search_helper(node.left, x, metric)
        
    def __dist_measure(self, point1, point2, metric='Euclidean'):
        return math.sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(point1, point2)))

    
    #find median uses double quick sort
    def __select_Kth_num(self, k, array, split, low, high):
        if low >= high or array.shape[0] == 0 or k < 0 :
            return low
        pivot = np.random.randint(0, high-low+1) + low
        pivot_val = array[pivot, split]
        array[low],array[pivot] = array[pivot].copy(),array[low].copy()
        i = low+1
        j = high
        while(True):
            while(i <= high and array[i, split] < pivot_val):
                i += 1
            while(j >= low+1 and array[j, split] > pivot_val):
                j -= 1
            if(i > j):
                break
            array[i],array[j] = array[j].copy(),array[i].copy()
            i += 1
            j -= 1
        array[low],array[j] = array[j].copy(),array[low].copy()
        if j == k:
            return j
        elif j > k:
            return self.__select_Kth_num(k, array, split, low, j-1)
        else:
            return self.__select_Kth_num(k, array, split, j+1, high)
            
            
     #find the index of median
    def __find_median(self, array, split):
        length = len(array)
        if length <= 0:
            return None
        mid = length // 2
        return self.__select_Kth_num(mid, array, split, 0, length-1)

In [51]:
data = np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])
kd = KDTree(data)
kd.show(kd.root)
n = kd.search(kd.root, [3, 4.5], 3)
print(n)


<__main__.KDNode object at 0x000001F93E03ABE0>
[7 2]
[5 4]
[2 3]
[4 7]
[9 6]
[8 1]
<class 'list'>
<class 'list'>
<class 'list'>
<class 'list'>
<class 'list'>
<class 'list'>
<class 'list'>
<class 'list'>
[[1 array([4, 7])]
 [1 array([5, 4])]
 [1 array([2, 3])]]
