In [6]:
import numpy as np

In [69]:
def gini_index(groups, attributes):
    n = sum([len(group) for group in groups])

    gini = 0.0
    for group in groups:
        size = len(group)
        
        # Avoid zero division.
        if size == 0:
            continue
        score = 0
        
        for attr in attributes:
            p = [row[-1] for row in group].count(attr) / size
            score += p * p
        gini += (1 - score) * (size / n)
    return gini

In [70]:
gini_index([[[1, 1], [1, 0]], [[1, 1], [1, 0]]], [0, 1])

0.5

In [71]:
gini_index([[[1, 0], [1, 0]], [[1, 1], [1, 1]]], [0, 1])

0.0

In [28]:
def test_split(index, value, dataset):
    result = [], []
    for row in dataset:
        result[0 if row[index] < value else 1].append(row)
    return result

In [53]:
def get_split(dataset):
    """Select the best split point for a dataset."""
    attributes = list(set(row[-1] for row in dataset))
    best_index, best_value, best_score, best_groups = 999, 999, 999, None
    for index in range(len(dataset[0]) - 1):
        for row in dataset:
            groups = test_split(index, row[index], dataset)
            gini = gini_index(groups, attributes)
            print(f'X{index+1} < {row[index]:.3f} Gini={gini:.3f}')
            if gini < best_score:
                best_index, best_value, best_score, best_groups = index, row[index], gini, groups
    return {
        'index': best_index,
        'value': best_value,
        'groups': best_groups,
    }

In [54]:
data = """2.771244718		1.784783929		0
1.728571309		1.169761413		0
3.678319846		2.81281357		0
3.961043357		2.61995032		0
2.999208922		2.209014212		0
7.497545867		3.162953546		1
9.00220326		3.339047188		1
7.444542326		0.476683375		1
10.12493903		3.234550982		1
6.642287351		3.319983761		1
"""
dataset = []
for row in data.split('\n'):
    cols = row.strip().split('\t\t')
    if len(cols) > 1:
        dataset.append(list(map(float, cols)))
dataset

[[2.771244718, 1.784783929, 0.0],
 [1.728571309, 1.169761413, 0.0],
 [3.678319846, 2.81281357, 0.0],
 [3.961043357, 2.61995032, 0.0],
 [2.999208922, 2.209014212, 0.0],
 [7.497545867, 3.162953546, 1.0],
 [9.00220326, 3.339047188, 1.0],
 [7.444542326, 0.476683375, 1.0],
 [10.12493903, 3.234550982, 1.0],
 [6.642287351, 3.319983761, 1.0]]

In [55]:
result = get_split(dataset)
print(f'Split at [X{result["index"]+1} < {result["value"]}]')

X1 < 2.771 Gini=0.444
X1 < 1.729 Gini=0.500
X1 < 3.678 Gini=0.286
X1 < 3.961 Gini=0.167
X1 < 2.999 Gini=0.375
X1 < 7.498 Gini=0.286
X1 < 9.002 Gini=0.375
X1 < 7.445 Gini=0.167
X1 < 10.125 Gini=0.444
X1 < 6.642 Gini=0.000
X2 < 1.785 Gini=0.500
X2 < 1.170 Gini=0.444
X2 < 2.813 Gini=0.320
X2 < 2.620 Gini=0.417
X2 < 2.209 Gini=0.476
X2 < 3.163 Gini=0.167
X2 < 3.339 Gini=0.444
X2 < 0.477 Gini=0.500
X2 < 3.235 Gini=0.286
X2 < 3.320 Gini=0.375
Split at [X1 < 6.642287351]


In [56]:
def to_terminal(group):
    outcomes = [row[-1] for row in group]
    return max(set(outcomes), key=outcomes.count)

