# Decision tree implementation


Importing libraries and defining a binary class decision tree with both true and false branch

In [86]:

import numpy as np
import matplotlib.pyplot as plt
import collections
import csv
from collections import defaultdict

class DecisionTree:
    #Binary tree implementation with true and false branch.
    def __init__(self, col=-1, value=None, trueBranch=None, falseBranch=None, results=None):
        self.col = col
        self.value = value
        self.trueBranch = trueBranch
        self.falseBranch = falseBranch
        self.results = results
        
       
        
        
        


Defined a function shuffle which is used to shuffle the data to 70 % training and 30% testing

In [87]:
def shuffle(data):
    np.random.seed(2)
    np.random.shuffle(data)
    train_data = data[:int(0.7*len(data))]
    test_data = data[int(0.7*len(data)):]
    return train_data, test_data

Defined the divideset function to split the dataset and returning them to two lists 

In [88]:

def divideSet(rows, column, value):
    splittingFunction = None
    if isinstance(value, int) or isinstance(value, float): # for int and float values
        splittingFunction = lambda row : row[column] >= value
    else: # for strings 
        splittingFunction = lambda row : row[column] == value
    list1 = [row for row in rows if splittingFunction(row)]
    list2 = [row for row in rows if not splittingFunction(row)]
    return (list1, list2)

In [89]:
def uniqueCounts(rows):
    results = {}
    for row in rows:
        r = row[-1]
        if r not in results: results[r] = 0
        results[r] += 1
    return results

Defined the entropy splitting criterion function which calculates the information gain value 

In [90]:

def entropy(rows):
    from math import log
    log2 = lambda x: log(x)/log(2)
    results = uniqueCounts(rows)

    entr = 0.0
    for r in results:
        p = float(results[r])/len(rows)
        entr -= p*log2(p)
      
    return entr


Defined the gini splitting criterion function which calculates the gini index value 

In [91]:

def gini(rows):
    total = len(rows)
    counts = uniqueCounts(rows)
    imp = 0.0

    for k1 in counts:
        p1 = float(counts[k1])/total  
        for k2 in counts:
            if k1 == k2: continue
            p2 = float(counts[k2])/total
            imp += p1*p2
    
    return imp

In [92]:
def variance(rows):
    if len(rows) == 0: return 0
    data = [float(row[len(row) - 1]) for row in rows]
    mean = sum(data) / len(data)

    variance = sum([(d-mean)**2 for d in data]) / len(data)
    return variance

defined growDecisionTreeFrom function which is used to grow the binary decision tree using the evaluation function as entropy.
Here we calculate the entropy gain and compare that with all splitting combinations and the best information gain is chose based on which the true and false branch are returned to the decision tree class.

In [93]:
def growDecisionTreeFrom(rows, evaluationFunction=entropy):
    """Grows and then returns a binary decision tree.
    evaluationFunction: entropy or gini""" 
    if len(rows) == 0: return DecisionTree()
    currentScore = evaluationFunction(rows)
    bestGain = 0.0
    bestAttribute = None
    bestSets = None
    columnCount = len(rows[0]) - 1  # last column is the result/target column
    for col in range(0, columnCount):
        columnValues = [row[col] for row in rows]
        for value in columnValues:
            (set1, set2) = divideSet(rows, col, value)
            p = float(len(set1)) / len(rows)
            gain = currentScore - p*evaluationFunction(set1) - (1-p)*evaluationFunction(set2)
            if gain>bestGain and len(set1)>0 and len(set2)>0:
                bestGain = gain
                bestAttribute = (col, value)
                bestSets = (set1, set2)
    if bestGain > 0:
        trueBranch = growDecisionTreeFrom(bestSets[0])
        falseBranch = growDecisionTreeFrom(bestSets[1])
        
        return DecisionTree(col=bestAttribute[0], value=bestAttribute[1], trueBranch=trueBranch, falseBranch=falseBranch)
    else:
        return DecisionTree(results=uniqueCounts(rows))

defined growDecisionTreeFrom_gini function which is used to grow the binary decision tree using the evaluation function as gini.Here we calculate the gini index and compare that with all splitting combinations and the best gini gain is chose based on which the true and false branch are returned to the decision tree class.

