In [1]:
#based on http://kldavenport.com/pure-python-decision-trees/
from math import log

In [2]:
def get_result_assignments(data):
    """last column is the class assignment"""
    result = {}
    for row in data:
        if row[-1] in result:
            result[row[-1]] += 1
        else:
            result[row[-1]] = 1
    return result

def get_entropy(data):
    """input is a list of class assignments in a group"""
    assignments = get_result_assignments(data)
    entropy = 0
    for number_assigned in assignments.values():
        entropy -= (number_assigned / len(data)) * log((number_assigned / len(data)))
    return entropy

In [58]:
class DecisionTree:
    def __init__(self):
        self.node = DecisionTreeNode()
        
    def create_tree(self, data):
        self.node = self.train(data)
        
    def train(self, data):
#         print('---- input ----')
#         for row in data:
#             print(row)
        print('----       ----')
        if len(data) == 0: return DecisionTreeNode()
        current_entropy = get_entropy(data)
        best_info_gain = 0
        separating_var = None
        separating_value = None
        best_left = None
        best_right = None
        for variable in range(len(data[0])-1):
            values = set([row[variable] for row in data])
            for value in values:
                left, right = self.try_split(variable, value, data)
                p = len(left) / len(data)
                info_gain = current_entropy - p*get_entropy(left) - (1-p)*get_entropy(right)
                if info_gain > best_info_gain and len(left) > 0 and len(right) > 0:
                    best_info_gain = info_gain
                    separating_var = variable
                    separating_value = value
                    best_left = left
                    best_right = right
        if best_info_gain > 0:
#             print('best info gain:' + str(best_info_gain))
            print('separating variable:' + str(separating_var))
            print('separating value:' + str(separating_value))
#             print('left')
#             for row in best_left:
#                 print(row)
#             print('right')
#             for row in best_right:
#                 print(row)
            left = self.train(best_left)
            right = self.train(best_right)
            return DecisionTreeNode(separating_var, separating_value, left=left, right=right)
        else:
            return DecisionTreeNode(result=get_result_assignments(data))
        
    def try_split(self, variable, split_value, data):
        left = []
        right = []
        for row in data:
            if isinstance(row[variable], int) or isinstance(row[variable], float):
                if row[variable] < split_value: 
                    left.append(row)
                else: 
                    right.append(row)
            else:
                if not row[variable] == split_value: 
                    left.append(row)
                else: 
                    right.append(row)
        return left, right
       
class DecisionTreeNode:
    """Trees are made up recursively using these nodes"""
    def __init__(self, variable=None, value=None, result=None, left=None, right=None):
        self.variable = variable #the variable used to split the tree
        self.value = value #the value that splits the tree
        self.left = left #branch
        self.right = right #branch
        self.result = result #data assignments (only in leaf node)
        self.indent = ''
        
    def print_tree(self, indent=''):
        if self.result:
            print(indent + str(self.result))
        else:
            print(indent + 'Variable ' + str(self.variable)+' : '+str(self.value)+'?')
            print(indent + 'False')
            self.left.print_tree(indent+'   ')
            print(indent + 'True')
            self.right.print_tree(indent+'   ')

In [59]:
my_data=[['slashdot','USA','yes',18,'None'],
        ['google','France','yes',23,'Premium'],
        ['reddit','USA','yes',24,'Basic'],
        ['kiwitobes','France','yes',23,'Basic'],
        ['google','UK','no',21,'Premium'],
        ['(direct)','New Zealand','no',12,'None'],
        ['(direct)','UK','no',21,'Basic'],
        ['google','USA','no',24,'Premium'],
        ['slashdot','France','yes',19,'None'],
        ['reddit','USA','no',18,'None'],
        ['google','UK','no',18,'None'],
        ['kiwitobes','UK','no',19,'None'],
        ['reddit','New Zealand','yes',12,'Basic'],
        ['slashdot','UK','no',21,'None'],
        ['google','UK','yes',18,'Basic'],
        ['kiwitobes','France','yes',19,'Basic']]

In [60]:
dt = DecisionTree()
dt.create_tree(my_data)
print()
dt.node.print_tree()

----       ----
separating variable:0
separating value:google
----       ----
separating variable:0
separating value:slashdot
----       ----
separating variable:2
separating value:no
----       ----
----       ----
separating variable:3
separating value:21
----       ----
----       ----
----       ----
----       ----
separating variable:3
separating value:21
----       ----
separating variable:2
separating value:no
----       ----
----       ----
----       ----

Variable 0 : google?
False
   Variable 0 : slashdot?
   False
      Variable 2 : no?
      False
         {'Basic': 4}
      True
         Variable 3 : 21?
         False
            {'None': 3}
         True
            {'Basic': 1}
   True
      {'None': 3}
True
   Variable 3 : 21?
   False
      Variable 2 : no?
      False
         {'Basic': 1}
      True
         {'None': 1}
   True
      {'Premium': 3}
