In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.lines as lines
import matplotlib.patches as patches
from ipywidgets import interact, fixed
import random
import statistics

In [2]:
n_objects = 100

# we choose a datalayout in an numpy array
# every point is one row, with x coordinate at [n,0] and y coordinate at [n,1]
data = np.zeros( (n_objects, 2), dtype=np.float32 ) 
for i in range( n_objects ):
    x = random.uniform( -10, 10 )
    y = random.uniform( -10, 10 )
    data[ i, : ] = [x,y]

## Tree Node
inner_node ("inner", direction, position left_child, right_child ) 

leaf_node  ("lead", data)

In [3]:
def compute_extend_of_data( data ):
    min_x = np.min( data[:,0] )
    min_y = np.min( data[:,1] )
    max_x = np.max( data[:,0] )
    max_y = np.max( data[:,1] )
    return np.array( [[min_x, max_x], [min_y, max_y]], dtype=np.float32 )

active_region = compute_extend_of_data(data)
print( active_region[1] )

[-9.652392   9.8428135]


In [4]:
def choose_split_direction( active_region ):
    if active_region[0,1]-active_region[0,0] >= active_region[1,1]-active_region[1,0]:
        return "x"
    return "y"

# because of the data layout (s. above) x is always at index 0, y at index 1
def index_of_direction( direction ):
    if direction == "x":
        return 0
    return 1

def make_split_decision( data, direction ):
    index = index_of_direction ( direction )
    return np.median( data[:,index] )

def perform_split( data, direction, split_value ):
    index = index_of_direction ( direction )
    left_children  = data[ data[:,index] <= split_value,: ]
    right_children = data[ data[:,index] >  split_value,: ]
    return left_children, right_children

def split_active_regions( active_region, direction, split_value ):
    index = index_of_direction ( direction )
    left_region  = active_region.copy()    
    right_region = active_region.copy()    
    left_region [ index, 1 ] = split_value
    right_region[ index, 0 ] = split_value
    return left_region,right_region

def build_tree( data, active_region, max_leaf_size=4, depth = 0 ):
    if len(data) <= max_leaf_size:        
        # print('  ' * depth, "leaf", depth )
        return ( "leaf", ( data, active_region ) )
        
    direction   = choose_split_direction( active_region )
    split_value = make_split_decision( data, direction )
    
    left_children, right_children = perform_split( data, direction, split_value )
    left_region,right_region      = split_active_regions( active_region, direction, split_value )
    
    # print('  ' * depth, "inner", depth, direction, split_value )
    left_child  = build_tree( left_children,  left_region,  depth=depth+1 ) 
    right_child = build_tree( right_children, right_region, depth=depth+1 ) 
    return ( "inner", ( direction, split_value, left_child, right_child, active_region ) )

active_region = compute_extend_of_data( data ) 
bsp = build_tree( data, active_region )

def traverse_tree( node, point ):
    node_type,node_data = node
    if node_type == "leaf":
        data, _ = node_data
        return node
        
    direction, split_value, left_child, right_child, _ = node_data
    index = index_of_direction ( direction )
    
    if point[index] <= split_value:
        return traverse_tree( left_child, point )
    return traverse_tree( right_child, point )

In [5]:
def draw_tree( axis, node, color="blue", depth = 0, draw_leaf=False ):
    node_type,node_data = node
    if node_type == "leaf":
        data, active_region = node_data
        if draw_leaf:
            xy = active_region[:,0]
            wh = active_region[:,1]-active_region[:,0]
            rect = patches.Rectangle( xy, wh[0], wh[1], color=color, linewidth=0, alpha=0.5 )
            axis.add_patch( rect )        
    else:
        direction, split_value, left_child, right_child, active_region = node_data
        index = index_of_direction ( direction )
        polygon = np.zeros( (2,2), dtype=np.float32 )
        polygon[0,index] = split_value
        polygon[1,index] = split_value
        polygon[0,1-index] = active_region[1-index,0]
        polygon[1,1-index] = active_region[1-index,1]
        width = max( 5-depth, 1 )        
        line = lines.Line2D( polygon[:,0], polygon[:,1], color=color, linewidth=width )
        line.set_dashes( [ 2, 3 ] )
        line.set_dash_capstyle( "round")
        axis.add_line( line )        
        draw_tree( axis, left_child,  color=color, depth=depth+1 )
        draw_tree( axis, right_child, color=color, depth=depth+1 )

def visualize_bsp( x=5.0, y=0.0 ):    
    fig,axis = plt.subplots( 1, figsize=(10,10) )
    axis.set_xlim( -10, 10 )
    axis.set_ylim( -10, 10 )

    X = [ x for x,_ in data ]
    Y = [ y for _,y in data ]
    
    node = traverse_tree( bsp, np.array( [ x, y], dtype=np.float32 ) )
    draw_tree( axis, node, "red", draw_leaf=True )
    draw_tree( axis, bsp, "blue" )
    axis.scatter( X, Y, color="blue")
    axis.scatter( [x], [y], color="red")

interact(visualize_bsp, x=(-10.0, 10.0), y=(-10.0, 10.0))

interactive(children=(FloatSlider(value=5.0, description='x', max=10.0, min=-10.0), FloatSlider(value=0.0, des…

<function __main__.visualize_bsp(x=5.0, y=0.0)>