In [94]:
def growDecisionTreeFrom_gini(rows, evaluationFunction=gini):
    """Grows and then returns a binary decision tree.
    evaluationFunction: entropy or gini""" 
    if len(rows) == 0: return DecisionTree()
    currentScore = evaluationFunction(rows)
    bestGain = 0.0
    bestAttribute = None
    bestSets = None
    columnCount = len(rows[0]) - 1  # last column is the result/target column
    for col in range(0, columnCount):
        columnValues = [row[col] for row in rows]
        for value in columnValues:
            (set1, set2) = divideSet(rows, col, value)
            p = float(len(set1)) / len(rows)
            gain = currentScore - p*evaluationFunction(set1) - (1-p)*evaluationFunction(set2)
            if gain>bestGain and len(set1)>0 and len(set2)>0:
                bestGain = gain
                bestAttribute = (col, value)
                bestSets = (set1, set2)
    if bestGain > 0:
        trueBranch = growDecisionTreeFrom_gini(bestSets[0])
        falseBranch = growDecisionTreeFrom_gini(bestSets[1])
       
       
        return DecisionTree(col=bestAttribute[0], value=bestAttribute[1], trueBranch=trueBranch, falseBranch=falseBranch)
    else:
        return DecisionTree(results=uniqueCounts(rows))

defined prune function which using mini gain as the parameter using which it prunes the decision tree based on the evaluation function as entropy.

In [95]:
def prune(tree, minGain, evaluationFunction=entropy, notify=True):
    """Prunes the obtained tree according to the minimal gain (entropy or Gini). """
    # recursive call for each branch
    if tree.trueBranch.results == None: prune(tree.trueBranch, minGain, evaluationFunction, notify)
    if tree.falseBranch.results == None: prune(tree.falseBranch, minGain, evaluationFunction, notify)
     # merge leaves (potentionally)
    if tree.trueBranch.results != None and tree.falseBranch.results != None:
        tb, fb = [], []
        for v, c in tree.trueBranch.results.items(): tb += [[v]] * c
        for v, c in tree.falseBranch.results.items(): fb += [[v]] * c
        p = float(len(tb)) / len(tb + fb)
        delta = evaluationFunction(tb+fb) - p*evaluationFunction(tb) - (1-p)*evaluationFunction(fb)
        if delta < minGain:	
            if notify: print('A branch was pruned: gain = %f' % delta)
            tree.trueBranch, tree.falseBranch = None, None
            tree.results = uniqueCounts(tb + fb)

defined prunegini function which using mini gain as the parameter using which it prunes the decision tree based on the evaluation function as gini.

In [96]:
def prune_gini(tree, minGain, evaluationFunction=gini, notify=True):
    """Prunes the obtained tree according to the minimal gain (entropy or Gini). """
    # recursive call for each branch
    if tree.trueBranch.results == None: prune_gini(tree.trueBranch, minGain, evaluationFunction, notify)
    if tree.falseBranch.results == None: prune_gini(tree.falseBranch, minGain, evaluationFunction, notify)
     # merge leaves (potentionally)
    if tree.trueBranch.results != None and tree.falseBranch.results != None:
        tb, fb = [], []
        for v, c in tree.trueBranch.results.items(): tb += [[v]] * c
        for v, c in tree.falseBranch.results.items(): fb += [[v]] * c
        p = float(len(tb)) / len(tb + fb)
        delta = evaluationFunction(tb+fb) - p*evaluationFunction(tb) - (1-p)*evaluationFunction(fb)
        if delta < minGain:	
            if notify: print('A branch was pruned: gain = %f' % delta)
            tree.trueBranch, tree.falseBranch = None, None
            tree.results = uniqueCounts(tb + fb)

The function classify is used to classify the test data set with the training tree and observes if easch data in the data set is classified in the right way or not and returns the observation along  with the branch.

