In [1]:
from sklearn import datasets
iris = datasets.load_iris()



In [2]:
def uniqueCounts(rows):
    results = {}
    for row in rows:
        r = row[len(row)-1]
        if r not in results:results[r] = 0
        results[r] += 1
    return results

In [3]:
def entropy(rows):
    from math import log
    log2 = lambda x:log(x)/log(2)
    results = uniqueCounts(rows)
    ent = 0.0
    for r in results.keys():
        p = float(results[r])/len(rows)
        ent = ent - p*log2(p)
    return ent

In [5]:
class decisionNode:
    def __init__(self, col = -1,value = None, results = None, tb = None, fb = None):
        self.col = col
        self.value = value
        self.results = results
        self.tb = tb
        self.fb = fb


In [15]:
def gini_impurity(rows):
    total = len(rows)
    counts = uniqueCounts(rows)
    imp = 0
    for k1 in counts.keys():
        p1 = float(counts[k1])/total
        imp += p1*(1 - p1)

In [8]:
def divide_set(rows, column, value):
    split_function =  None
    if isinstance(value, int) or isinstance(value, float):
        split_function = lambda row:row[column] >= value
    else:
        split_function = lambda row:row[column] == value

    set1 = [row for row in rows if split_function(row)]
    set2 = [row for row in rows if not split_function(row)]
    return set1, set2

In [10]:
def build_tree(rows, scoreFunction = entropy):
    if len(rows) == 0: return decisionNode()
    current_score = scoreFunction(rows)

    best_gain = 0.0
    best_criteria = None
    best_sets = None

    column_count = len(rows[0]) - 1
    for col in range(0, column_count):
        column_value = {}
        for row in rows:
            column_value[row[col]] = 1

        for value in column_value.keys():
            set1, set2 = divide_set(rows, col, value)
            p = float(len(set1)/len(rows))
            gain = current_score - p*scoreFunction(set1) - (1-p) * scoreFunction(set2)
            if gain > best_gain and len(set1) > 0 and len(set2) > 0:
                best_gain = gain
                best_criteria = (col, value)
                best_sets = (set1, set2)

    if best_gain > 0:
        trueBranch = build_tree(best_sets[0])
        falseBranch = build_tree(best_sets[1])
        return decisionNode(col = best_criteria[0], value= best_criteria[1], tb = trueBranch, fb = falseBranch)
    else:
        return decisionNode(results = uniqueCounts(rows))


In [27]:
def printTree(tree, indent = ''):
    if tree.results is not None:
        print(str(tree.results))
    else:
        print(str(tree.col)+":"+str(tree.value)+"?")
        print(indent+"T->",end="")
        printTree(tree.tb, indent+"     ")
        print(indent+"F->",end="")
        printTree(tree.fb,indent+"      ")

In [28]:
def classify(observation, tree):
    if tree.results is not None:
        return tree.results
    else:
        v = observation[tree.col]
        branch = None
        if isinstance(v, int) or isinstance(v, float):
            if v >= tree.value: branch = tree.tb
            else: branch = tree.fb
        else:
            if v == tree.value: branch = tree.tb
            else: branch = tree.fb
        return classify(observation, branch)

In [31]:
my_data=[['slashdot','USA','yes',18,'None'],
        ['google','France','yes',23,'Premium'],
        ['digg','USA','yes',24,'Basic'],
        ['kiwitobes','France','yes',23,'Basic'],
        ['google','UK','no',21,'Premium'],
        ['(direct)','New Zealand','no',12,'None'],
        ['(direct)','UK','no',21,'Basic'],
        ['google','USA','no',24,'Premium'],
        ['slashdot','France','yes',19,'None'],
        ['digg','USA','no',18,'None'],
        ['google','UK','no',18,'None'],
        ['kiwitobes','UK','no',19,'None'],
        ['digg','New Zealand','yes',12,'Basic'],
        ['slashdot','UK','no',21,'None'],
        ['google','UK','yes',18,'Basic'],
        ['kiwitobes','France','yes',19,'Basic']]


divide_set(my_data, 2, 'yes')

gini_impurity(my_data)

tree = build_tree(my_data)

printTree(tree)
classify(['google', 'France', 'yes', 23], tree)

0:google?
T->3:21?
     T->{'Premium': 3}
     F->2:no?
           T->{'None': 1}
           F->{'Basic': 1}
F->0:slashdot?
      T->{'None': 3}
      F->2:yes?
            T->{'Basic': 4}
            F->3:21?
                  T->{'Basic': 1}
                  F->{'None': 3}


{'Premium': 3}