# 第3章 kd Tree
在本章节主要实现K-dimension tree, 并复现书中例3.2

实现kNN算法时，主要考虑如何对训练数据进行快速k近邻搜索，这一点在特征空间维数大及训练数据容量大时尤其重要。

kNN算法最简单的实现是线性扫描，但是这种计算是非常耗时的。

kd Tree 适用于训练实例数远大于空间维数时的kNN搜索，当空间维数接近于训练实例数时，它的效率会迅速下降，几乎接近线性扫描。

### Code
kd tree是对数据点在k维空间中划分的一种数据结构。kd tree实际上是一种二叉树。每个结点的内容如下：
![kdTree_parameters](kdTree_parameters.png)
样本集E由k-d tree的结点的集合表示，每个结点表示一个样本点，dom_elt就是表示该样本点的向量。该样本点根据结点的分割超平面将样本空间分为两个子空间。左子空间中的样本点集合由左子树left表示，右子空间中的样本点集合由右子树right表示。分割超平面是一个通过点dom_elt并且垂直于split所指示的方向轴的平面。举个简单的例子，在二维的情况下，一个样本点可以由二维向量(x,y)表示，其中令x维的序号为0，y维的序号为1。假设一个结点的dom_elt为(7,2) ，split的取值为0，那么分割超面就是x=dom_elt(0)=7，它垂直与x轴且过点(7,2)。

In [1]:
# create kd Tree's node
class KdNode(object):
    def __init__(self, dom_elt, split, left, right):
        self.dom_elt = dom_elt
        self.split = split
        self.left = left
        self.right = right
        
class KdTree(object):
    def __init__(self, data):
        k = len(data[0])  # the dimension of data
        
        def CreateNode(split, data_set): #create a kdNode by dividing the dataset by the split dimension
            if not data_set:  # the dataset is none
                return None
            
            data_set.sort(key=lambda x:x[split]) #sort the data by dimension that needs to be split
            median_pos = len(data_set) // 2    # '//' is the interger division in python
            median_value = data_set[median_pos]
            split_next = (split +1) % k   # cycle split dimension
            
            return KdNode(median_value, split, 
                         CreateNode(split_next, data_set[:median_pos]),  # create left tree
                         CreateNode(split_next, data_set[median_pos+1:]))#create right tree
        
        self.root = CreateNode(0, data)
        
# Preorder traversal of kd Tree
def preorder(root):
    print(root.dom_elt)
    if root.left:  # the node is not empty
        preorder(root.left)
    if root.right:
        preorder(root.right)

In [2]:
# kd tree search
from math import sqrt
from collections import namedtuple

# define a namedtuple, save the nearest coordinate point, closest distance, and number of the nodes visited
result = namedtuple("Result_tuple", "nearest_point nearest_dist nodes_visited")

def find_nearest(tree, point):
    k = len(point) # the dimension of data
    def travel(kd_node, target, max_dist):
        if kd_node is None:
            return result([0] * k, float("inf"), 0)   # [0] * 3 ——> [0, 0, 0]
        
        nodes_visited = 1
        
        split = kd_node.split
        axis = kd_node.dom_elt  # split according to axis dom_elt
        
        if target[split] <= axis[split]: # if the target point s dimension is smaller than the corresponding value of the split axis
                                         # the target is closer to the left subtree
            nearer_node = kd_node.left   # the next visit node is the left subtree root
            further_node = kd_node.right # record the right subtree
        else:
            nearer_node = kd_node.right
            further_node = kd_node.left
            
        temp1 = travel(nearer_node, target, max_dist) # traversing to find the area containing the target point
        
        nearest = temp1.nearest_point  # use this leaf node as the "current nearest point"
        dist = temp1.nearest_dist   # update the nearest distance
        
        nodes_visited += temp1.nodes_visited
        
        if dist < max_dist:   
            max_dist = dist   # The nearest point will be in the supersphere with 
                              # the target point as the center of the sphere and max_dist as the radius.
        temp_dist = abs(axis[split] - target[split] )
        if max_dist < temp_dist:   # Determine if the supersphere intersects the hyperplane
            return result(nearest, dist, nodes_visited) # if they do not intersect ,return directly
        
        #calculate the Euclidean distance between the target point and the split point
        temp_dist = sqrt(sum((p1 - p2)  ** 2 for p1, p2 in zip(axis, target)))
        
        if temp_dist < dist:  # if nearer, update nearest node、nearest distance、radius of the sphere
            nearest = axis
            dist = temp_dist
            max_dist = dist
            
        # calculate whether the area corresponding to another child node has a closer point
        temp2 = travel(further_node, target, max_dist)
        
        nodes_visited += temp2.nodes_visited
        if temp2.nearest_dist < dist:    # if nearer, update
            nearest = temp2.nearest_point
            dist = temp2.nearest_dist
            
        return result(nearest, dist, nodes_visited)
    return travel(tree.root, point, float("inf")) # recursive from the root node

例3.2

In [3]:
data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
kd = KdTree(data)
preorder(kd.root)

[7, 2]
[5, 4]
[2, 3]
[4, 7]
[9, 6]
[8, 1]


In [4]:
from time import clock
from random import random

# create k-dimensional random vector
def random_point(k):
    return [random() for _ in range(k)]

# create n k-dimension random vectors
def random_points(k, n):
    return[random_point(k) for _ in range(n)]

In [5]:
ret = find_nearest(kd, [3, 4, 5])
print(ret)

Result_tuple(nearest_point=[2, 3], nearest_dist=1.4142135623730951, nodes_visited=4)


In [6]:
N = 400000
t0 = clock()
kd2 = KdTree(random_points(3, N)) # creating the kd Tree with 400000 3-dimension sample
ret2 = find_nearest(kd2, [0.1, 0.5, 0.8])
t1 = clock()

print("time: ", t1 - t0, "s")
print(ret2)

time:  4.6178349999999995 s
Result_tuple(nearest_point=[0.10099331205659678, 0.5031323363804084, 0.811750046284898], nearest_dist=0.012200892907449641, nodes_visited=50)