In [97]:
def classify(observations, tree, dataMissing=False):
    """Classifies the observationss according to the tree.
     dataMissing: true or false if data are missing or not. """

    def classifyWithoutMissingData(observations, tree):
        if tree.results != None:  # leaf
            return observations[4]
        else:
            v = observations[tree.col]
            branch = None
            if isinstance(v, int) or isinstance(v, float):
                if v >= tree.value: branch = tree.trueBranch
                else: branch = tree.falseBranch
            else:
                if v == tree.value: branch = tree.trueBranch
                else: branch = tree.falseBranch
        return classifyWithoutMissingData(observations,branch)


    def classifyWithMissingData(observations, tree):
        if tree.results != None:  # leaf 
            return observations[4]
        else:
            v = observations[tree.col]
            if v == None:
                tr = classifyWithMissingData(observations, tree.trueBranch)
                fr = classifyWithMissingData(observations, tree.falseBranch)
                tcount = sum(tr.values())
                fcount = sum(fr.values())
                tw = float(tcount)/(tcount + fcount)
                fw = float(fcount)/(tcount + fcount)
                result = collections.defaultdict(int) # Problem description: http://blog.ludovf.net/python-collections-defaultdict/
                for k, v in tr.items(): result[k] += v*tw
                for k, v in fr.items(): result[k] += v*fw
                return dict(result)
            else:
                branch = None
                if isinstance(v, int) or isinstance(v, float):
                    if v >= tree.value: branch = tree.trueBranch
                    else: branch = tree.falseBranch
                else:
                    if v == tree.value: branch = tree.trueBranch
                    else: branch = tree.falseBranch
            return classifyWithMissingData(observations, branch)

    # function body
    if dataMissing: 
        return classifyWithMissingData(observations, tree)
    else:
        return classifyWithoutMissingData(observations, tree)

Defined plot function to plot the training dataset set in the form of a binary tree. the best gain is considered as the root node and based on which we classify the true and false branch and such binary trees of many depths are constructed in the similar way.

In [98]:
def plot(decisionTree):
    """Plots the obtained decision tree. """
    def toString(decisionTree, indent=''):
        if decisionTree.results != None:  # leaf node
            return str(decisionTree.results)
        else:
            szCol = 'Column %s' % decisionTree.col
            if szCol in dcHeadings:
                szCol = dcHeadings[szCol]
            if isinstance(decisionTree.value, int) or isinstance(decisionTree.value, float):
                decision = '%s >= %s?' % (szCol, decisionTree.value)
            else:
                decision = '%s == %s?' % (szCol, decisionTree.value)
            trueBranch = indent + 'yes -> ' + toString(decisionTree.trueBranch, indent + '\t\t')
            falseBranch = indent + 'no  -> ' + toString(decisionTree.falseBranch, indent + '\t\t')
            
            return (decision + '\n' + trueBranch + '\n' + falseBranch)

    print(toString(decisionTree))

In [99]:

def loadCSV(file):
    """Loads a CSV file and converts all floats and int into basic datatypes.""" 
    def convertTypes(s):
        s = s.strip()
        try:
            return float(s) if '.' in s else int(s)
        except ValueError:
            return s
    reader = csv.reader(open(file, 'rt'))
    dcHeader = {}
    if bHeader:
        lsHeader = next(reader)
        for i, szY in enumerate(lsHeader):
                szCol = 'Column %d' % i
                dcHeader[szCol] = str(szY)
    return dcHeader, [[convertTypes(item) for item in row] for row in reader]

Loading the iris dataset and plotting the training data set to a binary tree using evaluation function as entropy and calculated the misclassification rate.Pruning results are also calculated on entropy tree.

In [100]:
if __name__ == '__main__':

    example = 1
    if example == 1:
        bHeader = True
        count = 0
        classified_resultlist=[]
        classified_resultlist_gini=[]
        
        # the bigger example
        dcHeadings,Data = loadCSV('C:\Users\saikiran\Desktop\iris.txt') # demo data from matlab
        train_dataset, test_dataset = shuffle(Data)
        print(len(train_dataset))
        print(len(test_dataset))
        decisionTree = growDecisionTreeFrom(train_dataset)
       
        plot(decisionTree)
       
        
        test = np.asarray(test_dataset)
        result1 = test[:,4]
        print("--------------------------------------------------------------------------------------------------------")
        print "test data before classification:", result1
        for i in test_dataset:
            classified_result= classify(i, decisionTree) # dataMissing=False is the default setting
            classified_resultlist = np.append(classified_resultlist,classified_result)
       
        print "test_data after classification:", classified_resultlist
        print("--------------------------------------------------------------------------------------------------------")
        print "misclassification rate", count
        
        # notify, when a branch is pruned 
        prune(decisionTree, 0.9, notify=True) 
        plot(decisionTree)
        
       
       
        
      
        
        

            
       
        
        

