In [None]:
from collections import namedtuple, Counter
from textwrap import dedent

import numpy as np
import numpy_indexed as npi

## for test only
from sklearn import datasets

##vis
import matplotlib.pyplot as plt
plt.style.use('ggplot')



def class_probability(labels):
    '''take binary labels and compute label distributions'''
    total = len(labels)
    p1 = len(labels[labels == 1]) / total
    p0 = 1-p1
    return np.array([p0,p1])


class Interval(namedtuple('Interval', ['attr', 'interval', 'loss', 'inside', 'outside'])):
    def __repr__(self):
        return f"Interval(attr={self.attr}, interval=[{self.interval[0]:.3f}, {self.interval[1]:.3f}], loss={self.loss:.4f}, inside={self.inside.data.shape}, outside={self.outside.data.shape})"
    
class Dataset(namedtuple('Dataset', ['data', 'target'])):
    def __repr__(self):
        return f"Dataset({self.data.shape})"
    
class LeafNode:
    def __init__(self, prediction=None, members=[], level=0):
        self.prediction = prediction
        self.members = members
        self.level = level
    def __repr__(self):
        return dedent(f'''Leaf(prediction={self.prediction}, size={self.members.data.shape[0]})''')


class IntervalNode:
    def __init__(
        self, 
        interval=None, 
        inside=None,
        outside=None,
        level=None,
    ):
        self.interval=interval
        self.inside=inside
        self.outside=outside
        self.level=level
        
    def __repr__(self, indent=2):
        level = self.inside.level
        return dedent(f'''
Interval(attr={self.interval.attr}, inverval={self.interval.interval}
{" "*indent*level}I={self.inside}
{" "*indent*level}O={self.outside}'''.strip("\n"))
    

In [None]:
# iris = datasets.load_iris()
# iris_dataset = Dataset(iris.data, (iris.target==0).astype(np.int64))##simulate binary classes
# dataset = iris_dataset
# print(iris.data.shape, iris.target.shape)

## Planar dataset
data = np.random.rand(150,2)
target = ((0<data[:,0])*(data[:,0]<0.5)*(0<data[:,1])*(data[:,1]<0.5)).astype(np.int64)
dataset = Dataset(data, target)
plt.scatter(data[:,0], data[:,1], c=target)
plt.axis('equal')
plt.xlim([0,1])
plt.ylim([0,1])


In [None]:
def all_intervals(sorted_values, min_interval=0):
    ## add two end points
    sorted_values = [sorted_values[0]-1e-3, *sorted_values, sorted_values[-1]+1e-3]
    unique_values = npi.group_by(sorted_values).unique
    split_points = [(a+b)/2 for a,b in zip(unique_values[:-1], unique_values[1:])]
    for i,v0 in enumerate(split_points):
        for v1 in split_points[i+1:]:
            if v1-v0 >= min_interval:
                yield [v0,v1]


def find_best_interval(dataset, attr_dimensions=None):
    data = dataset.data
    target = dataset.target
    if attr_dimensions is None:
        attr_dimensions = list(range(data.shape[1]))

    ## compute all attribute-split qualities
    intervals = []
    for a in attr_dimensions:
        ## sort by attribute
        data_sorted = sorted(data, key=lambda d:d[a])
        attr_values = [d[a] for d in data_sorted]
        value_range = attr_values[-1]-attr_values[0]
        min_interval = 0##value_range/20
        for [x0,x1] in all_intervals(attr_values, min_interval=min_interval):
            ## split data by split point
            inside = np.logical_and(x0<data[:,a], data[:,a]<x1)
            outside = np.logical_or(data[:,a]<x0, data[:,a]>x1)
            data_inside = data[inside]
            target_inside = target[inside]
            data_outside = data[outside]
            target_outside = target[outside]
            
            ## do not consider entire range as a split
            ## this only happen to outside
            if len(target_outside) == 0: continue
                
            p = class_probability(target_inside)
            inside_purity = p[0]**2 + 10*p[1]**2
            p = class_probability(target_outside)
            outside_purity = 10*p[0]**2 + p[1]**2
#             print(f'[{x0:.4f},{x1:.4f}]', f'inside:{inside_purity:.2f}', f'outside:{outside_purity:.2f}')
            loss = 1 - inside_purity*outside_purity
            
            intervals.append(Interval(
                attr=a, 
                interval=[x0,x1], 
                loss=loss, 
                inside=Dataset(data_inside, target_inside),
                outside=Dataset(data_outside, target_outside),
            ))
    
    best_interval = min(intervals, key=lambda s:s.loss)
    return best_interval
## test
bi = find_best_interval(dataset)
bi

In [None]:
def fit_decision_tree_predicate(dataset, level=0, attr_dimensions=None, branch='inside'):
    '''
    Assumptions:
        1. Only consider binary classes
        2. [TODO] Only try to bound positive examples
            (which coresponds to selected data points, e.g., through brush)
        I.e., most positive examples will appear _inside_ of some box as opposed to possibly outside.'''
        
    data = dataset.data
    target = dataset.target
    if attr_dimensions is None:
        attr_dimensions = list(range(data.shape[1]))
    
    if len(attr_dimensions)==0:
        is_leaf = True
    else:
        is_leaf = branch=='outside' or data.shape[0] <= 0
    if is_leaf:
        prediction = int(branch=='inside')
        return LeafNode(prediction=prediction, members=dataset, level=level)
    else:
        best_interval = find_best_interval(dataset, attr_dimensions=attr_dimensions)
        attr_dimensions.remove(best_interval.attr)
        inside_node = fit_decision_tree_predicate(
            best_interval.inside, 
            level=level+1, 
            attr_dimensions=attr_dimensions, 
            branch='inside',
        )
        
        outside_node = LeafNode(prediction=0, members=dataset, level=level+1)
        return IntervalNode(
            interval=best_interval, 
            inside=inside_node, 
            outside=outside_node, 
            level=level
        )
    
    
## test
tree = fit_decision_tree_predicate(dataset)
tree

In [None]:
x,y = sorted([tree.interval, tree.inside.interval], key=lambda x:x.attr)
x = x.interval
y = y.interval

plt.scatter(data[:,0], data[:,1], c=target)
plt.axis('equal')
plt.vlines(x,0,1)
plt.hlines(y,0,1)
plt.xlim([0,1])
plt.ylim([0,1])


In [None]:
def find_leaf(tree, data_point):
    node = tree
    while type(node) is IntervalNode:
        attr = node.interval.attr
        [a,b] = node.interval.interval
        is_inside = a<data_point[attr]<b
        if is_inside:
            node = node.inside
        else:
            node = node.outside
    return node


def predict(tree, data):
    ## for each data point, find the leaf node
    ## report leaf node prediction
    return np.array([
        find_leaf(tree, d).prediction 
        for d in data
    ])

## test
data_point = dataset.data[0]
find_leaf(tree, data_point)
predict(tree, dataset.data)

## Test