In [77]:
import pandas as pd
import numpy as np
import math
import copy

In [78]:
class Node():
    def __init__(self, value = None, feature = None, terminal = False):
        self.feature = feature #feature col
        self.value = value # value at which to split at
        self.terminal = terminal
        self.left = None
        self.right = None
        self.categorization = ""
    def __str__(self):
        result ="\n".join(["Feature: "+ str(self.feature),"Value: " + str(self.value),"Category: "+ str(self.categorization), "Terminal: " + str(self.terminal)]) 
        if self.left is not None and not self.left.terminal:
            result = result + "\nLeft: Exists" 
        else:
            left = "\nLeft :Terminal"
            result = result + left

        if self.right is not None and not self.right.terminal:
            result = result + "\nRight: Exists" 
        else:
            right = "\nRight :Terminal"
            result = result + right
        return result
    
    def copy(self):
        s = copy.deepcopy(self)
        return s



In [79]:
def entropy(target_col):
    elements,counts = np.unique(target_col,return_counts = True)
    sum = np.sum(counts)  
    return np.sum([(-counts[i]/sum)*np.log2(counts[i]/sum) for i in range(len(elements))])

In [80]:
def find_split(data,features,target):
    split = None
    best_gain = 0
    for x in features:
        (max_gain, best_split) = split_for_max_gain(data, x, target)
        if max_gain > best_gain:
            best_gain = max_gain
            split = (x,best_split)
            
    return Node(split[1], split[0])

In [81]:
def split_for_max_gain(data, feature, target_name):      
    sorted = data.sort_values(feature)
    values = np.unique(sorted[feature])
    splits = [(values[x] + values[x+1])/2 for x in range(len(values)-1)]
    max_gain = 0
    split = 0
    for x in splits:
        gain = info_gain_split(data, x, feature,target_name)
        if gain > max_gain:
            max_gain = gain
            split = x
    return (max_gain,split)

In [82]:
def info_gain_split(data, split_value, split_feature, target_feature):
    
    total_entropy = entropy(data[target_feature])

    above = []
    below = []

    for x in data[split_feature]:
        if x >= split_value:
            above.append(x)
        else:
            below.append(x)

    total = len(above) + len(below)
    
    entropy_above = (len(above)/total)*entropy(data.where(data[split_feature] > split_value).dropna()[target_feature])
    entropy_below = (len(below)/total)*entropy(data.where(data[split_feature] < split_value).dropna()[target_feature])

    weighted_entropy = entropy_above + entropy_below
    return total_entropy - weighted_entropy


In [83]:
def no_possible_splits(data, features):
    for feature in features:
        if(len(np.unique(data[feature])) > 1):
            return False
    return True    

In [84]:
def ID3(data, originaldata, features, target):
    
    
    if(no_possible_splits(data, features)):
        node = Node(terminal=True)
        vals, counts = np.unique(data[target], return_counts = True)
        category_index = 0
        for i in range(len(vals)):
            if(counts[i] > counts[category_index]):
                category_index = i
        node.categorization = vals[category_index]
        return node
        
        
    if len(np.unique(data[target])) == 1:
        node = Node(terminal=True)
        node.categorization = np.unique(data[target])[0]
        return node
    
    root = find_split(data, features, target)
    
    # grabs all rows where split feautre is below
    left_data = data.where(data[root.feature] < root.value).dropna()
    
    # grabs all rows where split feautre is above
    right_data = data.where(data[root.feature] > root.value).dropna()
    
    root.left = ID3(left_data, originaldata, features, target)
    root.right = ID3(right_data, originaldata, features, target)
    
    return root
      

In [85]:
data = pd.read_csv('/Users/elidangerfield/Documents/school/CSCI4350/OLA3/iris-data.txt', sep=" ", header=None)
features = data.columns[:-1]
target = data.columns[-1:]

In [86]:
node = ID3(data, data, features, target)

In [87]:
print(node.right.right.right)

Feature: None
Value: None
Category: 2.0
Terminal: True
Left :Terminal
Right :Terminal


In [88]:
test_data = pd.read_csv('/Users/elidangerfield/Documents/school/CSCI4350/OLA3/head_random_iris.txt', sep=" ", header=None)
test_data


Unnamed: 0,0,1,2,3,4
0,5.4,3.4,1.5,0.4,0
1,6.4,3.1,5.5,1.8,2
2,4.8,3.4,1.6,0.2,0
3,6.7,2.5,5.8,1.8,2
4,4.6,3.6,1.0,0.2,0
5,5.8,2.8,5.1,2.4,2
6,5.4,3.9,1.7,0.4,0
7,6.0,2.2,5.0,1.5,2
8,6.0,2.9,4.5,1.5,1
9,6.1,3.0,4.6,1.4,1


In [90]:
def test(test_data, target, tree):
    num_correct = 0
    for i in range(len(test_data)):
        current = tree.copy()
        while not current.terminal:
            if test_data.iloc[i][current.feature] >= current.value:
                # go right
                current = current.right
            else:
                # go left
                current = current.left
        if test_data.iloc[i].values[target] == current.categorization:
            num_correct +=1
            
    return num_correct