In [6]:
#!/usr/bin/env python
# References- 
# https://en.wikipedia.org/wiki/K-d_tree
# https://gist.github.com/rhigdon/199174
from collections import namedtuple
from operator import itemgetter
from pprint import pformat
import numpy as np
import random
import matplotlib.pyplot as plt
import cv2, sys, time, os, argparse, multiprocessing
from __future__ import print_function
import math
import operator
import dill
from functools import wraps
from collections import deque

In [7]:
# maps child position to its comparison operator
# 0 = search left child
# 1 = search right child
COMP_CHILD = {
    0: (operator.le, operator.sub),
    1: (operator.ge, operator.add),
}

def _plane(x):
    """ Check if the object of the function has axis """
    @wraps(x)
    def _w(self, *args, **kwargs):
        if None in (self.axis):
            raise ValueError('%(func_name) requires the node %(node)s '
                             'to have an axis' %
                             dict(func_name=x.__name__, node=repr(self)))

        return x(self, *args, **kwargs)

    return _w

In [8]:
class kd_node:
    """
    Node that builds kd-tree
    """
    def __init__(self, data=None, parent=None, left=None, right=None,
                 axis=None, sel_axis=None, dimensions=None):
        """
        Creates a new node for a kd-tree.
        If the node will be used within a tree, the axis function should be supplied.
        parent --> None, only when the node is the root node.
        """
        self.data = data
        self.parent = parent
        self.left = left
        self.right = right
        self.axis = axis
        self.sel_axis = sel_axis
        self.dimensions = dimensions
    
    def axis_distance(self, point, axis):
        """
        Returns the squared distance at the given axis between the current
        Node and the query point.
        """
        return math.pow(self.data.get(axis, 0.) - point.get(axis, 0.), 2)
    
    def distance(self, point):
        """
        Returns the squared distance between the current Node and the query point.
        """
        x = range(len(self.data))
        return sum([self.axis_dist(point, i) for i in x])
    
    def node_search(self, point, k, results, examined, get_dist):
        """
        k --> the number of nearest neighbors of point.
        get_distance --> is a distance function, expecting two points and returning a
        distance value.
        Returns (node, distance)
        """
        examined.add(self)

        # Get current best
        if not results:
            # results is empty
            bestNode = None
            bestDist = float('inf')
        else:
            # find the nearest (node, distance) tuple
            bestNode, bestDist = sorted(
                results.items(), key=lambda n_d: n_d[1], reverse=False)[0]

        nodesChanged = False

        # If the current node is closer than the current best, then it becomes
        # the current best. And the maximum distance nodes should be removed.
        nodeDist = get_dist(self)
        if nodeDist < bestDist:
            if len(results) == k and bestNode:
                # results.pop(bestNode)
                # here is the difference, i remove the max dist node
                maxNode, maxDist = sorted(
                    results.items(), key=lambda n: n[1], reverse=True)[0]
                results.pop(maxNode)

            results[self] = nodeDist
            nodesChanged = True
        # If we're equal to the current best, add it, regardless of k
        elif nodeDist == bestDist:
            results[self] = nodeDist
            nodesChanged = True
        # If we don't have k results yet, add it anyway
        elif len(results) < k:
            results[self] = nodeDist
            nodesChanged = True

        # Get new best only if nodes have changed
        if nodesChanged:
            bestNode, bestDist = sorted(
                results.items(), key=lambda n: n[1], reverse=False)[0]

        # Check whether there could be any other points on the other side
        # of the splitting.
        # hyperplane that are closer to the search point than the current best.
        for child, pos in self.children():
            if child in examined:
                continue

            examined.add(child)
            compare, combine = COMPARE_CHILD[pos]

            # Since the hyperplanes are all axis-aligned this is implemented
            # as a simple comparison to see whether the difference between the
            # splitting coordinate of the search point and current node is less
            # than the distance (overall coordinates) from the search point to
            # the current best.
            nodePoint = self.data.get(self.axis, 0.)
            pointPlusDist = combine(point.get(self.axis, 0.), bestDist)
            lineIntersects = compare(pointPlusDist, nodePoint)

            # If the hypersphere crosses the plane, there could be nearer
            # points on the other side of the plane, so the algorithm must move
            # down the other branch of the tree from the current node looking
            # for closer points, following the same recursive process as the
            # entire search.
            if lineIntersects:
                child._search_node(point, k, results, examined, get_dist)
    
    @_plane
    def add(self,point):
        while True:
            check_dimensionality(point, dimensions=current.dimensions)

            # Adding has hit an empty leaf-node, add here
            if current.data is None:
                current.data = point
                return current

            # split on self.axis, recurse either left or right
            if (point.get(current.axis, 0.) <
                    current.data.get(current.axis, 0.)):
                if current.left is None:
                    current.left = current.create_subnode(point)
                    return current.left
                else:
                    current = current.left
            else:
                if current.right is None:
                    current.right = current.create_subnode(point)
                    return current.right
                else:
                    current = current.right
    @_plane
    def create_subnode(self, data):
        return self.__class__(data, parent=self,
                              axis=self.sel_axis(self.axis),
                              sel_axis=self.sel_axis,
                              dimensions=self.dimensions)
    
    def search_knn(self, point, k, dist=None):
        """
        Returns the k nearest neighbors of the given point and their distance.
        point must be an actual point in same dimensions, not a node.
        k is the number of results to return. The actual results can be less
        (if there aren't more nodes to return) or more in case of equal
        distance.
        dist is a distance function, expecting two points and returning a
        distance value. Distance values can be any compareable type.
        The result is an ordered list of (node,distance) tuples.
        """
        prev = None
        current = self

        if dist is None:
            get_dist = lambda n: n.distance(point)
        else:
            get_dist = lambda n: distance(n.data, point)

        # go down the trees as we would for inserting
        while current:
            if (point.get(current.axis, 0.) <
                    current.data.get(current.axis, 0.)):
                # go to left subtree
                prev = current
                current = current.left
            else:
                # go to right subtree
                prev = current
                current = current.right

        if not prev:
            return []

        examined = set()
        results = {}

        # Go uo the tree, looking for better solutions
        current = prev
        while current:
            current.node_search(point, k, results, examined, get_dist)
            current = current.parent

        return sorted(results.items(), key=lambda a: a[1])
        
    def height(self):
        """Returns height of the (sub)tree."""
        # If self is not None, it's height should be at least 1
        min_height = int(bool(self))
        return max([min_height] + [c.height() + 1 for c, p in self.children()])
    
    def __nonzero__(self):
        return self.data is not None
    
    __bool__ = __nonzero__
    

