In [22]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import tree

In [23]:
zoo = pd.read_csv('zoo.csv')

In [24]:
zoo.head()

Unnamed: 0,animal_name,hair,feathers,eggs,milk,airborne,aquatic,predator,toothed,backbone,breathes,venomous,fins,legs,tail,domestic,catsize,class_type
0,aardvark,1,0,0,1,0,0,1,1,1,1,0,0,4,0,0,1,1
1,antelope,1,0,0,1,0,0,0,1,1,1,0,0,4,1,0,1,1
2,bass,0,0,1,0,0,1,1,1,1,0,0,1,0,1,0,0,4
3,bear,1,0,0,1,0,0,1,1,1,1,0,0,4,0,0,1,1
4,boar,1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,1,1


In [35]:
def gini_index(groups, classes):
    instances = float(sum([len(group) for group in groups]))
    gini = 0.0
    for group in groups:
        n = float(len(group))
        score = 0.0
        if(n==0): continue
        for class_val in classes:
            class_type = []
            for row in group:
                class_type.append(row[-1])
            proportion = class_type.count(class_val) / n
            score += proportion * proportion
        # weight the group score by its relative size
        gini += (1.0 - score) * (n / instances)
    return gini

In [36]:
def test_split(index, value, dataset):
    left, right = list(), list()
    for row in dataset:
        if row[index] < value:
            left.append(row)
        else:
            right.append(row)
    return left, right

In [59]:
def best_split(dataset):
    class_values = list(set(row[-1] for row in dataset))
    best_index, best_value, best_gini, best_groups = 999, 999, 999, None
    for index in range(len(dataset[0])-1):
        for row in dataset:
            groups = test_split(index, row[index], dataset)
            gini = gini_index(groups, class_values)
            print('X%d < %.3f Gini=%.3f' % (int(index), row[index], gini))
        if gini < best_gini:
            best_index, best_value, best_gini, best_groups = index, row[index], gini, groups
    return [best_index, best_value, best_groups]

In [60]:
dataset = [[2.771244718,1.784783929,0],
	[1.728571309,1.169761413,0],
	[3.678319846,2.81281357,0],
	[3.961043357,2.61995032,0],
	[2.999208922,2.209014212,0],
	[7.497545867,3.162953546,1],
	[9.00220326,3.339047188,1],
	[7.444542326,0.476683375,1],
	[10.12493903,3.234550982,1],
	[6.642287351,3.319983761,1]]
index, value, groups = best_split(dataset)
print('Best: [X%d < %.3f]' % (int(index), value))

X0 < 2.771 Gini=0.444
X0 < 1.729 Gini=0.500
X0 < 3.678 Gini=0.286
X0 < 3.961 Gini=0.167
X0 < 2.999 Gini=0.375
X0 < 7.498 Gini=0.286
X0 < 9.002 Gini=0.375
X0 < 7.445 Gini=0.167
X0 < 10.125 Gini=0.444
X0 < 6.642 Gini=0.000
X1 < 1.785 Gini=0.500
X1 < 1.170 Gini=0.444
X1 < 2.813 Gini=0.320
X1 < 2.620 Gini=0.417
X1 < 2.209 Gini=0.476
X1 < 3.163 Gini=0.167
X1 < 3.339 Gini=0.444
X1 < 0.477 Gini=0.500
X1 < 3.235 Gini=0.286
X1 < 3.320 Gini=0.375
Best: [X0 < 6.642]


In [61]:
def to_terminal(group):
    outcomes = [row[-1] for row in group]
    return max(set(outcomes), key=outcomes.count)

In [62]:
def split(node, max_depth, min_size, depth):
    left, right = node['groups']
    del(node['groups'])
    # check for a no split
    if not left or not right:
        node['left'] = node['right'] = to_terminal(left + right)
        return
    # check for max depth
    if depth >= max_depth:
        node['left'], node['right'] = to_terminal(left), to_terminal(right)
        return
    # process left child
    if len(left) <= min_size:
        node['left'] = to_terminal(left)
    else:
        node['left'] = get_split(left)
        split(node['left'], max_depth, min_size, depth+1)
    # process right child
    if len(right) <= min_size:
        node['right'] = to_terminal(right)
    else:
        node['right'] = get_split(right)
        split(node['right'], max_depth, min_size, depth+1)

In [63]:
def build_tree(train, max_depth, min_size):
    root = get_split(train)
    split(root, max_depth, min_size, 1)
    return root

In [64]:
def predict(node, row):
    if row[node['index']] < node['value']:
        if isinstance(node['left'], dict):
            return predict(node['left'], row)
        else:
            return node['left']
    else:
        if isinstance(node['right'], dict):
            return predict(node['right'], row)
        else:
            return node['right']

In [67]:
def print_tree(node, depth=0):
	if isinstance(node, dict):
		print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))
		print_tree(node['left'], depth+1)
		print_tree(node['right'], depth+1)
	else:
		print('%s[%s]' % ((depth*' ', node)))
tree = build_tree(dataset, 3, 1)
print_tree(tree)

[X1 < 6.642]
 [X1 < 2.771]
  [0]
  [X1 < 2.771]
   [0]
   [0]
 [X1 < 7.498]
  [X1 < 7.445]
   [1]
   [1]
  [X1 < 7.498]
   [1]
   [1]


In [68]:
stump = {'index': 0, 'right': 1, 'value': 6.642287351, 'left': 0}
for row in dataset:
	prediction = predict(stump, row)
	print('Expected=%d, Got=%d' % (row[-1], prediction))

Expected=0, Got=0
Expected=0, Got=0
Expected=0, Got=0
Expected=0, Got=0
Expected=0, Got=0
Expected=1, Got=1
Expected=1, Got=1
Expected=1, Got=1
Expected=1, Got=1
Expected=1, Got=1
