In [1]:
import numpy as np

In [2]:
from sklearn.datasets import load_iris
data = load_iris()
X,y = data["data"],data["target"]

In [3]:
# 
class KDTree:
    def buildTree(self,X,depth):
        n_size = len(X)
        #递归终止条件
        if n_size == 1:
            self.data = X[0]
            print(f"depth:{depth},data:{X[0]}")
            return self
        
        self.fi = depth % self.n_feature
        argsort = np.argsort(X[:,self.fi])
        middle_idx = argsort[n_size // 2]
        left_idxs,right_idxs = argsort[:n_size//2],argsort[n_size//2+1:]
        
        self.fv = X[middle_idx,self.fi]
        self.data = X[middle_idx]
        print(f"depth:{depth},data:{X[middle_idx]}")
        if len(left_idxs) > 0:
            self.left = self.buildTree(X[left_idxs],depth+1)
        if len(right_idxs) > 0:
            self.right = self.buildTree(X[right_idxs],depth+1)
        return self
        
    def fit(self,X):
        self.n_feature = X.shape[-1]
        self.buildTree(X,0)
        return self

In [4]:
X = np.array([(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)])
X

array([[2, 3],
       [5, 4],
       [9, 6],
       [4, 7],
       [8, 1],
       [7, 2]])

In [5]:
tree = KDTree()
tree = tree.fit(X)

depth:0,data:[7 2]
depth:1,data:[5 4]
depth:2,data:[2 3]
depth:2,data:[4 7]
depth:1,data:[9 6]
depth:2,data:[8 1]


In [6]:
print(tree.data)

[8 1]


In [9]:
# 以上代码会出现递归时的混乱现象，考虑还是将TreeNode和BuildTree分开
class TreeNode:
    def __init__(self,data=None,fi=None,fv=None,left=None,right=None):
        self.data = data
        self.fi = fi
        self.fv = fv
        self.left = left
        self.right = right

class KDTree:
    def buildTree(self,X,depth):
        n_size,n_feature = X.shape
        #递归终止条件
        if n_size == 1:
            tree = TreeNode(data=X[0])
            return tree

        fi = depth % n_feature
        argsort = np.argsort(X[:,fi])
        middle_idx = argsort[n_size // 2]
        left_idxs,right_idxs = argsort[:n_size//2],argsort[n_size//2+1:]

        fv = X[middle_idx,fi]
        data = X[middle_idx]
        left,right = None,None
        if len(left_idxs) > 0:
            left = self.buildTree(X[left_idxs],depth+1)
        if len(right_idxs) > 0:
            right = self.buildTree(X[right_idxs],depth+1)
        tree = TreeNode(data,fi,fv,left,right)
        return tree
    
    def fit(self,X):
        self.tree = self.buildTree(X,0)

In [10]:
kdtree = KDTree()
kdtree.fit(X)

In [11]:
kdtree.tree.right.data

array([9, 6])

In [12]:
def distance(a,b):
    return np.sqrt(((a-b)**2).sum())

In [13]:
x = np.array([2,4.5])
nearest_point = kdtree.tree.data
nearest_dis = distance(kdtree.tree.data,x)

In [14]:
def find_nearest(kdtree,x):
    global nearest_dis,nearest_point
    if kdtree == None:
        return
    
    #如果根节点到目标点的距离小于最近距离，则更新nearest_point和nearest_dis
    if distance(kdtree.data,x) < nearest_dis:
        nearest_dis = distance(kdtree.data,x)
        nearest_point = kdtree.data
    
    if kdtree.fi == None or kdtree.fv == None:
        return
    
    #进入下一个相应的子节点
    if x[kdtree.fi] < kdtree.fv:
        find_nearest(kdtree.left,x)
        if x[kdtree.fi] + nearest_dis > kdtree.fv:
            find_nearest(kdtree.right,x)
    elif x[kdtree.fi] > kdtree.fv:
        find_nearest(kdtree.right,x)
        if x[kdtree.fi] - nearest_dis < kdtree.fv:
            find_nearest(kdtree.left,x)
    else:
        find_nearest(kdtree.left,x)
        find_nearest(kdtree.right,x)

In [15]:
find_nearest(kdtree.tree,x)
print(nearest_dis)
print(nearest_point)

1.5
[2 3]


In [18]:
# 整合成一个KDTree
class KDTree:
    def buildTree(self,X,depth):
        n_size,n_feature = X.shape
        #递归终止条件
        if n_size == 1:
            tree = TreeNode(data=X[0])
            return tree

        fi = depth % n_feature
        argsort = np.argsort(X[:,fi])
        middle_idx = argsort[n_size // 2]
        left_idxs,right_idxs = argsort[:n_size//2],argsort[n_size//2+1:]

        fv = X[middle_idx,fi]
        data = X[middle_idx]
        left,right = None,None
        if len(left_idxs) > 0:
            left = self.buildTree(X[left_idxs],depth+1)
        if len(right_idxs) > 0:
            right = self.buildTree(X[right_idxs],depth+1)
        tree = TreeNode(data,fi,fv,left,right)
        return tree
    
    def fit(self,X):
        self.tree = self.buildTree(X,0)
        
    @staticmethod
    def distance(a,b):
        return np.sqrt(((a-b)**2).sum())
    
    def find_nearest(self,x):
        nearest_point = self.tree.data
        nearest_dis = KDTree.distance(self.tree.data,x)
        def travel(kdtree,x):
            nonlocal nearest_dis,nearest_point
            if kdtree == None:
                return

            #如果根节点到目标点的距离小于最近距离，则更新nearest_point和nearest_dis
            if KDTree.distance(kdtree.data,x) < nearest_dis:
                nearest_dis = KDTree.distance(kdtree.data,x)
                nearest_point = kdtree.data

            if kdtree.fi == None or kdtree.fv == None:
                return

            #进入下一个相应的子节点
            if x[kdtree.fi] < kdtree.fv:
                travel(kdtree.left,x)
                if x[kdtree.fi] + nearest_dis > kdtree.fv:
                    travel(kdtree.right,x)
            elif x[kdtree.fi] > kdtree.fv:
                travel(kdtree.right,x)
                if x[kdtree.fi] - nearest_dis < kdtree.fv:
                    travel(kdtree.left,x)
            else:
                travel(kdtree.left,x)
                travel(kdtree.right,x)
        travel(self.tree,x)
        return nearest_point,nearest_dis

In [19]:
kdtree = KDTree()
kdtree.fit(X)
x = np.array([2,4.5])
kdtree.find_nearest(x)

(array([2, 3]), 1.5)