In [110]:
import csv
import random
import numpy as np
import math

In [20]:
def pluralityValue(data, col=-1):
    counters = {}
    for idx in range(len(data)):
        if data[idx][col] in counters.keys():
            counters[data[idx][col]] += 1
        else:
            counters[data[idx][col]] = 1
    max_count = 0
    for val in counters.keys():
        if counters[val] > max_count:
            max_val = val
            max_count = counters[val]
    return max_val

In [23]:
def read_datafile(name, delimiter=','):
    with open(name, 'rb') as csvfile:
        reader = csv.reader(csvfile, delimiter=delimiter)
        data = []
        for row in reader:
            data.append(row)
    return data

In [167]:
class Variable:
    def __init__(self, name, domain, idx):
        self.name = name
        self.domain = domain
        self.idx = idx
        
class Problem:
    def __init__(self):
        self.variables = []
    
    def add_variable(self, variable):
        self.variables.append(variable)
        
class DecisionTree:
    def __init__(self, variable=None, value=None):
        self.value = value
        self.variable = variable
        self.children = {}
        
    def dump(self, indent=0):
        if self.value != None:
            print ' '*indent + self.value
        else:
            print ' '*indent + self.variable.name
            for val in self.variable.domain:
                print ' '*indent + val
                self.children[val].dump(indent=indent+5)
    
    def classify(self, data):
        output = []
        for d in data:
            output.append(self.classifyOne(d))
        return output
            
    def classifyOne(self, data):
        if self.value != None:
            return self.value
        else:
            return self.children[data[self.variable.idx]].classifyOne(data)
            

def listWithout(lst, element):
    tmp = list(lst)
    tmp.remove(element)
    if tmp != None:
        return tmp
    return []
        
def learnTree(data, variables, parent_data):
    if len(data) == 0:
        return DecisionTree(value=pluralityValue(parent_data))
    if len(set([x[-1] for x in data])) == 1:
        return DecisionTree(value=data[0][-1])
    if len(variables) == 0:
        return DecisionTree(value=pluralityValue(data))
        
    #importance sampling
    feature = mostImportantFeature(data, variables)
        
    tree = DecisionTree(variable=feature)
    for val in feature.domain:
        #filter data
        exs = [x for x in data if x[feature.idx] == val]
        #create subtree
        subtree = learnTree(exs, listWithout(variables, feature), data)
        #add branch
        tree.children[val] = subtree
    return tree

def mostImportantFeature(data, variables):
    #method one vs all for each outcome
    splits_gain = {}
    outcomes = set([x[-1] for x in data])
    for var in variables:
        for out in outcomes:
            p,n = pnSamples(data, -1, out)
            remainder = 0
            for val in var.domain:
                pv,nv = pnSamples([x for x in data if x[var.idx] == val], -1, out)
                if pv+nv > 0:
                    remainder += float(pv+nv)/(p+n) * B(float(pv)/(pv+nv))
            splits_gain[(var,out)] = B(float(p)/(p+n)) - remainder
    
    #print "gain for ", len(variables)
    #for k in splits_gain.keys():
    #    print k[0].name, k[1], splits_gain[k]
    
    maxkeys = []
    for key in splits_gain.keys():
        if splits_gain[key] == max(splits_gain.values()):
            maxkeys.append(key[0])
    return maxkeys[random.randint(0, len(maxkeys) - 1)]
            
def B(x):
    if x == 1 or x == 0:
        return 0
    return -(x * math.log(x,2) + (1-x) * math.log(1-x,2))

def pnSamples(data, idx, pos_val):
    p = len([x for x in data if x[idx] == pos_val])
    n = len(data) - p
    return p,n

In [170]:
data = read_datafile('WillWait-data.txt')
alt = Variable("Alternate", list(set([x[0] for x in data])), 0)
bar = Variable("Bar", list(set([x[1] for x in data])), 1)
fri = Variable("Fri/Sat", list(set([x[2] for x in data])), 2)
hun = Variable("Hungry", list(set([x[3] for x in data])), 3)
pat = Variable("Patrons", list(set([x[4] for x in data])), 4)
pri = Variable("Price", list(set([x[5] for x in data])), 5)
rai = Variable("Raining", list(set([x[6] for x in data])), 6)
res = Variable("Reservation", list(set([x[7] for x in data])), 7)
typ = Variable("Type", list(set([x[8] for x in data])), 8)
wai = Variable("WaitEstimate", list(set([x[9] for x in data])), 9)
T = learnTree(data, [alt, bar, fri, hun, pat, pri, rai, res, typ, wai], [])

In [171]:
T.dump()

Patrons
None
     No
Full
     Hungry
     Yes
          Type
          Burger
               Yes
          Thai
               Fri/Sat
               Yes
                    Yes
               No
                    No
          French
               Yes
          Italian
               No
     No
          No
Some
     Yes


In [136]:
data = read_datafile('iris.data.txt')
a = Variable("a", list(set([x[0] for x in data])), 0)
b = Variable("b", list(set([x[1] for x in data])), 1)
c = Variable("c", list(set([x[2] for x in data])), 2)
d = Variable("d", list(set([x[3] for x in data])), 3)
T = learnTree(data, [a, b, c, d], [])

In [104]:
data

[['5.1', '3.5', '1.4', '0.2', 'Iris-setosa'],
 ['4.9', '3.0', '1.4', '0.2', 'Iris-setosa'],
 ['4.7', '3.2', '1.3', '0.2', 'Iris-setosa'],
 ['4.6', '3.1', '1.5', '0.2', 'Iris-setosa'],
 ['5.0', '3.6', '1.4', '0.2', 'Iris-setosa'],
 ['5.4', '3.9', '1.7', '0.4', 'Iris-setosa'],
 ['4.6', '3.4', '1.4', '0.3', 'Iris-setosa'],
 ['5.0', '3.4', '1.5', '0.2', 'Iris-setosa'],
 ['4.4', '2.9', '1.4', '0.2', 'Iris-setosa'],
 ['4.9', '3.1', '1.5', '0.1', 'Iris-setosa'],
 ['5.4', '3.7', '1.5', '0.2', 'Iris-setosa'],
 ['4.8', '3.4', '1.6', '0.2', 'Iris-setosa'],
 ['4.8', '3.0', '1.4', '0.1', 'Iris-setosa'],
 ['4.3', '3.0', '1.1', '0.1', 'Iris-setosa'],
 ['5.8', '4.0', '1.2', '0.2', 'Iris-setosa'],
 ['5.7', '4.4', '1.5', '0.4', 'Iris-setosa'],
 ['5.4', '3.9', '1.3', '0.4', 'Iris-setosa'],
 ['5.1', '3.5', '1.4', '0.3', 'Iris-setosa'],
 ['5.7', '3.8', '1.7', '0.3', 'Iris-setosa'],
 ['5.1', '3.8', '1.5', '0.3', 'Iris-setosa'],
 ['5.4', '3.4', '1.7', '0.2', 'Iris-setosa'],
 ['5.1', '3.7', '1.5', '0.4', 'Iri

In [109]:
T.classify(data[1:3])

['Iris-setosa', 'Iris-setosa']

In [114]:
math.log(4,2)

2.0

In [118]:
0.0 + float(2)/3

0.6666666666666666

In [119]:
t = {}

In [120]:
t[(1,2)] = 3

In [144]:
print mostImportantFeature(data, [a, b, c, d]).name

d


In [132]:
math.log(0,2)

ValueError: math domain error