105
45
PetalLength >= 3?
yes -> PetalLength >= 4.8?
		yes -> PetalWidth >= 1.8?
				yes -> {'virginica': 33}
				no  -> PetalWidth >= 1.7?
						yes -> {'versicolor': 1}
						no  -> {'virginica': 3}
		no  -> PetalWidth >= 1.7?
				yes -> {'virginica': 1}
				no  -> {'versicolor': 30}
no  -> {'setosa': 37}
--------------------------------------------------------------------------------------------------------
test data before classification: ['setosa' 'setosa' 'versicolor' 'virginica' 'versicolor' 'virginica'
 'versicolor' 'versicolor' 'virginica' 'versicolor' 'versicolor'
 'versicolor' 'virginica' 'versicolor' 'virginica' 'versicolor' 'setosa'
 'versicolor' 'versicolor' 'versicolor' 'versicolor' 'virginica'
 'versicolor' 'setosa' 'setosa' 'virginica' 'versicolor' 'virginica'
 'setosa' 'virginica' 'setosa' 'virginica' 'virginica' 'setosa'
 'versicolor' 'setosa' 'virginica' 'versicolor' 'setosa' 'virginica'
 'versicolor' 'setosa' 'setosa' 'versicolor' 'setosa']
test_data after classific

Loading the iris dataset and plotting the training data set to a binary tree using evaluation function as gini and calculated the misclassification rate. Pruning results are also calculated on gini tree

In [101]:
if __name__ == '__main__':

    example = 2
    if example == 2:
        bHeader = True
        count = 0
        
        classified_resultlist_gini=[]
        
        # the bigger example
        dcHeadings,Data = loadCSV('C:\Users\saikiran\Desktop\iris.txt') # demo data from matlab
        train_dataset, test_dataset = shuffle(Data)
        print(len(train_dataset))
        print(len(test_dataset))
       
        test = np.asarray(test_dataset)
        result1 = test[:,4]
       
        decisionTree_gini = growDecisionTreeFrom_gini(train_dataset)
        plot(decisionTree_gini)
        print("--------------------------------------------------------------------------------------------------------")
        print "test data before classification:", result1
        for j in test_dataset:
            classified_result_gini= classify(j, decisionTree_gini) # dataMissing=False is the default setting
            classified_resultlist_gini = np.append(classified_resultlist_gini,classified_result_gini)
        print "test_data after classification:", classified_resultlist_gini
        print("--------------------------------------------------------------------------------------------------------")
        print "misclassification rate", count
        print("--------------------------------------------------------------------------------------------------------")
        prune_gini(decisionTree_gini, 0.4, notify=True) 
        plot(decisionTree_gini)
        
      
        
        

            
       
        
        

105
45
PetalLength >= 3?
yes -> PetalLength >= 4.8?
		yes -> PetalWidth >= 1.8?
				yes -> {'virginica': 33}
				no  -> PetalWidth >= 1.7?
						yes -> {'versicolor': 1}
						no  -> {'virginica': 3}
		no  -> PetalWidth >= 1.7?
				yes -> {'virginica': 1}
				no  -> {'versicolor': 30}
no  -> {'setosa': 37}
--------------------------------------------------------------------------------------------------------
test data before classification: ['setosa' 'setosa' 'versicolor' 'virginica' 'versicolor' 'virginica'
 'versicolor' 'versicolor' 'virginica' 'versicolor' 'versicolor'
 'versicolor' 'virginica' 'versicolor' 'virginica' 'versicolor' 'setosa'
 'versicolor' 'versicolor' 'versicolor' 'versicolor' 'virginica'
 'versicolor' 'setosa' 'setosa' 'virginica' 'versicolor' 'virginica'
 'setosa' 'virginica' 'setosa' 'virginica' 'virginica' 'setosa'
 'versicolor' 'setosa' 'virginica' 'versicolor' 'setosa' 'virginica'
 'versicolor' 'setosa' 'setosa' 'versicolor' 'setosa']
test_data after classific