In [1]:
%load_ext Cython
%load_ext memory_profiler

In [2]:
# Change Working Directory To Allow knn Imports
import os 
os.chdir('../')

In [93]:
%%cython -a
import cython
import numpy as np
cimport numpy as np
import math
from libcpp cimport bool
from libc.math cimport sqrt
from knn.distance_metrics import euclidean

ctypedef fused int_or_float:
    short
    int
    long
    float
    double
    

cdef class BallTree:
    
    cdef double[:, ::1] data_view
    cdef long[::1] data_inds_view
    cdef np.ndarray data
    cdef np.ndarray data_inds
    
    cdef double[:,::1] query_data_view 
    
    cdef np.ndarray node_data_inds
    cdef np.ndarray node_radius
    cdef np.ndarray node_is_leaf
    cdef np.ndarray node_center
    cdef long[:,::1] node_data_inds_view
    cdef double[::1] node_radius_view
    cdef double[::1] node_is_leaf_view
    cdef double[:,::1] node_center_view
    
    cdef int leaf_size
    cdef int node_count
    cdef int tree_height
    
    
    cdef public np.ndarray heap
    cdef double[:,::1] heap_view
    cdef public np.ndarray heap_inds
    cdef long[:,::1] heap_inds_view
    
    def __init__(self, data, leaf_size):
        
        # Data
        self.data = np.asarray(data, dtype=np.float, order='C')
        self.data_view = memoryview(self.data)
        self.data_inds = np.arange(data.shape[0], dtype=np.int)
        self.data_inds_view = memoryview(self.data_inds)
        
        
        
        
        # Tree Shape
        self.leaf_size = leaf_size
        leaf_count = self.data.shape[0] / leaf_size
        self.tree_height = math.ceil(np.log2(leaf_count)) + 1
        self.node_count = int(2 ** self.tree_height) - 1 
        
        print("Leaf Count: " + str(leaf_count))
        print("Tree Height: " + str(self.tree_height))
        print("Node Count: " + str(self.node_count))
        
        # Node Data
        self.node_data_inds = np.zeros((self.node_count, 2), dtype=np.int, order='C')
        self.node_radius = np.zeros(self.node_count, order='C')
        self.node_is_leaf = np.zeros(self.node_count, order='C')
        self.node_center = np.zeros((self.node_count, data.shape[1]), order='C')
        self.node_data_inds_view = memoryview(self.node_data_inds)
        self.node_radius_view = memoryview(self.node_radius)
        self.node_is_leaf_view = memoryview(self.node_is_leaf)
        self.node_center_view = memoryview(self.node_center)
        
        
    def build_tree(self):
        self._build(0, 0, self.data.shape[0]-1)
    
    
    def _build(self, long node_index, long node_data_start, long node_data_end):
        #print("Node Index: " + str(node_index))
        #print("Node Data Start: " + str(node_data_start))
        #print("Node Data End: " + str(node_data_end))
        
        # Current Node Is A Leaf
        if (node_data_end-node_data_start+1) <= self.leaf_size:
            #print("Leaf Node")
            
            self.node_center[node_index] = np.mean(self.data[self.data_inds[node_data_start:node_data_end+1]], axis=0)
            #print("Node Center:" + str(np.asarray(self.node_center[node_index,:])))
        
            self.node_radius[node_index] = np.max(euclidean(self.data[self.data_inds[node_data_start:node_data_end+1]], 
                                                            self.node_center[node_index,  :][np.newaxis, :]))
            #print("Radius: " + str(self.node_radius[node_index]))
            
            self.node_data_inds[node_index, 0] = node_data_start
            self.node_data_inds[node_index, 1] = node_data_end
            
            self.node_is_leaf[node_index] = True
            return None
        
        # Current Node Is Internal Node
        #print("Internal Node")
        
        # Random Point x0
        
        rand_index = np.random.choice(node_data_end-node_data_start+1, 1, replace=False)
        rand_point = self.data[self.data_inds[rand_index], :]
        
        #print(rand_point)
        
        # Find Maximal Point x1
        distances = euclidean(self.data[self.data_inds[node_data_start:node_data_end+1]], rand_point)
        ind_of_max_dist = np.argmax(distances)
        max_vector_1 = self.data[ind_of_max_dist]
        #print(max_vector_1)
        
        
        

        # Find Maximal Point x2
        distances = euclidean(self.data[self.data_inds[node_data_start:node_data_end+1]], max_vector_1[np.newaxis, :])
        ind_of_max_dist = np.argmax(distances)
        max_vector_2 = self.data[ind_of_max_dist]
        #print(max_vector_2)
        
        
        # Project Data
        proj_data = np.dot(self.data[self.data_inds[node_data_start:node_data_end+1]], max_vector_1-max_vector_2)
        
        #print(proj_data)
        
        # Find Median And Split Data
        median = np.partition(proj_data, proj_data.size//2)[proj_data.size//2]
        #print(median)
        
        #print("Here?")
        
        
        # Hoare Partitioning
        low = node_data_start
        high = node_data_end
        pivot = median
        
        #print(low)
        #print(high)
        #print(proj_data.shape)
        
                
        self._hoare_partition(pivot, low, high, proj_data)
        
        #print(proj_data)
        #print(self.data_inds)
        
        
        # Create Circle
        center = np.mean(self.data[self.data_inds[node_data_start:node_data_end+1]], axis=0)
        radius = np.max(euclidean(self.data[self.data_inds[node_data_start:node_data_end+1]], center[np.newaxis, :]))
        #print(center)
        #print(radius)
        #print(self.data[self.data_inds[node_data_start:node_data_end+1]])
        
        self.node_data_inds[node_index, 0] = node_data_start
        self.node_data_inds[node_index, 1] = node_data_end
        
        self.node_radius[node_index] = radius
        self.node_center[node_index] = center
        
        left_index = 2 * node_index + 1
        right_index = left_index + 1
        
        self.node_is_leaf[node_index] = False
        
        
        self._build(left_index, node_data_start,  node_data_start+ (proj_data.size//2)-1 )
        self._build(right_index, node_data_start+(proj_data.size//2),   node_data_end)
        
        
    cdef int _hoare_partition(self, pivot, low, high, projected_data):
        
        i = low - 1
        j = high + 1
        i2 = -1
        j2 = projected_data.shape[0]
        
        while True:
        
            # Scan From Left To Find Value Greater Than Pivot 
            condition = True
            while condition:
                i += 1
                i2 += 1
                condition = projected_data[i2] < pivot
            
            # Scan From Right To Find Value Less Than Pivot 
            condition = True
            while(condition):
                j -= 1
                j2 -= 1
                condition = projected_data[j2] > pivot

            # Time To End Algorithm
            if (i >= j):
                return j 
            
            # Swap Values
            projected_data[i2], projected_data[j2] = projected_data[j2], projected_data[i2]
            self.data_inds[i], self.data_inds[j] = self.data_inds[j], self.data_inds[i]
            
    @cython.boundscheck(False)
    @cython.wraparound(False)  
    @cython.initializedcheck(False)
    cdef inline double _euclid(self, double[::1] vector1, double[::1] vector2):

        cdef double distance = 0.0
        cdef int dims = vector1.shape[0]
        cdef double temp
        cdef size_t i

        
        for i in range(0, dims):
            temp = vector1[i] - vector2[i]
            distance += (temp*temp)

        return sqrt(distance)
            
            
    @cython.boundscheck(False)
    @cython.wraparound(False) 
    @cython.initializedcheck(False)
    def query(self, query_data, k):
        
        cdef size_t i
        cdef double[::1] query_vector
        cdef double[::1] initial_center
        
        self.heap = np.full((query_data.shape[0], k), np.inf, order='C')
        self.heap_view = memoryview(self.heap)
        self.heap_inds = np.zeros((query_data.shape[0], k), dtype=np.int, order='C')
        self.heap_inds_view = memoryview(self.heap_inds)
        
        self.query_data_view = memoryview(query_data)
        
        #inds = 0
#         dist = self.euclid(self.node_center_view[0], self.query_data_view[0, :])
#         self._query(0, dist, 0, self.query_data_view[0, :])

        initial_center = self.node_center_view[0]
        cdef int numb_query_vectors = query_data.shape[0]
        cdef double dist
        for i in range(0, numb_query_vectors):
            query_vector = self.query_data_view[i]
            dist = self._euclid(initial_center, query_vector)
            self._query(i, dist, 0, query_vector)
        
    

    
        return None
    
    @cython.boundscheck(False)
    @cython.wraparound(False)  
    @cython.initializedcheck(False)
    cdef int _query(self, size_t query_vect_ind, double dist_to_cent, size_t curr_node, double[::1] query_data):
        
        cdef size_t i, child1, child2, lower_index, upper_index, curr_index
        cdef double child1_dist, child2_dist, dist
        cdef double[::1] curr_vect, child1_center, child2_center

        # Prune This Ball
        if dist_to_cent - self.node_radius_view[curr_node] >= self._heap_peek_head(query_vect_ind):
            return 0

        # Currently A Leaf Node
        if self.node_is_leaf_view[curr_node]:
            lower_index = self.node_data_inds_view[curr_node][0]
            upper_index = self.node_data_inds_view[curr_node][1] + 1
            for i in range(lower_index, upper_index):
                curr_index = self.data_inds_view[i]
                curr_vect = self.data_view[curr_index]
                dist = self._euclid(curr_vect, query_data) 
                if dist < self._heap_peek_head(query_vect_ind):
                    self._heap_pop_push(query_vect_ind, dist, self.data_inds_view[i])

        # Not Leaf So Explore Children
        else:
            child1 = 2 * curr_node + 1
            child2 = child1 + 1
            
            child1_center = self.node_center_view[child1]
            child2_center = self.node_center_view[child2]

            child1_dist = self._euclid(child1_center, query_data)
            child2_dist = self._euclid(child2_center, query_data)

            if child1_dist < child2_dist:
                self._query(query_vect_ind, child1_dist, child1, query_data)
                self._query(query_vect_ind, child2_dist, child2, query_data)
            else:
                self._query(query_vect_ind, child2_dist, child2, query_data)
                self._query(query_vect_ind, child1_dist, child1, query_data)
                
        return 0
    
    @cython.initializedcheck(False)
    @cython.boundscheck(False)
    cdef inline double _heap_peek_head(self, size_t level):
        return self.heap_view[level, 0]
    
    
    @cython.boundscheck(False)
    @cython.wraparound(False) 
    @cython.initializedcheck(False)
    cdef int _heap_pop_push(self, int level, double value, int index):
        
        # Put New Value At Head And Remove Old Value
        self.heap_view[level, 0] = value
        self.heap_inds_view[level, 0] = index
        
        # Update Heap Structure
        cdef int left_ind, right_ind
        
        cdef int i
        
        cdef double temp_value
        cdef int temp_index
        
        i = 0
        while(True):
            left_ind = 2 * i + 1
            right_ind = left_ind + 1
            
            # Catch Edge Of Array
            if left_ind >= self.heap.shape[1]:
                break
            elif right_ind >= self.heap.shape[1]:
                if self.heap_view[level, left_ind] > self.heap_view[level, i]:
                    temp_value = self.heap_view[level, i]
                    self.heap_view[level, i] = self.heap_view[level, left_ind]
                    self.heap_view[level, left_ind] = temp_value
#                     self.heap[level, i] = self.heap_view[level, left_ind]
#                     self.heap[level, left_ind] = temp_value
                    
                    temp_index = self.heap_inds_view[level, i]
                    self.heap_inds_view[level, i] = self.heap_inds_view[level, left_ind]
                    self.heap_inds_view[level, left_ind] = temp_index
#                     self.heap_inds[level, i] = self.heap_inds_view[level, left_ind]
#                     self.heap_inds[level, left_ind] = temp_index
                    

                    
                break
                
            # Determine If We Should Explore Left or Right
            
            if self.heap_view[level, left_ind] > self.heap_view[level, right_ind]:
                if self.heap_view[level, left_ind] > self.heap_view[level, i]:
                    temp_value = self.heap_view[level, i]
                    self.heap_view[level, i] = self.heap_view[level, left_ind]
                    self.heap_view[level, left_ind] = temp_value
#                     self.heap[level, i] = self.heap_view[level, left_ind]
#                     self.heap[level, left_ind] = temp_value
                    
                    temp_index = self.heap_inds_view[level, i]
                    self.heap_inds_view[level, i] = self.heap_inds_view[level, left_ind]
                    self.heap_inds_view[level, left_ind] = temp_index
#                     self.heap_inds[level, i] = self.heap_inds_view[level, left_ind]
#                     self.heap_inds[level, left_ind] = temp_index
                    i = left_ind
                else:
                    break
            
            else:
                if self.heap_view[level, right_ind] > self.heap_view[level, i]:
                    temp_value = self.heap_view[level, i]
                    self.heap_view[level, i] = self.heap_view[level, right_ind]
                    self.heap_view[level, right_ind] = temp_value
                    
#                     self.heap[level, i] = self.heap_view[level, right_ind]
#                     self.heap[level, right_ind] = temp_value
                    
                    temp_index = self.heap_inds_view[level, i]
                    self.heap_inds_view[level, i] = self.heap_inds_view[level, right_ind]
                    self.heap_inds_view[level, right_ind] = temp_index
                    
                    
#                     self.heap_inds[level, i] = self.heap_inds_view[level, right_ind]
#                     self.heap_inds[level, right_ind] = temp_index
                    i = right_ind
                else:
                    break
        
        return 0
    
    
    
            


Error compiling Cython file:
------------------------------------------------------------
...
    float
    double

cdef class BallTree:
    
    cdef int_or_float[:, ::1] data_view
        ^
------------------------------------------------------------

/Users/Johnny/.ipython/cython/_cython_magic_11a69c8017e2aced1e2bff304937badf.pyx:18:9: Fused types not allowed here

Error compiling Cython file:
------------------------------------------------------------
...
    cdef int_or_float[:, ::1] data_view
    cdef long[::1] data_inds_view
    cdef np.ndarray data
    cdef np.ndarray data_inds
    
    cdef int_or_float[:,::1] query_data_view 
        ^
------------------------------------------------------------

/Users/Johnny/.ipython/cython/_cython_magic_11a69c8017e2aced1e2bff304937badf.pyx:23:9: Fused types not allowed here


TypeError: object of type 'NoneType' has no len()

In [59]:
# Test Data
sample_data = np.array([ [5, 5], [8, 7], [-6, -1], [-1, -3], [-4,-8], [4, -2]], dtype=np.float)
sample_label= np.array([1, 1, 1, 0, 0, 0])

tree = BallTree(sample_data, 3)
tree.build_tree()


Leaf Count: 2.0
Tree Height: 2
Node Count: 3


In [127]:
euclidean(sample_data, np.array([4, -2])[np.newaxis, :])
# print(tree.node_radius)
# print(tree.node_center)
# print(tree.node_data_inds)
# print(tree.node_is_leaf)
# print(np.asarray(tree.data_inds_view))

array([[ 7.07106781,  9.8488578 , 10.04987562,  5.09901951, 10.        ,
         0.        ]])

In [6]:
%timeit test = tree.build_tree()

261 µs ± 10.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [129]:
# Load Data
mnist_data = np.load('./sample_data/mnist/mnist_data.npz')
train_data = mnist_data['train_data']
test_data = mnist_data['test_data']

# Subset Data If Desired
test_labels = test_data[:1000, 0]
test_data = test_data[:1000, 1:].astype(np.float)
train_labels = train_data[:1000, 0]
train_data = train_data[:1000, 1:].astype(np.float)


tree = BallTree(train_data, 10)
tree.build_tree()

Leaf Count: 100.0
Tree Height: 8
Node Count: 255


In [130]:
%timeit test = tree.build_tree()

116 ms ± 3.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
tree.query(np.zeros((8, 2)) ,3)
print(tree.heap.shape)
print(tree.heap_inds.shape)

inf
(8, 3)
(8, 3)


In [74]:
sample_data = np.array([ [5, 5], [8, 7], [-6, -1], [-1, -3], [-4,-8], [4, -2]], dtype=np.float)
sample_label= np.array([1, 1, 1, 0, 0, 0])

tree = BallTree(sample_data, 3)
tree.build_tree()

# print(tree.node_center)
# print(tree.data_inds)
# print(tree.node_data_inds)
# print(tree.node_radius)
# print(tree.node_is_leaf)



tree.query(np.array([5., 4.])[np.newaxis, :] ,3)
print(tree.heap)
print(tree.heap_inds)
print(euclidean(np.array([5, 4])[np.newaxis, :], sample_data))

Leaf Count: 2.0
Tree Height: 2
Node Count: 3
[[6.08276253 4.24264069 1.        ]]
[[5 1 0]]
[[ 1.        ]
 [ 4.24264069]
 [12.08304597]
 [ 9.21954446]
 [15.        ]
 [ 6.08276253]]


In [111]:
print(euclidean(np.array([5, 4])[np.newaxis, :], tree.node_center))

[[ 5.89726867]
 [10.38160767]
 [ 1.41421356]
 [ 6.08276253]
 [13.12440475]
 [ 9.21954446]
 [ 2.5       ]]


In [161]:
%timeit test = tree.query(np.array([5, 4])[np.newaxis, :] ,3)

54.3 µs ± 2.58 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [88]:
# Load Data
mnist_data = np.load('./sample_data/mnist/mnist_data.npz')
train_data = mnist_data['train_data']
test_data = mnist_data['test_data']

# Subset Data If Desired
test_labels = test_data[:, 0]
test_data = test_data[:1000, 50:54].astype(np.float)
train_labels = train_data[:, 0]
train_data = train_data[:1000, 50:54].astype(np.float)


tree = BallTree(train_data, 10)
tree.build_tree()


Leaf Count: 100.0
Tree Height: 8
Node Count: 255


In [91]:
%timeit  test = tree.query(test_data , 3)

650 µs ± 4.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [86]:
%timeit test = euclidean(train_data, test_data)

6.05 ms ± 45.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
test = tree.query(test_data , 3)