In [9]:
    def create(point_list, dimensions, axis=0, sel_axis=None, parent=None):
        """
        Creates a kd-tree from a list of points
        All points in the list must be of the same dimensionality.
        If no point_list is given, an empty tree is created. The number of
        dimensions has to be given instead.
        If both a point_list and dimensions are given, the numbers must agree.
        axis is the axis on which the root-node should split.
        sel_axis(axis) is used when creating subnodes of a node. It receives the
        axis of the parent node and returns the axis of the child node.
        parent is the Nodes' parent node. """

        if not point_list and not dimensions:
            raise ValueError('either point_list or dimensions must be provided')

        elif point_list:
            dimensions = check_dimensionality(point_list, dimensions)

        # by default cycle through the axis
        sel_axis = sel_axis or (lambda prev_axis: (prev_axis + 1) % dimensions)

        if not point_list:
            return KDNode(sel_axis=sel_axis, axis=axis, dimensions=dimensions)

        # Sort point list and choose median as pivot element
        point_list.sort(key=lambda point: point.get(axis, 0.))
        median = len(point_list) / 2

        loc = point_list[median]
        root = KDNode(loc, parent, left=None, right=None,
                      axis=axis, sel_axis=sel_axis)
        root.left = create(point_list[:median],
                           dimensions, sel_axis(axis), parent=root)
        root.right = create(point_list[median + 1:],
                            dimensions, sel_axis(axis), parent=root)
        return root


    def check_dimensionality(point_list, dimensions):
        # The dimensions must be given
        dimensions = dimensions  # or len(point_list[0])
        for p in point_list:
            if max(p.keys()) > dimensions:
                raise ValueError(
                    'All Points in point_list must have the same dimensionality')

        return dimensions


    def level_order(tree, include_all=False):
        """ Returns an iterator over the tree in level-order
        If include_all is set to True, empty parts of the tree are filled
        with dummy entries and the iterator becomes infinite. """

        q = deque()
        q.append(tree)
        while q:
            node = q.popleft()
            yield node

            if include_all or node.left:
                q.append(node.left or node.__class__())

            if include_all or node.right:
                q.append(node.right or node.__class__())


    def visualize(tree, max_level=100, node_width=10, left_padding=5):
        """ Prints the tree to stdout """

        height = min(max_level, tree.height() - 1)
        max_width = pow(2, height)

        per_level = 1
        in_level = 0
        level = 0

        for node in level_order(tree, include_all=True):

            if in_level == 0:
                print()
                print()
                print(' ' * left_padding, end=' ')

            width = int(max_width * node_width / per_level)

            node_str = (str(node.data) if node else '').center(width)
            print(node_str, end=' ')

            in_level += 1

            if in_level == per_level:
                in_level = 0
                per_level *= 2
                level += 1

            if level > height:
                break

        print()
        print()