In [1]:
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw

np.random.seed(0)

In [2]:
def entropy(y):
    p = np.array(y.value_counts()) / len(y)
    return -np.sum(p * np.log2(p))

In [3]:
class Node:
    def __init__(self, false_branch=None, true_branch=None, feature=None, feature_value=None):
        self.false_branch = false_branch
        self.true_branch = true_branch
        self.feature = feature
        self.feature_value = feature_value

In [4]:
class DecisionTree(Node):
        
    def build(self, X, y, score=entropy):
        
        if pd.unique(y).shape[0] == 1:
            return pd.unique(y)[0]
        else:
            max_inf_gain = 0
            len_y = len(y)
            score_y = score(y)
            
            for i in X:
                values = pd.unique(X[i])
                for v in values:
                    if type(v) == str: ind = X[i] == v
                    else: ind = X[i] < v
                        
                    inf_gain = score_y - (score(y[ind])*len(y[ind]) + score(y[~ind])*len(y[~ind]))/len_y
                    if max_inf_gain < inf_gain:
                        max_inf_gain = inf_gain
                        index = ind
                        feature, feature_value = i, v

            if len(y[index]) == 0 or len(y[~index]) == 0:
                return y.value_counts().argmax()
            else:
                self.false_branch = DecisionTree().build(X[~index], y[~index])
                self.true_branch = DecisionTree().build(X[index], y[index])
                self.feature=feature
                self.feature_value=feature_value
            
        return self
    
    def predict(self, x):
        if not isinstance(self, Node):
            return self
        
        i, v = self.feature, self.feature_value
        
        if (type(v) == str and x[i] == v) or (type(v) != str and x[i] < v):
            next_tree = self.true_branch
        else:
            next_tree = self.false_branch
            
        if isinstance(next_tree, Node):
            return next_tree.predict(x)
        else:
            return next_tree

In [5]:
def getwidth(tree):
    if isinstance(tree, Node):
        return getwidth(tree.false_branch) + getwidth(tree.true_branch)
    return 1

def getdepth(tree):
    if isinstance(tree, Node):
        return 1 + max(getdepth(tree.false_branch), getdepth(tree.true_branch))
    return 1

In [6]:
def drawtree(tree, path='tree.jpg'):
    w = getwidth(tree) * 100
    h = getdepth(tree) * 100
    img = Image.new('RGB', (w, h), (255, 255, 255))
    draw = ImageDraw.Draw(img)
    drawnode(draw, tree, w / 2, 20)
    img.save(path, 'JPEG')
    
def drawnode(draw, tree, x, y):
    if isinstance(tree, Node):
        shift = 100
        width1 = getwidth(tree.false_branch) * shift
        width2 = getwidth(tree.true_branch) * shift
        left = x - (width1 + width2) / 2
        right = x + (width1 + width2) / 2

        i, v = tree.feature, tree.feature_value
        if (type(v) == str):
            predicate = "{feature} == {value}?".format(feature=i, value=v)
        else:
            predicate = "{feature} < {value}?".format(feature=i, value=v)
        
        draw.text((x - 20, y - 10), predicate, (0, 0, 0))
        draw.line((x, y, left + width1 / 2, y + shift), fill=(255, 0, 0))
        draw.line((x, y, right - width2 / 2, y   + shift), fill=(255, 0, 0))
        drawnode(draw, tree.false_branch, left + width1 / 2, y   + shift)
        drawnode(draw, tree.true_branch, right - width2 / 2, y   + shift)
    else:
        draw.text((x - 20, y), tree, (0, 0, 0))

In [7]:
data = pd.read_csv('halloween.csv')
y = data['type']
X = data.drop('type', axis=1)

In [8]:
n = int(len(y)*0.8)

In [9]:
X_train, y_train = X[:n], y[:n]
X_test, y_test = X[n:], y[n:]

In [10]:
dt = DecisionTree()

In [11]:
dt.build(X_train, y_train)

<__main__.DecisionTree at 0x7fcd5c0>

In [12]:
err_train = 0
for i in np.arange(len(y_train)):
    if dt.predict(X_train.iloc[i]) != y_train[i]:
        err_train += 1

In [40]:
err_test = 0
for i in np.arange(len(y_test)):
    if dt.predict(X_test.iloc[i]) != y_test[i+y_test.index[0]]:
        err_test += 1

In [42]:
err_train/len(y_train), err_test/len(y_test)

(0.0, 0.3466666666666667)

In [44]:
drawtree(dt)

In [38]:
getdepth(dt)

12

In [39]:
getwidth(dt)

62