In [40]:
import numpy as np
from sklearn import datasets
from sklearn.datasets.samples_generator import make_classification

In [41]:
def test_split(index, value, dataset):
    left, right = list(), list()
    for row in dataset:
        if row[index] < value:
            left.append(row)
        else:
            right.append(row)
    return left, right

In [42]:
def gini_index(groups, classes):
    n_instances = float(sum([len(group) for group in groups]))
    gini = 0.0
    for group in groups:
        size = float(len(group))
        if size == 0:
            continue
        score = 0.0
        for class_val in classes:
            p = [row[-1] for row in group].count(class_val) / size
            score += p * p
        gini += (1.0 - score) * (size / n_instances)
    return gini

In [43]:
def get_split(dataset):
    class_values = list(set(row[-1] for row in dataset))
    b_index, b_value, b_score, b_groups = 999, 999, 999, None
    for index in range(len(dataset[0])-1):
        for row in dataset:
            groups = test_split(index, row[index], dataset)
            gini = gini_index(groups, class_values)
            if gini < b_score:
                b_index, b_value, b_score, b_groups = index, row[index], gini, groups
    return {'index':b_index, 'value':b_value, 'groups':b_groups}

In [44]:
def to_terminal(group):
    outcomes = [row[-1] for row in group]
    return max(set(outcomes), key=outcomes.count)


In [45]:
def split(node):
    left, right = node['groups']
    del(node['groups'])
    if not left or not right:
        node['left'] = node['right'] = to_terminal(left + right)
        return
    gini1, gini2, gini3 = 999, 999, 999 
    class_values = list(set(row[-1] for row in left + right))
    for index in range(len((left + right)[0])-1):
        for row in left + right:
            groups = test_split(index, row[index], left + right)
            gini = gini_index(groups, class_values)
            if gini < gini1:
                gini1 = gini
    class_values = list(set(row[-1] for row in left))
    for index in range(len(left[0])-1):
        for row in left:
            groups = test_split(index, row[index], left)
            gini = gini_index(groups, class_values)
            if gini < gini2:
                gini2 = gini
    class_values = list(set(row[-1] for row in right))
    for index in range(len(right[0])-1):
        for row in right:
            groups = test_split(index, row[index], right)
            gini = gini_index(groups, class_values)
            if gini < gini3:
                gini3 = gini
    groups = test_split(index, row[index], right)
    gini3 = gini_index(groups, class_values)
    if gini1 <= 0:
        node['left'], node['right'] = to_terminal(left), to_terminal(right)
        return
    if gini2 <= 0:
        node['left'] = to_terminal(left)
    else:
        node['left'] = get_split(left)
        split(node['left'])
    if gini3 <= 0:
        node['right'] = to_terminal(right)
    else:
        node['right'] = get_split(right)
        split(node['right'])


In [46]:
def build_tree(train):
    root = get_split(train)
    split(root)
    return root

def print_tree(node, depth=0):
    if isinstance(node, dict):
        print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))
        print_tree(node['left'], depth+1)
        print_tree(node['right'], depth+1)
    else:
        print('%s[%s]' % ((depth*' ', node)))

