In [345]:
import pandas as pd
import numpy as np
import math

In [375]:
class Node():
    def __init__(self, value = None, feature = None, terminal = False):
        self.feature = feature #feature col
        self.value = value # value at which to split at
        self.terminal = terminal
        self.left = None
        self.right = None
        self.categorization = ""
    def __str__(self):
        result ="\n".join(["Feature: "+ str(self.feature),"Value: " + str(self.value),"Category: "+ str(self.categorization), "Terminal: " + str(self.terminal)]) 
        if self.left is not None and not self.left.terminal:
            result = result + "\nLeft: Exists" 
        else:
            left = "\nLeft :Terminal"
            result = result + left

        if self.right is not None and not self.right.terminal:
            result = result + "\nRight: Exists" 
        else:
            right = "\nRight :Terminal"
            result = result + right
        return result



In [376]:
def entropy(target_col):
    elements,counts = np.unique(target_col,return_counts = True)
    sum = np.sum(counts)  
    return np.sum([(-counts[i]/sum)*np.log2(counts[i]/sum) for i in range(len(elements))])

In [377]:
def find_split(data,features,target):
    split = None
    best_gain = 0
    for x in features:
        (max_gain, best_split) = split_for_max_gain(data, x, target)
        print("Max gain for ", x, " : ", max_gain, "at point: ", )
        if max_gain > best_gain:
            best_gain = max_gain
            split = (x,best_split)
    return Node(split[1], split[0])
    
        

In [378]:
def split_for_max_gain(data, feature, target_name):      
    sorted = data.sort_values(feature)
    values = np.unique(sorted[feature])
    splits = [(values[x] + values[x+1])/2 for x in range(len(values)-1)]
    max_gain = 0
    split = 0
    for x in splits:
        gain = info_gain_split(data, x, feature,target_name)
        if gain > max_gain:
            max_gain = gain
            split = x
    return (max_gain,split)

In [379]:
def info_gain_split(data, split_value, split_feature, target_feature):
    
    total_entropy = entropy(data[target_feature])

    above = []
    below = []

    for x in data[split_feature]:
        if x >= split_value:
            above.append(x)
        else:
            below.append(x)

    total = len(above) + len(below)
    if(total != len(data[split_feature]):
           input("FAILURE: ABOVE AND BELOW NOT WORKING")
    
    entropy_above = (len(above)/total)*entropy(data.where(data[split_feature] > split_value).dropna()[target_feature])
    entropy_below = (len(below)/total)*entropy(data.where(data[split_feature] < split_value).dropna()[target_feature])

    weighted_entropy = entropy_above + entropy_below
    return total_entropy - weighted_entropy


In [514]:
def no_possible_splits(data, features):
    for feature in features:
        if(len(np.unique(data[feature])) > 1):
            return False
    return True    

In [529]:
def ID3(data, originaldata, features, target):
    
    
    if(no_possible_splits(data, features)):
        node = Node(terminal=True)
        vals, counts = np.unique(data[target], return_counts = True)
        category_index = 0
        for i in range(len(vals)):
            if(counts[i] > counts[category_index]):
                category_index = i
        node.categorization = vals[category_index]
        return node
        
        
        #also if the target col is all one value then there is no split to be done
    if len(np.unique(data[target])) == 1:
        node = Node(terminal=True)
        node.categorization = np.unique(data[target])[0]
        return node
    
    root = find_split(data, features, target)
    
    # grabs all rows where split feautre is below
    left_data = data.where(data[root.feature] < root.value).dropna()
    
    # grabs all rows where split feautre is above
    right_data = data.where(data[root.feature] > root.value).dropna()
    
    
    root.left = ID3(left_data, originaldata, features, target)
    root.right = ID3(right_data, originaldata, features, target)
    
    return root
      

In [530]:
data = pd.read_csv('/Users/elidangerfield/Documents/school/CSCI4350/OLA3/iris-data.txt', sep=" ", header=None)
data.columns = ["sepal_length", "sepal_width", "pedal_length", "pedal_width","class"]
features = ["sepal_length", "sepal_width", "pedal_length", "pedal_width"]
target = "class"

In [531]:
node = ID3(data, data, features, target)

In [532]:
print(node.right.right.right.right.right)

Feature: pedal_length
Value: 3.75
Category: 
Terminal: False
Left :Terminal
Right: Exists
