In [61]:
import random
import math
import numpy as np

from result_set import KNNResultSet, RadiusNNResultSet

In [62]:
# Node类，Node是tree的基本组成元素
class Node:
    def __init__(self, axis, value, left, right, point_indices):
        self.axis = axis
        self.value = value
        self.left = left
        self.right = right
        self.point_indices = point_indices

    def is_leaf(self):
        if self.value is None:
            return True
        else:
            return False

    def __str__(self):
        output = ''
        output += 'axis %d, ' % self.axis
        if self.value is None:
            output += 'split value: leaf, '
        else:
            output += 'split value: %.2f, ' % self.value
        output += 'point_indices: '
        output += str(self.point_indices.tolist())
        return output

In [63]:
# 功能：构建树之前需要对value进行排序，同时对一个的key的顺序也要跟着改变
# 输入：
#     key：键
#     value:值
# 输出：
#     key_sorted：排序后的键
#     value_sorted：排序后的值
def sort_key_by_vale(key, value):
    assert key.shape == value.shape
    assert len(key.shape) == 1
    sorted_idx = np.argsort(value)
    key_sorted = key[sorted_idx]
    value_sorted = value[sorted_idx]
    return key_sorted, value_sorted


def axis_round_robin(axis, dim):
    if axis == dim-1:
        return 0
    else:
        return axis + 1

In [64]:
# 功能：通过递归的方式构建树
# 输入：
#     root: 树的根节点
#     db: 点云数据
#     point_indices：排序后的键
#     axis: scalar
#     leaf_size: scalar
# 输出：
#     root: 即构建完成的树
def kdtree_recursive_build(root, db, point_indices, axis, leaf_size):
    if root is None:
        root = Node(axis, None, None, None, point_indices)

    # determine whether to split into left and right
    if len(point_indices) > leaf_size:
        # --- get the split position ---
        point_indices_sorted, _ = sort_key_by_vale(point_indices, db[point_indices, axis])  # M
        
        # 作业1
        # 屏蔽开始
        mid_left_idx = math.ceil(point_indices_sorted.shape[0]/2)-1
        mid_left_point_idx = point_indices_sorted[mid_left_idx]
        mid_left_point_val = db[mid_left_point_idx,axis]
        
        mid_right_idx =mid_left_idx+1
        mid_right_point_idx = point_indices_sorted[mid_right_idx]
        mid_right_point_val = db[mid_right_point_idx,axis]
        root.value = (mid_left_point_val + mid_right_point_val)*0.5
        root.left = kdtree_recursive_build(root.left,db,\
                                           point_indices_sorted[0:mid_right_idx],\
                                           axis_round_robin(axis,dim=db.shape[1]),\
                                           leaf_size)
        root.right =kdtree_recursive_build(root.right,db,\
                                        point_indices_sorted[mid_right_idx:],\
                                          axis_round_robin(axis,dim=db.shape[1]),\
                                           leaf_size)

        
        # 屏蔽结束
    return root


In [65]:
# 功能：翻转一个kd树
# 输入：
#     root：kd树
#     depth: 当前深度
#     max_depth：最大深度
def traverse_kdtree(root: Node, depth, max_depth):
    depth[0] += 1
    if max_depth[0] < depth[0]:
        max_depth[0] = depth[0]

    if root.is_leaf():
        print(root)
    else:
        traverse_kdtree(root.left, depth, max_depth)
        traverse_kdtree(root.right, depth, max_depth)

    depth[0] -= 1

In [66]:
# 功能：构建kd树（利用kdtree_recursive_build功能函数实现的对外接口）
# 输入：
#     db_np：原始数据
#     leaf_size：scale
# 输出：
#     root：构建完成的kd树
def kdtree_construction(db_np, leaf_size):
    N, dim = db_np.shape[0], db_np.shape[1]

    # build kd_tree recursively
    root = None
    root = kdtree_recursive_build(root,
                                  db_np,
                                  np.arange(N),
                                  axis=0,
                                  leaf_size=leaf_size)
    return root


In [67]:
# 功能：通过kd树实现knn搜索，即找出最近的k个近邻
# 输入：
#     root: kd树
#     db: 原始数据
#     result_set：搜索结果
#     query：索引信息
# 输出：
#     搜索失败则返回False
def kdtree_knn_search(root: Node, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
    if root is None:
        return False

    if root.is_leaf():
        # compare the contents of a leaf
        leaf_points = db[root.point_indices, :]
        diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
        for i in range(diff.shape[0]):
            result_set.add_point(diff[i], root.point_indices[i])
        return False

    # 作业2
    # 提示：仍通过递归的方式实现搜索
    # 屏蔽开始
    if query[root.axis] <=root.value:
        knn_search(root.left,db,result_set,query)
        if math.fabs(query[root.axis] - root.value) <result_set.worstDist():
            knn_search(root.right,db,result_set,query)
    else:
        knn_search(root.right,db,result_set,query)
        if math.fabs(query[root.axis]-root.value)<result_set.worstDist():
            knn_search(root.left,db,result_set,query)

    # 屏蔽结束

    return False


In [69]:
def main():
    # configuration
    db_size = 64
    dim = 3
    leaf_size = 4
    k = 1

    db_np = np.random.rand(db_size, dim)
    print("db_np is",db_np)

    root = kdtree_construction(db_np, leaf_size=leaf_size)

    depth = [0]
    max_depth = [0]
    traverse_kdtree(root, depth, max_depth)
    print("tree max depth: %d" % max_depth[0])

    # query = np.asarray([0, 0, 0])
    # result_set = KNNResultSet(capacity=k)
    # knn_search(root, db_np, result_set, query)
    #
    # print(result_set)
    #
    # diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
    # nn_idx = np.argsort(diff)
    # nn_dist = diff[nn_idx]
    # print(nn_idx[0:k])
    # print(nn_dist[0:k])
    #
    #
    # print("Radius search:")
    # query = np.asarray([0, 0, 0])
    # result_set = RadiusNNResultSet(radius = 0.5)
    # radius_search(root, db_np, result_set, query)
    # print(result_set)


if __name__ == '__main__':
    main()

db_np is [[0.77829522 0.90465558 0.78029335]
 [0.6060835  0.9876036  0.79371144]
 [0.13929295 0.53280759 0.13525921]
 [0.01986096 0.5986425  0.72680043]
 [0.13843965 0.78014549 0.34605709]
 [0.1579894  0.06921405 0.56646989]
 [0.88745421 0.73713546 0.50555584]
 [0.77733099 0.75559697 0.85397762]
 [0.36587958 0.72211325 0.3666477 ]
 [0.36059893 0.70520732 0.08535405]
 [0.23713277 0.28376374 0.68021894]
 [0.01476498 0.7117849  0.92219346]
 [0.58492873 0.95938796 0.57998443]
 [0.4414286  0.89483841 0.66275199]
 [0.61608037 0.58729979 0.02593442]
 [0.6330328  0.95669243 0.11356709]
 [0.39536586 0.45450982 0.95307751]
 [0.6242124  0.60912613 0.92343955]
 [0.06837642 0.38733411 0.18377745]
 [0.51194069 0.92919338 0.12731619]
 [0.36077365 0.89288475 0.4208398 ]
 [0.41527412 0.45362026 0.27810713]
 [0.7888532  0.46467792 0.64108233]
 [0.81603502 0.42848016 0.32677614]
 [0.15904848 0.38473848 0.67679895]
 [0.79234842 0.94496245 0.01747973]
 [0.0711919  0.44431494 0.51627973]
 [0.67916478 0.4246