Ball Tree: Divides space using a center and radius, suitable for high-dimensional data.

In [1]:
import numpy as np

In [2]:
class BallTree:
    def __init__(self, data, leaf_size=10):
        self.leaf_size = leaf_size
        self.root = self._build_tree(data)
    
    def _build_tree(self, data):
        if len(data) <= self.leaf_size:
            return BallTreeNode(center=np.mean(data, axis=0), radius=0, points=data)
        
        center = np.mean(data, axis=0)
        distances = np.linalg.norm(data - center, axis=1)
        radius = np.max(distances)
        
        median_idx = np.argsort(distances)[len(distances) // 2]
        left_data = data[distances <= distances[median_idx]]
        right_data = data[distances > distances[median_idx]]
        
        return BallTreeNode(center=center, radius=radius,
                            left=self._build_tree(left_data),
                            right=self._build_tree(right_data))
    
    def query(self, point, k=1):
        best = []
        self._query(self.root, point, k, best)
        return [b[1] for b in sorted(best)[:k]]
    
    def _query(self, node, point, k, best):
        if node.is_leaf():
            for p in node.points:
                dist = np.linalg.norm(p - point)
                best.append((dist, p))
            best.sort()
            return
        
        dist_to_center = np.linalg.norm(node.center - point)
        if node.left and (len(best) < k or dist_to_center - node.radius < best[-1][0]):
            self._query(node.left, point, k, best)
        if node.right and (len(best) < k or dist_to_center + node.radius < best[-1][0]):
            self._query(node.right, point, k, best)


In [3]:
class BallTreeNode:
    def __init__(self, center, radius, left=None, right=None, points=None):
        self.center = center
        self.radius = radius
        self.left = left
        self.right = right
        self.points = points
    
    def is_leaf(self):
        return self.points is not None