In [5]:
import numpy as np

In [121]:
class KDTreeNode:
    def __init__(self, data, split_axis, father):
        self.data = data
        self.split_axis = split_axis
        self.father = father
        self.left = None
        self.right = None
        self.dist = float('inf')
    @property
    def mate(self):
        if not self.father:
            return None
        if self.father.left == self:
            return self.father.right
        else:
            return self.father.left
    def __lt__(self, other):
        return self.dist < other.dist

In [128]:
class KNN:
    def __init__(self, use_kdtree=True, np_dataset=False):
        self.use_kdtree = use_kdtree
        self.np_dataset = np_dataset
    def train(self, dataset):
        if self.use_kdtree:
            if self.np_dataset:
                self.create_kdtree_np(dataset)
            else:
                self.create_kdtree(dataset)
        else:
            self.dataset = dataset
    def test(self, query_data, k):
        self.k = k
        if self.use_kdtree:
            if self.k == 1:
                return self.kdtree_search_nearest(query_data)
            return self.kdtree_search(query_data)
        else:
            return self.scan_search(query_data)
    def create_kdtree_np(self, dataset):
        def dfs(dataset, split_axis, father):
            if not dataset.size:
                return None
            lens = len(dataset)
            dataset = dataset[dataset[:, split_axis].argsort()]
            split_idx = lens//2
            split_data = dataset[split_idx]
            left_dataset = dataset[:split_idx]
            right_dataset = dataset[split_idx+1:]
            root = KDTreeNode(split_data, split_axis, father)
            root.left = dfs(left_dataset, (split_axis+1)%len(dataset[0]), root)
            root.right = dfs(right_dataset, (split_axis+1)%len(dataset[0]), root)
            return root
        root = dfs(dataset, 0, None)
        self.kdtree = root
    def create_kdtree(self, dataset):
        def dfs(dataset, split_axis, father):
            if not dataset:
                return None
            lens = len(dataset)
            dataset.sort(key=lambda x:x[split_axis])
            split_idx = lens//2
            split_data = dataset[split_idx]
            left_dataset = dataset[:split_idx]
            right_dataset = dataset[split_idx+1:]
            root = KDTreeNode(split_data, split_axis, father)
            root.left = dfs(left_dataset, (split_axis+1)%len(dataset[0]), root)
            root.right = dfs(right_dataset, (split_axis+1)%len(dataset[0]), root)
            return root
        root = dfs(dataset, 0, None)
        self.kdtree = root
    def distance(self, a, b, dist_type='l2'):
        if dist_type == 'l2':
            if self.np_dataset:
                return np.sum((a-b)**2)
            if not isinstance(a, (list, tuple)):
                a = [a]
                b = [b]
            return sum([(i-j)**2 for (i, j) in zip(a, b)])
    def scan_search(self, query_data):
        dist_map = [(data, self.distance(query_data, data)) for data in self.dataset]
        dist_map.sort(key=lambda x:x[1])
        return [item[0] for item in dist_map[:self.k]]
    def kdtree_search(self, query_data):
        import heapq
        heap = []
        def dfs(root, query_data):
            if not root.left and not root.right:
                nearest = root
                min_dist = self.distance(root.data, query_data)
                return (nearest, min_dist)
            split_axis = root.split_axis
            if query_data[split_axis] < root.data[split_axis]:
                if root.left:
                    return dfs(root.left, query_data)
                else:
                    return dfs(root.right, query_data)
            else:
                if root.right:
                    return dfs(root.right, query_data)
                else:
                    return dfs(root.left, query_data)
        def region_search(root, query_data):
            nearest, max_min_dist = dfs(root, query_data)
            node = nearest
            while True:
                curr_dist = self.distance(node.data, query_data)
                node.dist = -curr_dist
                if len(heap) < self.k:
                    heapq.heappush(heap, node)
                else:
                    if heapq.nsmallest(1, heap)[0].dist < -curr_dist:
                        heapq.heapreplace(heap, node)
                if node == root:
                    break
                father = node.father
                plane_dist = self.distance(father.data[father.split_axis], query_data[father.split_axis]) #查询点和超平面的距离
                if (len(heap) < self.k or plane_dist < max_min_dist) and node.mate:
                    region_search(node.mate, query_data)
                node = node.father
        region_search(self.kdtree, query_data)
        return [item.data for item in heap]
    def kdtree_search_nearest(self, query_data):
        #最近邻搜索
        def dfs(root, query_data):
            if not root.left and not root.right:
                nearest = root
                min_dist = self.distance(root.data, query_data)
                return (nearest, min_dist)
            split_axis = root.split_axis
            if query_data[split_axis] < root.data[split_axis]:
                if root.left:
                    return dfs(root.left, query_data)
                else:
                    return dfs(root.right, query_data)
            else:
                if root.right:
                    return dfs(root.right, query_data)
                else:
                    return dfs(root.left, query_data)
        def region_search(root, query_data):
            nearest, min_dist = dfs(root, query_data)
            node = nearest
            while True:
                curr_dist = self.distance(node.data, query_data)
                if curr_dist < min_dist:
                    nearest = node
                    min_dist = curr_dist
                if node == root:
                    break
                father = node.father
                plane_dist = self.distance(father.data[father.split_axis], query_data[father.split_axis]) #查询点和超平面的距离
                if plane_dist < min_dist and node.mate:
                    mate_nearest, mate_mindist = region_search(node.mate, query_data)
                    if mate_mindist < min_dist:
                        min_dist = mate_mindist
                        nearest = mate_nearest
                node = node.father
            return nearest, min_dist
        self.nearest, _ = region_search(self.kdtree, query_data)
        return self.nearest.data

