In [1]:
import numpy as np

In [2]:
class KDTree:
    def __init__(self, data, depth=0):
        self.depth = depth
        self.axis = depth % data.shape[1]
        self.data = data
        
        if len(data) <= 1:
            self.median = data
            self.left = None
            self.right = None
        else:
            data = data[data[:, self.axis].argsort()]
            median_idx = len(data) // 2
            self.median = data[median_idx]
            self.left = KDTree(data[:median_idx], depth + 1)
            self.right = KDTree(data[median_idx + 1:], depth + 1)
    
    def query(self, point, k=1):
        best = []
        self._query(point, k, best)
        return [b[1] for b in sorted(best)[:k]]
    
    def _query(self, point, k, best):
        if self.median is not None:
            dist = np.linalg.norm(self.median - point)
            best.append((dist, self.median))
            best.sort()
        
        if len(best) > k:
            best = best[:k]
        
        if self.left is None and self.right is None:
            return
        
        axis_dist = point[self.axis] - self.median[self.axis]
        
        if axis_dist <= 0:
            if self.left is not None:
                self.left._query(point, k, best)
            if self.right is not None and (len(best) < k or abs(axis_dist) < best[-1][0]):
                self.right._query(point, k, best)
        else:
            if self.right is not None:
                self.right._query(point, k, best)
            if self.left is not None and (len(best) < k or abs(axis_dist) < best[-1][0]):
                self.left._query(point, k, best)


In [3]:
# Example usage
data = np.random.random((100, 3))
kd_tree = KDTree(data)
print(kd_tree.query(np.array([0.5, 0.5, 0.5]), k=3))


[array([0.36198613, 0.46734936, 0.44116492]), array([0.58717493, 0.5778671 , 0.63990788]), array([0.5833976 , 0.74674175, 0.73618636])]
