In [None]:
import numpy as np
from collections import namedtuple, Counter

from sklearn import datasets
from textwrap import dedent

def argmax(distribution):
    return max(distribution.items(), key=lambda x:x[1])[0]

def discrete_distribution(values):
    counts = Counter(values)
    total = sum(counts.values())
    distribution = {value:count/total for value,count in counts.items()}
    return distribution

def all_splits(sorted_values):
    return [(a+b)/2 for a,b in zip(sorted_values[0:-1], sorted_values[1:])]

def gini_index(target):
    p = discrete_distribution(target)
    ps = np.array(list(p.values()))
    purity = 1-(ps**2).sum()
    return purity

In [None]:
iris = datasets.load_iris()
print(iris.data[0], iris.target[0])

In [None]:
Split = namedtuple('Split', ['attr', 'value', 'loss', 'left', 'right'])

class Dataset(namedtuple('Dataset', ['data', 'target'])):
    def __repr__(self):
        return f'Dataset(data=<>, target=<>)'
    
    
class LeafNode:
    def __init__(self, prediction=None, members=[], level=0):
        self.prediction = prediction
        self.members = members
        self.level=level
    def __repr__(self):
        return dedent(f'''Leaf(prediction={self.prediction}, members={self.members.data.shape})''')


class InternalNode:
    def __init__(
        self, 
        split=None, 
        prediction=None, 
        left=None,
        right=None,
        level=None,
    ):
        self.split = split
        self.prediction = prediction
        self.left = left
        self.right = right
        self.level = level
        
    def __repr__(self):
        indent = 2
        return dedent(f'''
Internal(attr={self.split.attr}, split={self.split.value}
{" "*indent*self.left.level}L={self.left}
{" "*indent*self.right.level}R={self.right}'''.strip("\n"))
        

def find_best_split(dataset, attr_dimensions=range(0,3), criterion=gini_index):
    data = dataset.data
    target = dataset.target
    
    ## compute all attribute-split qualities
    splits = []
    for a in attr_dimensions:
        ## sort by attribute
        data_sorted = sorted(data, key=lambda d:d[a])
        attr_values = [d[a] for d in data_sorted]
        for split_point in all_splits(attr_values):
            ## split data by split point
            data_left = data[data[:,a]<split_point]
            target_left = target[data[:,a]<split_point]
            data_right = data[data[:,a]>=split_point]
            target_right = target[data[:,a]>=split_point]
            
            loss_left = criterion(target_left)
            loss_right = criterion(target_right)
            loss = loss_left + loss_right
            splits.append(Split(
                attr=a, 
                value=split_point, 
                loss=loss, 
                left=Dataset(data_left, target_left),
                right=Dataset(data_right, target_right),
            ))
        
    ## find best split
    split = min(splits, key=lambda s:s.loss)
    return split


def fit_decision_tree_classifier(dataset, criterion=gini_index, level=0, attr_dimensions=[], all_dimensions=[]):
    data = dataset.data
    target = dataset.target
    is_leaf = data.shape[0] <= 4 or len(attr_dimensions)==0
    all_dimensions = attr_dimensions[:]
    if is_leaf:
        prediction = argmax(discrete_distribution(target))
        return LeafNode(prediction=prediction, members=dataset, level=level)
    else:
        split = find_best_split(dataset, attr_dimensions=attr_dimensions)
        if level%2==1:
            attr_dimensions.remove(split.attr)
        else:
            attr_dimensions = [split.attr]
        left_node = fit_decision_tree_classifier(
            split.left, 
            level=level+1, 
            attr_dimensions=attr_dimensions, 
            all_dimensions=all_dimensions,
        )
        right_node = fit_decision_tree_classifier(
            split.right, 
            level=level+1, 
            attr_dimensions=attr_dimensions, 
            all_dimensions=all_dimensions,
        )
        
        return InternalNode(split=split, left=left_node, right=right_node, level=level)

In [None]:



n_attributes = iris.data.shape[1]

# find_best_split(iris, attr_dimensions=range(n_attributes))
tree = fit_decision_tree_classifier(iris, attr_dimensions=list(range(4)))
tree