In [129]:
knn1 = KNN(np_dataset=True)
knn2 = KNN(np_dataset=True, use_kdtree=False)

In [130]:
for _ in range(1):
    dataset_np = np.random.rand(128).reshape(32,4)
    knn1.train(dataset_np)
    knn2.train(dataset_np)
    node1 = knn1.test(np.array([0.5,0.5,0.5,0.5]), 4)
    node2 = knn2.test(np.array([0.5,0.5,0.5,0.5]), 4)
    print(node1)
    print(node2)

[array([0.28050092, 0.32473445, 0.41601298, 0.66665668]), array([0.71771779, 0.40721258, 0.45705213, 0.72531417]), array([0.51532176, 0.35700408, 0.35264483, 0.68879033]), array([0.23581586, 0.43221634, 0.63873245, 0.52189906])]
[array([0.51532176, 0.35700408, 0.35264483, 0.68879033]), array([0.23581586, 0.43221634, 0.63873245, 0.52189906]), array([0.71771779, 0.40721258, 0.45705213, 0.72531417]), array([0.28050092, 0.32473445, 0.41601298, 0.66665668])]


In [135]:
a = np.random.rand(40000).reshape(10000,4)

In [132]:
import time

In [136]:
stime = time.time()
knn1.train(a)
node1 = knn1.test(np.array([0.5,0.5,0.5,0.5]), 4)
etime = time.time()
print(node1)
print(etime-stime)

[array([0.51826398, 0.55013395, 0.56642872, 0.51797208]), array([0.48096842, 0.49805931, 0.42931577, 0.54564063]), array([0.52429017, 0.54667381, 0.44003482, 0.51805642]), array([0.51620391, 0.55687684, 0.44156029, 0.47703573])]
0.12363815307617188


In [137]:
stime = time.time()
knn2.train(a)
node1 = knn2.test(np.array([0.5,0.5,0.5,0.5]), 4)
etime = time.time()
print(node1)
print(etime-stime)

[array([0.52429017, 0.54667381, 0.44003482, 0.51805642]), array([0.51620391, 0.55687684, 0.44156029, 0.47703573]), array([0.48096842, 0.49805931, 0.42931577, 0.54564063]), array([0.51826398, 0.55013395, 0.56642872, 0.51797208])]
0.07579922676086426
