In [1]:
%load_ext Cython
%load_ext memory_profiler

In [2]:
# Change Working Directory To Allow knn Imports
import os 
os.chdir('../')

In [391]:
%%cython
import cython
import numpy as np
cimport numpy as np
import math
from libcpp cimport bool
from knn.distance_metrics import euclidean

cdef class BallTree:
    
    cdef readonly double[:, ::1] data_view
    cdef readonly long[::1] data_inds_view
    cdef np.ndarray data
    cdef np.ndarray data_inds
    
    cdef public np.ndarray node_data_inds
    cdef public np.ndarray node_radius
    cdef public np.ndarray node_is_leaf
    cdef public np.ndarray node_center
    cdef public long[:,::1] node_data_inds_view
    cdef public double[::1] node_radius_view
    cdef public double[::1] node_is_leaf_view
    cdef public double[:,::1] node_center_view
    
    cdef int leaf_size
    cdef int node_count
    cdef int tree_height
    
    def __init__(self, data, leaf_size):
        
        # Data
        self.data = np.asarray(data, dtype=np.float, order='C')
        self.data_view = memoryview(self.data)
        self.data_inds = np.arange(data.shape[0], dtype=np.int)
        self.data_inds_view = memoryview(self.data_inds)
        
        # Tree Shape
        self.leaf_size = leaf_size
        leaf_count = self.data.shape[0] / leaf_size
        self.tree_height = math.ceil(np.log2(leaf_count)) + 1
        self.node_count = int(2 ** self.tree_height) - 1 
        
        print("Leaf Count: " + str(leaf_count))
        print("Tree Height: " + str(self.tree_height))
        print("Node Count: " + str(self.node_count))
        
        # Node Data
        self.node_data_inds = np.zeros((self.node_count, 2), dtype=np.int, order='C')
        self.node_radius = np.zeros(self.node_count, order='C')
        self.node_is_leaf = np.zeros(self.node_count, order='C')
        self.node_center = np.zeros((self.node_count, data.shape[1]), order='C')
        self.node_data_inds_view = memoryview(self.node_data_inds)
        self.node_radius_view = memoryview(self.node_radius)
        self.node_is_leaf_view = memoryview(self.node_is_leaf)
        self.node_center_view = memoryview(self.node_center)
        
        
    def build_tree(self):
        self._build(0, 0, self.data.shape[0]-1)
    
    
    def _build(self, long node_index, long node_data_start, long node_data_end):
        #print("Node Index: " + str(node_index))
        #print("Node Data Start: " + str(node_data_start))
        #print("Node Data End: " + str(node_data_end))
        
        # Current Node Is A Leaf
        if (node_data_end-node_data_start+1) <= self.leaf_size:
            #print("Leaf Node")
            
            self.node_center[node_index] = np.mean(self.data[self.data_inds[node_data_start:node_data_end+1]], axis=0)
            #print("Node Center:" + str(np.asarray(self.node_center[node_index,:])))
        
            self.node_radius[node_index] = np.max(euclidean(self.data[self.data_inds[node_data_start:node_data_end+1]], 
                                                            self.node_center[node_index,  :][np.newaxis, :]))
            #print("Radius: " + str(self.node_radius[node_index]))
            
            self.node_data_inds[node_index, 0] = node_data_start
            self.node_data_inds[node_index, 1] = node_data_end
            
            self.node_is_leaf[node_index] = True
            return None
        
        # Current Node Is Internal Node
        #print("Internal Node")
        
        # Random Point x0
        
        rand_index = np.random.choice(node_data_end-node_data_start+1, 1, replace=False)
        rand_point = self.data[self.data_inds[rand_index], :]
        
        #print(rand_point)
        
        # Find Maximal Point x1
        distances = euclidean(self.data[self.data_inds[node_data_start:node_data_end+1]], rand_point)
        ind_of_max_dist = np.argmax(distances)
        max_vector_1 = self.data[ind_of_max_dist]
        #print(max_vector_1)
        
        
        

        # Find Maximal Point x2
        distances = euclidean(self.data[self.data_inds[node_data_start:node_data_end+1]], max_vector_1[np.newaxis, :])
        ind_of_max_dist = np.argmax(distances)
        max_vector_2 = self.data[ind_of_max_dist]
        #print(max_vector_2)
        
        
        # Project Data
        proj_data = np.dot(self.data[self.data_inds[node_data_start:node_data_end+1]], max_vector_1-max_vector_2)
        
        #print(proj_data)
        
        # Find Median And Split Data
        median = np.partition(proj_data, proj_data.size//2)[proj_data.size//2]
        #print(median)
        
        #print("Here?")
        
        
        # Hoare Partitioning
        low = node_data_start
        high = node_data_end
        pivot = median
        
        #print(low)
        #print(high)
        #print(proj_data.shape)
        
                
        self._hoare_partition(pivot, low, high, proj_data)
        
        #print(proj_data)
        #print(self.data_inds)
        
        
        # Create Circle
        center = np.mean(self.data[self.data_inds[node_data_start:node_data_end+1]], axis=0)
        radius = np.max(euclidean(self.data[self.data_inds[node_data_start:node_data_end+1]], center[np.newaxis, :]))
        #print(center)
        #print(radius)
        #print(self.data[self.data_inds[node_data_start:node_data_end+1]])
        
        self.node_data_inds[node_index, 0] = node_data_start
        self.node_data_inds[node_index, 1] = node_data_end
        
        self.node_radius[node_index] = radius
        self.node_center[node_index] = center
        
        left_index = 2 * node_index + 1
        right_index = left_index + 1
        
        self.node_is_leaf[node_index] = False
        
        
        self._build(left_index, node_data_start,  node_data_start+ (proj_data.size//2)-1 )
        self._build(right_index, node_data_start+(proj_data.size//2),   node_data_end)
        
        
    def _hoare_partition(self, pivot, low, high, projected_data):
        
        i = low - 1
        j = high + 1
        i2 = -1
        j2 = projected_data.shape[0]
        
        while True:
        
            # Scan From Left To Find Value Greater Than Pivot 
            condition = True
            while condition:
                i += 1
                i2 += 1
                condition = projected_data[i2] < pivot
            
            # Scan From Right To Find Value Less Than Pivot 
            condition = True
            while(condition):
                j -= 1
                j2 -= 1
                condition = projected_data[j2] > pivot

            # Time To End Algorithm
            if (i >= j):
                return j 
            
            # Swap Values
            projected_data[i2], projected_data[j2] = projected_data[j2], projected_data[i2]
            self.data_inds[i], self.data_inds[j] = self.data_inds[j], self.data_inds[i]
            

In [392]:
# Test Data
sample_data = np.array([ [5, 5], [8, 7], [-6, -1], [-1, -3], [-4,-8], [4, -2]], dtype=np.float)
sample_label= np.array([1, 1, 1, 0, 0, 0])

tree = BallTree(sample_data, 3)
tree.build_tree()

Leaf Count: 2.0
Tree Height: 2
Node Count: 3


In [386]:

print(tree.node_radius)
print(tree.node_center)
print(tree.node_data_inds)
print(tree.node_is_leaf)
print(np.asarray(tree.data_inds_view))

[10.13793755  5.58768487  4.01386486]
[[ 1.         -0.33333333]
 [ 5.66666667  3.33333333]
 [-3.66666667 -4.        ]]
[[0 5]
 [0 2]
 [3 5]]
[0. 1. 1.]
[0 1 5 3 4 2]


In [327]:
%timeit test = tree.build_tree()

1.07 ms ± 11.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [389]:
# Load Data
mnist_data = np.load('./sample_data/mnist/mnist_data.npz')
train_data = mnist_data['train_data']
test_data = mnist_data['test_data']

# Subset Data If Desired
test_labels = test_data[:1, 0]
test_data = test_data[:1, 1:].astype(np.float)
train_labels = train_data[:10000, 0]
train_data = train_data[:10000, 1:].astype(np.float)


tree = BallTree(train_data, 100)
tree.build_tree()

Leaf Count: 100.0
Tree Height: 8
Node Count: 255


In [390]:
%timeit test = tree.build_tree()

1.05 s ± 15.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
