In [None]:
import numpy as np
import numpy_indexed as npi
from collections import namedtuple, Counter

from sklearn import datasets
from textwrap import dedent


# def argmax(distribution):
#     return max(distribution.items(), key=lambda x:x[1])[0]

# def class_probablity(values):
#     counts = Counter(values)
#     total = sum(counts.values())
#     distribution = {value:count/total for value,count in counts.items()}
#     return distribution


def all_splits(sorted_values):
    return [(a+b)/2 for a,b in zip(sorted_values[0:-1], sorted_values[1:])]

def class_probability(values):
    '''from a list of class labels to a vector for label distributions'''
    groups = npi.group_by(values)
    total = len(values)
#     distribution = {k:c/total for k,c in zip(groups.unique, groups.count)}
#     return distribution
    return groups.count / total ## groups.count is an np.array of counts, with key values sorted already

def gini_index(target):
    p = class_probability(target)
    purity = 1-(p**2).sum()
    return purity

In [None]:
iris = datasets.load_iris()
print(iris.data.shape, iris.target.shape)

In [None]:
Split = namedtuple('Split', ['attr', 'value', 'loss', 'left', 'right'])
Dataset = namedtuple('Dataset', ['data', 'target'])
    
class LeafNode:
    def __init__(self, prediction=None, members=[], level=0):
        self.prediction = prediction
        self.members = members
        self.level=level
    def __repr__(self):
        return dedent(f'''Leaf(prediction={self.prediction}, size={self.members.data.shape[0]})''')


class SplitNode:
    def __init__(
        self, 
        split=None, 
        prediction=None, 
        left=None,
        right=None,
        level=None,
    ):
        self.split = split
        self.prediction = prediction
        self.left = left
        self.right = right
        self.level = level
        
    def __repr__(self):
        indent = 2
        return dedent(f'''
Split(attr={self.split.attr}, threshold={self.split.value}
{" "*indent*self.left.level}L={self.left}
{" "*indent*self.right.level}R={self.right}'''.strip("\n"))
        

def find_best_split(dataset, attr_dimensions=range(0,3), criterion=gini_index):
    data = dataset.data
    target = dataset.target
    
    ## compute all attribute-split qualities
    splits = []
    for a in attr_dimensions:
        ## sort by attribute
        data_sorted = sorted(data, key=lambda d:d[a])
        attr_values = [d[a] for d in data_sorted]
        for split_point in all_splits(attr_values):
            ## split data by split point
            data_left = data[data[:,a]<split_point]
            target_left = target[data[:,a]<split_point]
            data_right = data[data[:,a]>=split_point]
            target_right = target[data[:,a]>=split_point]
            
            loss_left = criterion(target_left)
            loss_right = criterion(target_right)
            loss = loss_left + loss_right
            splits.append(Split(
                attr=a, 
                value=split_point, 
                loss=loss, 
                left=Dataset(data_left, target_left),
                right=Dataset(data_right, target_right),
            ))
        
    ## find best split
    split = min(splits, key=lambda s:s.loss)
    return split


def fit_decision_tree_classifier(dataset, criterion=gini_index, level=0, attr_dimensions=[]):
    data = dataset.data
    target = dataset.target
    
    if len(attr_dimensions)==0:
        is_leaf = True
    else:
        best_split = find_best_split(dataset, attr_dimensions=attr_dimensions)
        is_leaf = best_split.loss<0.1 or data.shape[0] <= 4
    
    if is_leaf:
        g = npi.group_by(target)
        per_class_counts = g.count
        unique_class_labels = g.unique
        majority_class = unique_class_labels[per_class_counts.argmax()]
        return LeafNode(prediction=majority_class, members=dataset, level=level)
    else:
        attr_dimensions.remove(best_split.attr)
        left_node = fit_decision_tree_classifier(
            best_split.left, 
            level=level+1, 
            attr_dimensions=attr_dimensions, 
        )
        right_node = fit_decision_tree_classifier(
            best_split.right, 
            level=level+1, 
            attr_dimensions=attr_dimensions, 
        )
        return SplitNode(split=best_split, left=left_node, right=right_node, level=level)
    
    

def find_leaf(tree, data_point):
    node = tree
    while type(node) is SplitNode:
        attr = node.split.attr
        thresh = node.split.value
        left = data_point[attr] < thresh
        if left:
            node = node.left
        else:
            node = node.right
    return node


def predict(tree, data):
    ## for each data point, find the leaf node
    ## report leaf node prediction
    return np.array([
        find_leaf(tree, d).prediction 
        for d in data
    ])
# data_point = dataset.data[0]
# find_leaf(tree, data_point)
# predict(tree, dataset.data)

In [None]:
n_attributes = iris.data.shape[1]

# find_best_split(iris, attr_dimensions=range(n_attributes))
dataset = Dataset(iris.data, iris.target)


tree = fit_decision_tree_classifier(dataset, attr_dimensions=list(range(4)))
tree

In [None]:
predict(tree, dataset.data)