In [117]:
import scipy.io
import pandas as pd
import numpy as np

from math import log2
from dataclasses import dataclass, field

from functools import partial
from queue import deque
from enum import Enum
from collections import deque

from typing import *
from operator import *

from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import *
from sklearn.datasets import *

import graphviz
#mat = scipy.io.loadmat('/Users/scott/projects/research-projects/tree_diff/input/twitter/influenza_outbreak_dataset.mat')

## Decision Tree Features
* ~Handle multiclass/binary~
* Binary or multple splits: Binary only 
* Choice of stopping criteria
* Choice of evaluation measure 
* Choice of splitting criteria 
* Choice of pruning strategy 
* Opaque rules

In [16]:
a = load_iris()
X = a.data
y = a.target

X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.5)

# Simple Decision tree

## Data classes 

In [265]:
class Operator(str, Enum):
    EQ = "=="
    LT = "<"
    GT = ">"
    LE = "<="
    GE = ">="
    NE = "!="
    
    @property
    def op(self):
        return {self.EQ: eq, 
                self.LT: lt, 
                self.LE: le, 
                self.GE: ge,
                self.NE: ne,                
                self.GT: gt}[self]

@dataclass
class Condition:
    attribute: str
    attribute_pos: int
    operator: Operator
    threshold: float

    def __repr__(self):
        return str(self)

    def __str__(self):
        return f"{self.attribute} {self.operator} {self.threshold}"
            
    def fire(self, x):        
        return self.operator.op(x[self.attribute_pos], self.threshold)
    
@dataclass
class Split:
    score: float 
    attribute_pos: int
    ids: Tuple[float]
    operations: List[Tuple[Operator, float]]
        
@dataclass            
class DecisionNode:
    label: str
    node_id: int
    value: List[int]    
    impurity: float = 0.0
    parent: 'Node' = None
    children: List['Node'] = field(default_factory=list)
    conditions: List[Condition] = field(default_factory=list)    
        
    def walk(self, callback):
        callback(self)
        for n in self.children:
            n.walk(callback)        
        
    def add_child(self, condition, node):
        self.conditions.append(condition)
        node.parent = self
        self.children.append(node)
    
    def predict(self, x):
        stack = deque()
        stack.append(self)
        node = self

        while not node.is_leaf():
            node = stack.pop()

            for i, cond in enumerate(node.conditions):
                if cond.fire(x):
                    stack.append(node.children[i])

        return node.label            
    
    def is_leaf(self):
        return not self.children
    
    def is_root(self):
        return not self.parent
    
    def plot(self):
        dot = graphviz.Digraph('tree', comment='Decision Tree')  

        def update_dot(dot, node):
            dot.node(f"{node.node_id}", f"Node_{node.node_id}\nImpurity: {node.impurity:0.3f}\nLabel: {node.label}\nValue: {node.value}\nSamples: {sum(node.value)}")
            if node.parent:
                cond = node.find_to_condition()
                dot.edge(f"{node.parent.node_id}" , f"{node.node_id}", str(cond))
        
        update_dot_partial = partial(update_dot, dot)        
        self.walk(update_dot_partial)
        return dot
    
    def find_to_condition(self):
        if self.is_root():
            return None
        else:            
            index = -1
            for i, node in enumerate(self.parent.children):
                if node.node_id == self.node_id:
                    index = i
            if index < 0:
                raise ValueError("Incorrect tree")
            return self.parent.conditions[index]
                             
    def __str__(self):
        return f"Node_{self.node_id}"        

## Utility functions

In [252]:
def stopping_criteria(tree_depth, **kwargs):
    max_depth = kwargs.pop('max_depth', -1)
    if tree_depth >= max_depth:
        return True        
    return False    


def gini_impurity(y):
    counts = Counter(y)
    total = sum(counts.values())
    return round(1 - sum(map(lambda x: (x / total) ** 2, counts.values())), 3)


def entropy_impurity(y):
    counts = Counter(y)
    total = sum(counts.values())
    return - sum(map(lambda x: (x / total) * log2 (x / total), counts.values()))    
    

def evaluation_measure(groups: Tuple, measure):
    N = sum(map(len,groups))    
    return sum(map(lambda x: len(x) / N * measure(x), groups))
    
    
def calculate_current_depth(current_node):
    depth_counter = 0
    parent_node = current_node.parent    
    while parent_node:
        depth_counter += 1
        parent_node = parent_node.parent
    return depth_counter

def count_values(array, values):
    return [np.count_nonzero(array == i) for i in sorted(values)]        
        
def find_best_split(X, y, **kwargs):     
    split = Split(1,-1,(),())    
    measure = kwargs.pop("measure", gini_impurity)
    
    # Loop over attributes
    for i in range(0, X.shape[1]):
        x_s = X[:, i]

        # Try each unique value (inefficient for numerical values)
        # TODO: All split conditions are in the dataset unlike in CART       
        for threshold in np.unique(x_s):
            
            # TODO: Support non binary splits
            ids = (x_s <= threshold, x_s > threshold)
            operations = [(Operator.LE, threshold), (Operator.GT, threshold)]
            
            y_values = [y[i] for i in ids]            
            score = evaluation_measure(y_values, measure)

            # Find smallest gain, use 
            if score < split.score:
                split = Split(score, i, ids, operations)

    return split

## Grow a decision tree

