In [None]:
import random
import numpy as np
from statistics import mean
from random import randrange
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import time
import sys
from scipy.spatial import distance
from operator import itemgetter
%matplotlib inline

In [None]:
def compute_distance(test_X, train_X):
    return np.sqrt(np.sum(test_X**2, axis=1, keepdims=True) + np.sum(train_X**2, axis=1, keepdims=True).T - 2*np.dot(test_X, train_X.T))

In [1]:
class QuickSort(object):
    "Quick Sort to get medium number"

    def __init__(self, low, high, array):
        self._array = array
        self._low = low
        self._high = high
        self._medium = (low+high+1)//2 # python3中的整除

    def get_medium_num(self):
        return self.quick_sort_for_medium(self._low, self._high, 
                                          self._medium, self._array)

    def quick_sort_for_medium(self, low, high, medium, array): #用快速排序来获取中位数
        if high == low:
            return array[low] # find medium
        if high > low:
            index, partition = self.sort_partition(low, high, array); 
            #print array[low:index], partition, array[index+1:high+1]
            if index == medium:
                return partition
            if index > medium:
                return self.quick_sort_for_medium(low, index-1, medium, array)
            else:
                return self.quick_sort_for_medium(index+1, high, medium, array)

    def quick_sort(self, low, high, array):  #正常的快排
        if high > low:
            index, partition = self.sort_partition(low, high, array); 
            #print array[low:index], partition, array[index+1:high+1]
            self.quick_sort(low, index-1, array)
            self.quick_sort(index+1, high, array)

    def sort_partition(self, low, high, array): # 用第一个数将数组里面的数分成两部分
        index_i = low
        index_j = high
        partition = array[low]
        while index_i < index_j:
            while (index_i < index_j) and (array[index_j] >= partition):
                index_j -= 1
            if index_i < index_j:
                array[index_i] = array[index_j]
                index_i += 1
            while (index_i < index_j) and (array[index_i] < partition):
                index_i += 1
            if index_i < index_j:
                array[index_j] = array[index_i]
                index_j -= 1
        array[index_i] = partition
        return index_i, partition

In [2]:
class KDTree(object):

    def __init__(self, input_x, input_y):
        self._input_x = np.array(input_x)
        self._input_y = np.array(input_y)
        (data_num, axes_num) = np.shape(self._input_x)
        self._data_num = data_num
        self._axes_num = axes_num
        self._nearest = None  #用来存储最近的节点
        return

    def construct_kd_tree(self):
        return self._construct_kd_tree(0, 0, self._input_x)

    def _construct_kd_tree(self, depth, axes, data):
        if not data.any():
            return None
        axes_data = data[:, axes].copy()
        qs = QuickSort(0, axes_data.shape[0]-1, axes_data)
        medium = qs.get_medium_num() #找到轴的中位数

        data_list = []
        left_data = []
        right_data = []
        data_range = range(np.shape(data)[0])
        for i in data_range:   # 跟中位数相比较
            if data[i][axes] == medium:  #相等
                data_list.append(data[i])
            elif data[i][axes] < medium: 
                left_data.append(data[i])
            else:
                right_data.append(data[i])

        left_data = np.array(left_data)
        right_data = np.array(right_data)
        left = self._construct_kd_tree(depth+1, (axes+1)% self._axes_num, left_data)
        right = self._construct_kd_tree(depth+1, (axes+1)% self._axes_num, right_data)
        #[树的深度，轴，中位数，该节点的数据，左子树，右子树]
        root = [depth, axes, medium, data_list, left, right] 
        return root

    def print_kd_tree(self, root): #打印kd树
        if root:
            [depth, axes, medium, data_list, left, right] = root
            print('{} {}'.format('    '*depth, data_list[0]))
            if root[4]:
                self.print_kd_tree(root[4])
            if root[5]:
                self.print_kd_tree(root[5])

In [None]:
input_x = [[2,3], [6,4], [9,6], [4,7], [8,1], [7,2]]
input_y = [1, 1, 1, 1, 1, 1]
kd = KDTree(input_x, input_y)
tree = kd.construct_kd_tree()
kd.print_kd_tree(tree)