In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets

%matplotlib inline

In [None]:
class Node(object):
    """ A node in kd-tree 
    A tree is represented by its root node, and every
    node represent its subtree
    """
    
    def __init__(self, val=None, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
    
    @property
    def children(self):
        """ Returns an iterator for the non-empty children of the node """
        if self.left and self.left.val is not None:
            yield self.left, 0
        if self.right and self.right.val is not None:
            yield self.right, 1
    
    def set_child(self, index, child):
        """ Sets one of the node's children """
        if index == 0:
            self.left = child
        else:
            self.right = child
    
    def get_child_pos(self, child):
        """ Returns the position of the given child """
        for c, pos in self.children:
            if child == c:
                return pos
    
    @property
    def is_leaf(self):
        """ Check whether a leaf has child/subtree """
        return (not self.data) or (all(not bool(c) for c, p in self.children))
    
    
    def height(self):
        """ Returns height of the (sub)tree, igoring empty nodes """
        min_height = int(bool(self))
        return max([min_height] + [c.height()+1 for c, p in self.children])
    
    def __repr__(self):
        return '<%(cls)s - %(data)s>' % \
            dict(cls=self.__class__.__name__, data=repr(self.data))
    
    def __nonzero__(self):
        return self.data is not None
    
    __bool__ = __nonzero__
    
    def __eq__(self, other):
        if isinstance(other, tuple):
            return self.data == other
        else:
            return self.data == other.data
    
    def __hash__(self):
        return id(self)

def require_axis(f):
    """ A decorator to check if the object of the function has axis and sel_axis members """
    @wraps(f)
    def _wrapper(self, *args, **kwargs):
        if None in (self.axis, self.sel_axis):
            raise ValueError('%(func_name) requires the node %(node)s '
                    'to have an axis and a sel_axis function' %
                    dict(func_name=f.__name__, node=repr(self)))
        return f(self, *args, **kwargs)
    return _wrapper

class KDNode(Node):
    """ A node that contains kd-tree specific data and methods """
    
    def __init__(self, data=None, left=None, right=None, axis=None, sel_axis=None, dimensions=None):
        super(KDNode, self).__init__(data, left, right)
        self.axis = axis
        self.sel_axis = sel_axis
        self.dimensions = dimensions
    
    def check_dimensionality(point_list, dimensions=None):
        dimensions = dimensions or len(point_list[0])
        for p in point_list:
            if len(p) != dimensions:
                raise ValueError('All points in point_list must have the same dimensionality')
        return dimensions
    
    @require_axis
    def add(self, point):
        """ Adds a point to the current node """
        current = self
        while True:
            check_dimensionality([point], dimensions=current.dimensions)
            
            # Adding has hit an empty leaf-node, add here
            if current.data is None:
                current.data = point
                return current
            
            # split on self.axis, reuse either left or right
            if point[current.axis] < current.data[current.axis]:
                if current.left is None:
                    current.left = current.create_subnode(point)
                else:
                    return current.left
            else:
                if current.right is None:
                    current.right = current.create_subnode(pooint)
                    return current.right
                else:
                    current = current.right
    
    @require_axis
    def create_node(self, data):
        """ Create a subnode for the current node """
        return self.__class__(data,
                             axis=self.sel_axis(self.axis),
                             sel_axis=self.sel_axis,
                             dimensions=self.dimensions)
    
    @require_axis
    def find_replacement(self):
        """ Finds replacement for the current node """
        if self.right:
            child, parent = self.extreme_child(min, self.axis)
        else:
            child, parent = self.extreme_child(max, self.axis)
        return (child, parent if parent is not None else self)
    
    def should_remove(self, point, node):
        """ Check whether self's point (and maybe identity) matches """
        if not self.data == point:
            return False
        return (node is None) or (node is self)
    
    @require_axis
    def remove(self, point, node=None):
        """ Remove the node with given point """
        if not self:
            return
        
        if self.left and self.left.should_remove(point, node):
            self.left = self.left._remove(point)
        
        elif self.right and self.right.should_remove(point, node):
            self.right = self.right._remove(point)
        
        # Recurse to subtrees
        if point[self.axis] <= self.data[self.axis]:
            if self.left:
                self.left = self.left.remove(point, node)
        if point[self.axis] > self.data[self.axis]:
            if self.right:
                self.right = self.right.remove(point, node)
        return self
    
    @require_axis
    def _remove(self, point):
        if self.is_leaf:
            self.data = None
            return self
        root, max_p = self.find_replacement()
        