In [1]:
k = 2 # dimensions
points= [(11,10),(4,7),(16,10),(9,4), (7,13), (15,3), (14,11), (1,1)]

def build_kdtree(points, depth = 0, parent = None):
    n = len(points)
    if n <= 0:
        return None
    
    axis = depth % k
    
    sorted_points = sorted(points, key=lambda point: point[axis])
    
    d = {
        "point": sorted_points[n//2],
        "parent": parent
    }
    
    d["left"] = build_kdtree(sorted_points[:n//2], depth +1, d)
    d["right"]= build_kdtree(sorted_points[n//2+1:], depth +1, d)
    return d

tree = build_kdtree(points)


In [2]:
from math import sqrt

def distance(p1, p2, axis=None):
    if axis is None:
        return sqrt(sum([(x - y)**2 for x, y in zip(p1, p2)]))
    else:
        return sqrt((p1[axis] - p2[axis])**2)

def naive_closest_point(root, target, depth=0, best=None):
    axis = depth % k
    
    # if there is no tree return None
    if root is None:
        return best
    
    # best is None the first time
    if best is None:
        best = root['point']

    # is new point close to target       best match so far
    if distance(target, root['point']) < distance(target, best):
        best = root['point']
        
    # What way are we going
    if target[axis] < root['point'][axis]:
        next_branch = root['left']
    else:
        next_branch = root['right']

    return naive_closest_point(next_branch, target, depth + 1, best)

In [3]:
def closer(target, p1, p2, axis = None):
    if p1 is None:
        return p2
    if p2 is None:
        return p1
    
    d1 = distance(target, p1, axis)
    d2 = distance(target, p2, axis)
    
    if d1 < d2:
        return p1
    else:
        return p2

def closest_point(root, target, depth=0, best=None):
    axis = depth % k
    
    # if there is no tree return None
    if root is None:
        return None
    
    # best is None the first time
    if best is None:
        best = root['point']
        
    # What way are we going
    if target[axis] < root['point'][axis]:
        next_branch = root['left']
        oppo_branch = root['right']
    else:
        next_branch = root['right']
        oppo_branch = root['left']
        
    # finds the best from what we think is best branch
    best = closer(target, closest_point(next_branch, target, depth+1), best)
    
    # if distance from point to target is less than the distance from best so far to target
    # we wanna inspect values from the opposit branch to make sure no point there is better that curent best
    if distance(target, root['point'], axis) <= distance(target, best):
        best = closer(target, closest_point(oppo_branch, target, depth+1), best)

    return best

In [4]:
def points_in_reach(root, target, reach):
    
    vals = []
    
    def helper(root, target, reach, depth=0):

        axis = depth % k
        # if there is no tree return None
        if root is None:
            return None
        
        dist = distance(root['point'], target, axis)
        
        # if we are in reach, we check current point and go both left and right
        if dist <= reach:
            if distance(root['point'], target) <= reach:
                vals.append(root['point'])

            helper(root['right'], target, reach, depth + 1)
            helper(root['left'], target, reach, depth + 1)
        
        # if we are out of reach, we go either left or right depending on
        # wheter we we have to increase or decrease the axis to get closer
        if dist > reach:
            # if root is too small we go right
            if root['point'][axis] < target[axis]:
                helper(root['right'], target, reach, depth + 1)
            # if root is too big we go left
            else:
                helper(root['left'], target, reach, depth + 1)
        
    helper(root, target, reach)
    return vals

In [11]:
def closer2(target, point1, point2, axis = None):
    
    if point1 is None:
        return point2
    if point2 is None:
        return point1
    
    p1 = point1['point']
    p2 = point2['point']
    
    d1 = distance(target, p1, axis)
    d2 = distance(target, p2, axis)
    
    if d1 < d2:
        return point1
    else:
        return point2

def closest_point2(root, target, depth=0, best=None):
    axis = depth % k
    
    # if there is no tree return None
    if root is None:
        return best
    
    # best is None the first time
    if best is None:
        best = root
        
    # What way are we going
    if target[axis] < root['point'][axis]:
        next_branch = root['left']
        oppo_branch = root['right']
    else:
        next_branch = root['right']
        oppo_branch = root['left']
        
    # finds the best from what we think is best branch
    best = closer2(target, closest_point2(next_branch, target, depth+1), best)
    
    # if distance from point to target is less than the distance from best so far to target
    # we wanna inspect values from the opposit branch to make sure no point there is better that curent best
    if distance(target, root['point'], axis) <= distance(target, best['point']):
        best = closer2(target, closest_point2(oppo_branch, target, depth+1), best)
    
    return best

def cut_down_tree(root):
    vals = []
    
    def helper(root, first=False):
        if root is not None:
            if not first:
                vals.append(root['point'])
            helper(root['left'])
            helper(root['right'])
            
    helper(root, True)
    return vals


def remove_point(root, target):
    
    if root is None:
        return None
    
    best = closest_point2(root, target)
    
    print(target, "  \t", best['point'], "  \t", target == best['point'])
    
    if best['point'] == target:
        parent = best['parent']
        
        vals = cut_down_tree(best)
        new_tree = build_kdtree(vals)
        
        if parent is None:
            root = new_tree
        elif parent['left'] is best:
            parent['left'] = new_tree
        else:
            parent['right'] = new_tree
        
        return True
    
    return False

In [44]:
def insert_point(root, point, depth = 0, parent = None):
    axis = depth % k
    if root is None:
        new_point = {
            'parent': parent,
            'point': point,
            'left': None,
            'right': None
        }
        if parent is None:
            root = new_point
        else:
            if point[axis] < parent["point"][axis]:
                parent['left'] = new_point
            else:
                parent['right'] = new_point
            
    elif point[axis] < root["point"][axis]:
        insert_point(root["left"], point, depth+1, root)
    else:
        insert_point(root["right"], point, depth+1, root)

In [45]:
print(naive_closest_point(tree, (14,9)))
print(closest_point(tree, (14,9)))


res = points_in_reach(tree, (0,0), 20)
res2 = sorted([(x,distance((0,0),x))for x in res], key=lambda x: x[1])


print()
for x in res2:
    print(*x, sep="  \t")
    

print()
    
for point in points:
    tree = build_kdtree(points)
    remove_point(tree, point)
    
    
print()


import pprint
tree = build_kdtree(points[:3])
pprint.pprint(tree)
print("\n\n\n\n\n\n")
tree = build_kdtree(points[:3])
insert_point(tree, (5,8))
insert_point(tree, (6,9))
insert_point(tree, (7,10))
pprint.pprint(tree)

(16, 10)
(16, 10)

(4, 7)  	8.06225774829855
(5, 8)  	9.433981132056603
(6, 9)  	10.816653826391969
(7, 10)  	12.206555615733702
(11, 10)  	14.866068747318506
(16, 10)  	18.867962264113206

(11, 10)   	 (11, 10)   	 True
(4, 7)   	 (4, 7)   	 True
(16, 10)   	 (16, 10)   	 True
(9, 4)   	 (9, 4)   	 True
(7, 13)   	 (7, 13)   	 True
(15, 3)   	 (15, 3)   	 True
(14, 11)   	 (14, 11)   	 True
(1, 1)   	 (1, 1)   	 True

{'left': {'left': None,
          'parent': <Recursion on dict with id=2129771848976>,
          'point': (4, 7),
          'right': None},
 'parent': None,
 'point': (11, 10),
 'right': {'left': None,
           'parent': <Recursion on dict with id=2129771848976>,
           'point': (16, 10),
           'right': None}}







{'left': {'left': None,
          'parent': <Recursion on dict with id=2129772606664>,
          'point': (4, 7),
          'right': {'left': None,
                    'parent': <Recursion on dict with id=2129772606808>,
                    'point