In [57]:
def split(node, max_depth, min_size, depth):
    left, right = node['groups']
    del(node['groups'])
    
    # Check for a no split.
    if not left or not right:
        node['left'] = node['right'] = to_terminal(left + right)
        return
    
    # Check for max depth.
    if depth >= max_depth:
        node['left'], node['right'] = to_terminal(left), to_terminal(right)
        return

    # Process left child.
    if len(left) <= min_size:
        node['left'] = to_terminal(left)
    else:
        node['left'] = get_split(left)
        split(node['left'], max_depth, min_size, depth+1)
    
    # Process right child.
    if len(right) <= min_size:
        node['right'] = to_terminal(right)
    else:
        node['right'] = get_split(right)
        split(node['right'], max_depth, min_size, depth + 1)

In [58]:
def build_tree(train, max_depth, min_size):
    root = get_split(train)
    split(root, max_depth, min_size, 1)
    return root

In [59]:
def print_tree(node, depth=0):
    if isinstance(node, dict):
        print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))
        print_tree(node['left'], depth+1)
        print_tree(node['right'], depth+1)
    else:
        print('%s[%s]' % ((depth*' ', node)))

In [65]:
tree = build_tree(dataset, 3, 1)
print_tree(tree)

X1 < 2.771 Gini=0.444
X1 < 1.729 Gini=0.500
X1 < 3.678 Gini=0.286
X1 < 3.961 Gini=0.167
X1 < 2.999 Gini=0.375
X1 < 7.498 Gini=0.286
X1 < 9.002 Gini=0.375
X1 < 7.445 Gini=0.167
X1 < 10.125 Gini=0.444
X1 < 6.642 Gini=0.000
X2 < 1.785 Gini=0.500
X2 < 1.170 Gini=0.444
X2 < 2.813 Gini=0.320
X2 < 2.620 Gini=0.417
X2 < 2.209 Gini=0.476
X2 < 3.163 Gini=0.167
X2 < 3.339 Gini=0.444
X2 < 0.477 Gini=0.500
X2 < 3.235 Gini=0.286
X2 < 3.320 Gini=0.375
X1 < 2.771 Gini=0.000
X1 < 1.729 Gini=0.000
X1 < 3.678 Gini=0.000
X1 < 3.961 Gini=0.000
X1 < 2.999 Gini=0.000
X2 < 1.785 Gini=0.000
X2 < 1.170 Gini=0.000
X2 < 2.813 Gini=0.000
X2 < 2.620 Gini=0.000
X2 < 2.209 Gini=0.000
X1 < 2.771 Gini=0.000
X1 < 3.678 Gini=0.000
X1 < 3.961 Gini=0.000
X1 < 2.999 Gini=0.000
X2 < 1.785 Gini=0.000
X2 < 2.813 Gini=0.000
X2 < 2.620 Gini=0.000
X2 < 2.209 Gini=0.000
X1 < 7.498 Gini=0.000
X1 < 9.002 Gini=0.000
X1 < 7.445 Gini=0.000
X1 < 10.125 Gini=0.000
X1 < 6.642 Gini=0.000
X2 < 3.163 Gini=0.000
X2 < 3.339 Gini=0.000
X2 < 0.4

In [66]:
def predict(node, row):
    if row[node['index']] < node['value']:
        if isinstance(node['left'], dict):
            return predict(node['left'], row)
        else:
            return node['left']
    else:
        if isinstance(node['right'], dict):
            return predict(node['right'], row)
        else:
            return node['right']

In [67]:
stump = {'index': 0,
         'right': 1,
         'value': 6.642287351,
         'left': 0}

In [68]:
for row in dataset:
    prediction = predict(stump, row)
    print(f'expected={row[-1]}, got={prediction}')

expected=0.0, got=0
expected=0.0, got=0
expected=0.0, got=0
expected=0.0, got=0
expected=0.0, got=0
expected=1.0, got=1
expected=1.0, got=1
expected=1.0, got=1
expected=1.0, got=1
expected=1.0, got=1