In [290]:
def grow_tree(X, y, **kwargs):
    attribute_types = list(map(str, X.dtypes))
    column_name = X.columns
    X = X.to_numpy()
    node_counter = 1    
    stack = deque()
    
    # Set up decision tree
    classes = np.unique(y)
    counts = count_values(y, classes)
    tree = DecisionNode(max(y), node_counter, counts, gini_impurity(y))
    stack.append((tree, X, y))

    while len(stack) != 0:
        current_node, parent_X, parent_y = stack.pop()            

        # Stop once reached max depth branching
        current_depth = calculate_current_depth(current_node)
        if stopping_criteria(current_depth, **kwargs):
            continue
        
        # Stop branching if node contains a single class
        values = np.unique(parent_y)            
        if len(values) < 2:
            continue
        
        
        # Determine best attribute and split
        split = find_best_split(parent_X, parent_y, **kwargs)

        # Update tree with new split
        for cond, split_ids in zip(split.operations, split.ids):        
            new_y = parent_y[split_ids]
            new_X = parent_X[split_ids]            

            # Ensure new node is processed later
            counts = count_values(new_y, classes)
            if len(counts) == 0:
                continue 
            
            condition = Condition(column_name[split.attribute_pos], 
                                  split.attribute_pos, 
                                  cond[0],
                                  cond[1])
             
            label = classes[np.argmax(counts)]
            node_counter += 1
            
            score = gini_impurity(new_y)
 
            new_node = DecisionNode(label, node_counter, counts, score)
            current_node.add_child(condition, new_node)

            stack.append((new_node, new_X, new_y))
            
    return tree        

tree = grow_tree(pd.DataFrame(X_train, columns=["F1", "F2", "F3", 'F4']), y_train, max_depth=2, min_samples_split=2)
#tree.plot()

# Sample datasets

In [309]:
datasets = [load_iris, load_breast_cancer, load_digits, load_wine]

for func_data in datasets:
    d = func_data()
    X = d.data
    y = d.target
    X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.5)
    
    tree = grow_tree(pd.DataFrame(X_train), y_train, max_depth=2)            
    y_pred_cust = [tree.predict(X_test[i]) for i in range(0, X_test.shape[0])]
    
    sklearn_dt = DecisionTreeClassifier(max_depth=2).fit(X_train, y_train)    
    y_pred = sklearn_dt.predict(X_test)
    
    print(f"Sklearn Decision Tree:{accuracy_score(y_test, y_pred):0.2f}, Custom Tree:{accuracy_score(y_test, y_pred_cust):0.2f}, {func_data.__name__}")
    

Sklearn Decision Tree:0.95, Custom Tree:0.96, load_iris
Sklearn Decision Tree:0.94, Custom Tree:0.94, load_breast_cancer
Sklearn Decision Tree:0.30, Custom Tree:0.30, load_digits
Sklearn Decision Tree:0.91, Custom Tree:0.91, load_wine


# Regrowth algorithm

In [369]:
d = load_digits()
X = d.data
y = d.target
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.5)

tree = grow_tree(pd.DataFrame(X_train), y_train, max_depth=3)            
y_pred_cust_train = [tree.predict(X_train[i]) for i in range(0, X_train.shape[0])]
y_pred_cust = [tree.predict(X_test[i]) for i in range(0, X_test.shape[0])]

f"Train: {accuracy_score(y_train, y_pred_cust_train):0.2f}, Validation: {accuracy_score(y_test, y_pred_cust):0.2f}" 

'Train: 0.48, Validation: 0.46'

In [380]:
budget = 5
base_accuracy = accuracy_score(y_train, y_pred_cust_train)
base_tree = tree

leafs = []
tree.walk(lambda x: leafs.append(x))
N = sum(tree.value)
for x in sorted(leafs, key=lambda x: x.impurity, reverse=False):
    print(x.node_id, x.impurity, x.label, x.value, sum(x.value)/N)

15 0.067 9 [0, 0, 0, 1, 0, 0, 0, 0, 0, 28] 0.03229398663697105
13 0.105 0 [87, 0, 1, 0, 2, 0, 1, 0, 0, 1] 0.10244988864142539
14 0.142 5 [0, 0, 1, 0, 0, 12, 0, 0, 0, 0] 0.014476614699331848
11 0.209 0 [87, 0, 1, 0, 2, 2, 3, 0, 2, 1] 0.1091314031180401
8 0.448 7 [0, 5, 2, 4, 3, 9, 0, 85, 8, 0] 0.1291759465478842
10 0.473 9 [0, 0, 1, 1, 0, 12, 0, 0, 0, 28] 0.0467706013363029
2 0.56 0 [87, 0, 2, 1, 2, 14, 3, 0, 2, 29] 0.155902004454343
12 0.667 5 [0, 0, 0, 0, 0, 2, 2, 0, 2, 0] 0.0066815144766146995
6 0.703 3 [0, 25, 81, 89, 1, 1, 2, 0, 27, 5] 0.2572383073496659
9 0.718 4 [1, 6, 0, 0, 44, 33, 1, 11, 2, 10] 0.12026726057906459
4 0.73 7 [1, 11, 2, 4, 47, 42, 1, 96, 10, 10] 0.24944320712694878
7 0.827 6 [1, 55, 7, 4, 29, 36, 81, 1, 52, 37] 0.3374164810690423
5 0.86 3 [1, 80, 88, 93, 30, 37, 83, 1, 79, 42] 0.5946547884187082
3 0.887 3 [2, 91, 90, 97, 77, 79, 84, 97, 89, 52] 0.844097995545657
1 0.9 9 [89, 91, 92, 98, 79, 93, 87, 97, 91, 81] 1.0
