In [28]:
'''
Program implementing a basic decision tree in python
It was created using the tutorial article found on:
https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/
with some personal modifications
'''

'\nProgram implementing a basic decision tree in python\nIt was created using the tutorial article found on:\nhttps://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/\nwith some personal modifications\n'

In [29]:
'''
The gini index will score the value of a split and whether or not is is optimal
Ideally we are looking for a 50:50 split which will result in a 0.00 gini score
'''

#Calculate the Gini index
def gini_index(groups, classes):
    n_instances = float(sum([len(group) for group in groups]))
    gini = 0
    for group in groups:
        size = float(len(group))
        if size == 0:
            continue
        score = 0.0
        for class_val in classes:
            p = [row[-1] for row in group].count(class_val) / size
            score += p**2
        
        gini += (1.0 - score) * (size/n_instances)
        
    return gini

In [30]:
'''
Now we implement a splitting fuction
'''
#Returns a tuple comprised of two lists with datapoints
#Here the index equal to the split value goes to the right
def split_set(index, value, dataset):
    left, right = list(), list()
    for row in dataset:
        if row[index] < value:
            left.append(row)
        else:
            right.append(row)
    return left, right

In [31]:
'''
This function evalueates each attribute except the value for split potential
'''
def find_split(dataset):
    class_values = list(set(row[-1] for row in dataset))
    b_index, b_value, b_score, b_groups = 999, 999, 999, None
    for index in range(len(dataset[0])-1):
        for row in dataset:
            groups = split_set(index, row[index], dataset)
            gini = gini_index(groups, class_values)
            if gini < b_score:
                b_index, b_value, b_score, b_groups = index, row[index], gini, groups
    
    return {'index':b_index, 'value':b_value, 'groups':b_groups}

In [32]:
'''
We introduce a random dataset to test out the algorithm
'''
import random
dataset = []
datapoints = 15
classes = 2
attributes = 3

for x in range(datapoints):
    row = []
    for j in range(attributes):
        row.append(random.uniform(0.5, 10.0))
        
    #Adding class value
    row.append(random.randint(0, classes-1))
    dataset.append(row)

#Below is the generated dataset
#print(dataset)
dataset

[[3.428984294533204, 7.385083710483256, 5.417881167941312, 0],
 [5.616638138199577, 6.554369248324911, 2.6846458607758312, 1],
 [6.172563587534302, 8.362288111406883, 5.632005019980435, 0],
 [1.8230502672629383, 6.344209080416231, 8.843893986598468, 0],
 [9.15487200436044, 7.4896112524251, 9.009086770033141, 0],
 [2.7834909038819564, 4.0460027131722995, 9.377116692427581, 0],
 [3.7752896031615406, 9.890420015225411, 5.898878869994587, 1],
 [5.12735876526717, 4.313469728779825, 7.139712916904053, 1],
 [7.861306884290304, 8.631127235615136, 5.364843049028224, 1],
 [6.469457527556306, 6.540916474836575, 1.3843074930067685, 0],
 [8.19385794351623, 9.950020442882531, 3.618234954023822, 1],
 [4.910023923146414, 0.9246171523618356, 1.8468885669330375, 1],
 [5.539796488401742, 5.068874023995933, 2.060027944965388, 0],
 [1.2653549176677839, 5.903478092121263, 4.422688186647855, 0],
 [6.035117321743658, 7.0377429608677975, 4.413442382208564, 0]]

In [33]:
'''
We perform the split
Our algorithm will parse every 'axis' of the dataset to find the best solution
'''
split = find_split(dataset)
print('Split: [X%d < %.3f]' % ((split['index']+1), split['value']))

Split: [X2 < 8.631]


In [34]:
'''
Now we move on to creating the tree
We will look out for two important factors:
A. Respect the maximum tree depth to not overfit the data
B. Set a minimum amount of training patterns for each node, once at or below, we will stop expanding it
'''

'\nNow we move on to creating the tree\nWe will look out for two important factors:\nA. Respect the maximum tree depth to not overfit the data\nB. Set a minimum amount of training patterns for each node, once at or below, we will stop expanding it\n'

In [35]:
'''
Since we have to do some rounding up on terminal nodes (the last nodes in a branch)
This function approximates the common class value for a list of rows
'''
def to_terminal(group):
    outcomes = [row[-1] for row in group]
    return max(set(outcomes), key=outcomes.count)

In [44]:
'''
We use recursive splitting to populate our nodes
'''
def rec_split(node, max_depth, min_size, depth):
    left, right = node['groups']
    del(node['groups'])
    
    #If no split
    if not left or not right:
        node['left'] = node['right'] = to_terminal(left+right)
        return
    
    #If max depth reached
    if depth >= max_depth:
        node['left'], node['right'] = to_terminal(left), to_terminal(right)
        return
    
    #process left child
    if len(left) <= min_size:
        node['left'] = to_terminal(left)
    else:
        node['left'] = find_split(left)
        rec_split(node['left'], max_depth, min_size, depth+1)
        
    #process right child
    if len(right) <= min_size:
        node['right'] = to_terminal(right)
    else:
        node['right'] = find_split(right)
        rec_split(node['right'], max_depth, min_size, depth+1)

In [45]:
'''
Build and print tree method
'''
def build_tree(train, max_depth, min_size):
    root = find_split(train)
    rec_split(root, max_depth, min_size, 1)
    return root

def print_tree(node, depth=0):
    if isinstance(node, dict):
        print('%s[X%d < %.3f]' % ((depth*' ', (node['index']+1), node['value'])))
        print_tree(node['left'], depth+1)
    else:
        print('%s[%s]' % ((depth*' ', node)))

In [46]:
'''
Now we can build a decision tree and get better splits
'''
tree = build_tree(dataset, 5, 1)
print_tree(tree)

[X2 < 8.631]
 [X2 < 5.069]
  [X1 < 4.910]
   [0]
   [X1 < 5.127]
    [1]
    [1]
  [X3 < 4.413]
   [X2 < 6.554]
    [X1 < 6.469]
     [0]
     [0]
    [1]
   [X1 < 3.429]
    [X1 < 1.823]
     [0]
     [0]
    [X1 < 3.429]
     [0]
     [0]
 [X1 < 3.775]
  [1]
  [1]