In [47]:
dataset = [[5.1, 3.5, 1.4, 0.2, 0],
        [4.9, 3. , 1.4, 0.2, 0],
        [4.7, 3.2, 1.3, 0.2, 0],
        [4.6, 3.1, 1.5, 0.2, 0],
        [5. , 3.6, 1.4, 0.2, 0],
        [5.4, 3.9, 1.7, 0.4, 0],
        [4.6, 3.4, 1.4, 0.3, 0],
        [5. , 3.4, 1.5, 0.2, 0],
        [4.4, 2.9, 1.4, 0.2, 0],
        [4.9, 3.1, 1.5, 0.1, 0],
        [5.4, 3.7, 1.5, 0.2, 0],
        [4.8, 3.4, 1.6, 0.2, 0],
        [4.8, 3. , 1.4, 0.1, 0],
        [4.3, 3. , 1.1, 0.1, 0],
        [5.8, 4. , 1.2, 0.2, 0],
        [5.7, 4.4, 1.5, 0.4, 0],
        [5.4, 3.9, 1.3, 0.4, 0],
        [5.1, 3.5, 1.4, 0.3, 0],
        [5.7, 3.8, 1.7, 0.3, 0],
        [5.1, 3.8, 1.5, 0.3, 0],
        [5.4, 3.4, 1.7, 0.2, 0],
        [5.1, 3.7, 1.5, 0.4, 0],
        [4.6, 3.6, 1. , 0.2, 0],
        [5.1, 3.3, 1.7, 0.5, 0],
        [4.8, 3.4, 1.9, 0.2, 0],
        [5. , 3. , 1.6, 0.2, 0],
        [5. , 3.4, 1.6, 0.4, 0],
        [5.2, 3.5, 1.5, 0.2, 0],
        [5.2, 3.4, 1.4, 0.2, 0],
        [4.7, 3.2, 1.6, 0.2, 0],
        [4.8, 3.1, 1.6, 0.2, 0],
        [5.4, 3.4, 1.5, 0.4, 0],
        [5.2, 4.1, 1.5, 0.1, 0],
        [5.5, 4.2, 1.4, 0.2, 0],
        [4.9, 3.1, 1.5, 0.2, 0],
        [5. , 3.2, 1.2, 0.2, 0],
        [5.5, 3.5, 1.3, 0.2, 0],
        [4.9, 3.6, 1.4, 0.1, 0],
        [4.4, 3. , 1.3, 0.2, 0],
        [5.1, 3.4, 1.5, 0.2, 0],
        [5. , 3.5, 1.3, 0.3, 0],
        [4.5, 2.3, 1.3, 0.3, 0],
        [4.4, 3.2, 1.3, 0.2, 0],
        [5. , 3.5, 1.6, 0.6, 0],
        [5.1, 3.8, 1.9, 0.4, 0],
        [4.8, 3. , 1.4, 0.3, 0],
        [5.1, 3.8, 1.6, 0.2, 0],
        [4.6, 3.2, 1.4, 0.2, 0],
        [5.3, 3.7, 1.5, 0.2, 0],
        [5. , 3.3, 1.4, 0.2, 1],
        [7. , 3.2, 4.7, 1.4, 1],
        [6.4, 3.2, 4.5, 1.5, 1],
        [6.9, 3.1, 4.9, 1.5, 1],
        [5.5, 2.3, 4. , 1.3, 1],
        [6.5, 2.8, 4.6, 1.5, 1],
        [5.7, 2.8, 4.5, 1.3, 1],
        [6.3, 3.3, 4.7, 1.6, 1],
        [4.9, 2.4, 3.3, 1. , 1],
        [6.6, 2.9, 4.6, 1.3, 1],
        [5.2, 2.7, 3.9, 1.4, 1],
        [5. , 2. , 3.5, 1., 1 ],
        [5.9, 3. , 4.2, 1.5, 1],
        [6. , 2.2, 4. , 1. , 1],
        [6.1, 2.9, 4.7, 1.4, 1],
        [5.6, 2.9, 3.6, 1.3, 1],
        [6.7, 3.1, 4.4, 1.4, 1],
        [5.6, 3. , 4.5, 1.5, 1],
        [5.8, 2.7, 4.1, 1. , 1],
        [6.2, 2.2, 4.5, 1.5, 1],
        [5.6, 2.5, 3.9, 1.1, 1],
        [5.9, 3.2, 4.8, 1.8, 1],
        [6.1, 2.8, 4. , 1.3, 1],
        [6.3, 2.5, 4.9, 1.5, 1],
        [6.1, 2.8, 4.7, 1.2, 1],
        [6.4, 2.9, 4.3, 1.3, 1],
        [6.6, 3. , 4.4, 1.4, 1],
        [6.8, 2.8, 4.8, 1.4, 1],
        [6.7, 3. , 5. , 1.7, 1],
        [6. , 2.9, 4.5, 1.5, 1],
        [5.7, 2.6, 3.5, 1. , 1],
        [5.5, 2.4, 3.8, 1.1, 1],
        [5.5, 2.4, 3.7, 1. , 1],
        [5.8, 2.7, 3.9, 1.2, 1],
        [6. , 2.7, 5.1, 1.6, 1],
        [5.4, 3. , 4.5, 1.5, 1],
        [6. , 3.4, 4.5, 1.6, 1],
        [6.7, 3.1, 4.7, 1.5, 1],
        [6.3, 2.3, 4.4, 1.3, 1],
        [5.6, 3. , 4.1, 1.3, 1],
        [5.5, 2.5, 4. , 1.3, 1],
        [5.5, 2.6, 4.4, 1.2, 1],
        [6.1, 3. , 4.6, 1.4, 1],
        [5.8, 2.6, 4. , 1.2, 1],
        [5. , 2.3, 3.3, 1. , 1],
        [5.6, 2.7, 4.2, 1.3, 1],
        [5.7, 3. , 4.2, 1.2, 1],
        [5.7, 2.9, 4.2, 1.3, 1],
        [6.2, 2.9, 4.3, 1.3, 1],
        [5.1, 2.5, 3. , 1.1, 1],
        [5.7, 2.8, 4.1, 1.3, 2],
        [6.3, 3.3, 6. , 2.5, 2],
        [5.8, 2.7, 5.1, 1.9, 2],
        [7.1, 3. , 5.9, 2.1, 2],
        [6.3, 2.9, 5.6, 1.8, 2],
        [6.5, 3. , 5.8, 2.2, 2],
        [7.6, 3. , 6.6, 2.1, 2],
        [4.9, 2.5, 4.5, 1.7, 2],
        [7.3, 2.9, 6.3, 1.8, 2],
        [6.7, 2.5, 5.8, 1.8, 2],
        [7.2, 3.6, 6.1, 2.5, 2],
        [6.5, 3.2, 5.1, 2. , 2],
        [6.4, 2.7, 5.3, 1.9, 2],
        [6.8, 3. , 5.5, 2.1, 2],
        [5.7, 2.5, 5. , 2. , 2],
        [5.8, 2.8, 5.1, 2.4, 2],
        [6.4, 3.2, 5.3, 2.3, 2],
        [6.5, 3. , 5.5, 1.8, 2],
        [7.7, 3.8, 6.7, 2.2, 2],
        [7.7, 2.6, 6.9, 2.3, 2],
        [6. , 2.2, 5. , 1.5, 2],
        [6.9, 3.2, 5.7, 2.3, 2],
        [5.6, 2.8, 4.9, 2. , 2],
        [7.7, 2.8, 6.7, 2. , 2],
        [6.3, 2.7, 4.9, 1.8, 2],
        [6.7, 3.3, 5.7, 2.1, 2],
        [7.2, 3.2, 6. , 1.8, 2],
        [6.2, 2.8, 4.8, 1.8, 2],
        [6.1, 3. , 4.9, 1.8, 2],
        [6.4, 2.8, 5.6, 2.1, 2],
        [7.2, 3. , 5.8, 1.6, 2],
        [7.4, 2.8, 6.1, 1.9, 2],
        [7.9, 3.8, 6.4, 2. , 2],
        [6.4, 2.8, 5.6, 2.2, 2],
        [6.3, 2.8, 5.1, 1.5, 2],
        [6.1, 2.6, 5.6, 1.4, 2],
        [7.7, 3. , 6.1, 2.3, 2],
        [6.3, 3.4, 5.6, 2.4, 2],
        [6.4, 3.1, 5.5, 1.8, 2],
        [6. , 3. , 4.8, 1.8, 2],
        [6.9, 3.1, 5.4, 2.1, 2],
        [6.7, 3.1, 5.6, 2.4, 2],
        [6.9, 3.1, 5.1, 2.3, 2],
        [5.8, 2.7, 5.1, 1.9, 2],
        [6.8, 3.2, 5.9, 2.3, 2],
        [6.7, 3.3, 5.7, 2.5, 2],
        [6.7, 3. , 5.2, 2.3, 2],
        [6.3, 2.5, 5. , 1.9, 2],
        [6.5, 3. , 5.2, 2. , 2],
        [6.2, 3.4, 5.4, 2.3, 2],
        [5.9, 3. , 5.1, 1.8, 2]]
tree = build_tree(dataset)
print_tree(tree)

[X3 < 3.000]
 [X2 < 3.400]
  [X2 < 3.300]
   [0]
   [X1 < 5.100]
    [1]
    [0]
  [0]
 [X4 < 1.800]
  [X3 < 5.000]
   [X4 < 1.700]
    [X3 < 4.200]
     [X3 < 4.100]
      [1]
      [X1 < 5.800]
       [1]
       [1]
     [1]
    [2]
   [X4 < 1.600]
    [2]
    [X1 < 7.200]
     [1]
     [2]
  [X3 < 4.900]
   [2]
   [2]
