In [1]:
import numpy as np
import math
import pandas as pd
from sklearn import datasets
from sklearn.metrics import log_loss
import statistics
from pprint import pprint

In [3]:
X = datasets.load_iris().data
y = datasets.load_iris().target
from sklearn.model_selection import train_test_split
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.2,random_state=22)

In [10]:
def entropy_of_class(class1, number_of_samples):
    return -(class1*1.0/number_of_samples)*math.log(class1*1.0/number_of_samples, 2)

def calc_entropy(class1, class2):
    
    if class1== 0 or class2 == 0:  
        return 0
    return entropy_of_class(class1, class1+class2) + entropy_of_class(class2, class1+class2)


def entropy_of_subset(subset): 
    
    s = 0
    num_sample = len(subset)
    classes = set(subset)
    for cla in classes:   
        n_c = sum(subset==cla)
        e = n_c*1.0/num_sample * calc_entropy(sum(subset==cla), sum(subset!=cla)) 
        s += e
    return s, num_sample


def get_entropy(y_predict, y_real):
    
    if len(y_predict) != len(y_real):
        print('They have to be the same length')
        return None
    n = len(y_real)
    s_true, n_true = entropy_of_subset(y_real[y_predict]) 
    s_false, n_false = entropy_of_subset(y_real[~y_predict]) 
    s = n_true*1.0/n * s_true + n_false*1.0/n * s_false 
    return s

In [11]:
class DecisionTreeClassifier(object):
    def __init__(self, max_depth):
        self.depth = 0
        self.max_depth = max_depth
    
    def fit(self, x, y, par_node={}, depth=0):
        if par_node is None: 
            return None
        elif len(y) == 0:
            return None
        elif self.all_same(y):
            return {'val':y[0]}
        elif depth >= self.max_depth:
            return None
        else: 
            col, cutoff, entropy = self.find_best_split_of_all(x, y)    
            y_left = y[x[:, col] < cutoff]
            y_right = y[x[:, col] >= cutoff]
            par_node = {'col': iris.feature_names[col], 'index_col':col,
                        'cutoff':cutoff,
                       'val': np.round(np.mean(y))}
            par_node['left'] = self.fit(x[x[:, col] < cutoff], y_left, {}, depth+1)
            par_node['right'] = self.fit(x[x[:, col] >= cutoff], y_right, {}, depth+1)
            self.depth += 1 
            self.trees = par_node
            return par_node
    
    def find_best_split_of_all(self, x, y):
        col = None
        min_entropy = 1
        cutoff = None
        for i, c in enumerate(x.T):
            entropy, cur_cutoff = self.find_best_split(c, y)
            if entropy == 0:    
                return i, cur_cutoff, entropy
            elif entropy <= min_entropy:
                min_entropy = entropy
                col = i
                cutoff = cur_cutoff
        return col, cutoff, min_entropy
    
    def find_best_split(self, col, y):
        min_entropy = 10
        n = len(y)
        for value in set(col):
            y_predict = col < value
            my_entropy = get_entropy(y_predict, y)
            if my_entropy <= min_entropy:
                min_entropy = my_entropy
                cutoff = value
        return min_entropy, cutoff
    
    def all_same(self, items):
        return all(x == items[0] for x in items)

In [12]:
clf = DecisionTreeClassifier(max_depth = 3)
m = clf.fit(X_train,y_train)
pprint(m)


{'col': 'petal width (cm)',
 'cutoff': 1.0,
 'index_col': 3,
 'left': {'val': 0},
 'right': {'col': 'petal width (cm)',
           'cutoff': 1.8,
           'index_col': 3,
           'left': {'col': 'petal length (cm)',
                    'cutoff': 5.1,
                    'index_col': 2,
                    'left': None,
                    'right': None,
                    'val': 1.0},
           'right': {'col': 'petal length (cm)',
                     'cutoff': 4.9,
                     'index_col': 2,
                     'left': {'val': 1},
                     'right': {'val': 2},
                     'val': 2.0},
           'val': 1.0},
 'val': 1.0}
