In [1]:
import math
import copy
import numpy as np 
import statistics
import matplotlib.pyplot as plt

In [2]:
class Node:
    def __init__(self, label):
        self.label = label
        self.children = dict()

    def isLeaf(self):
        return len(self.children) == 0

In [4]:
def set_subset(data, attribute, val):
    sub_data = []

    for row in data:
        if row[attribute] == val:
            sub_data.append(row)

    return sub_data

In [5]:
def set_label(row, dt):
    new_dt = dt

    while not new_dt.isLeaf():
        curr_attr = new_dt.label
        attr_val = row[curr_attr]
        new_dt = new_dt.children[attr_val]

    return new_dt.label

In [6]:
def entropy(data):
    
    if len(data) == 0:
        return 0
    
    counts = dict()

    for row in data:
        label = row['y']
        if label not in counts:
            counts[label] = 0.0
            
        counts[label] += row['weights']

    entropy = 0.0
    total = get_total(data)
    for (label, count) in counts.items():
        p = count / total
        entropy += -p * math.log2(p)

    return entropy

def gini_index(data):
    if len(data) == 0:
        return 0
    
    counts = dict()

    for row in data:
        label = row['y']
        if label not in counts:
            counts[label] = 0.0
        counts[label] += row['weights']

    sq_sum = 0.0
    total =  get_total(data)
    for (label, count) in counts.items():
        p = count / total
        sq_sum +=p **2

    return 1-sq_sum

def ME(data):
    if len(data) == 0:
        return 0
    
    counts = dict()

    for row in data:
        label = row['y']
        if label not in counts:
            counts[label] = 0.0
        counts[label] += row['weights']

    max_p = 0.0
    total = get_total(data)
    for (label, count) in counts.items():
        p = count / total
        max_p = max(max_p, p)

    return 1-max_p

In [7]:
def info_gain(data,gain_type,attribute, vals):
    
    measure = None
    gain = 0.0
    if gain_type == 0:
        measure=entropy(data)
           
    elif gain_type == 1:
         measure=gini_index(data)
        
    elif gain_type == 2:
         measure=ME(data)
            
    for val in vals:
        sub_set = set_subset(data, attribute, val)
        total=get_total(data)
        sub_total=get_total(sub_set)
        p = sub_total /total
        gain += p * measure
        gain_x= measure-gain
        
    return gain_x

In [8]:
def select_feature(data,gain_type, attributes):
    gain_x= dict()

    for ln, lv in attributes.items():
        gain = info_gain(data,gain_type, ln, lv)
        gain_x[ln] = gain
        max_attr=max(gain_x.keys(), key=lambda key: gain_x[key])

    return max_attr

In [9]:
def majority_label(data):
    
    counts = dict()
    for row in data:
        label = row['y']
        
        if label not in counts:
            counts[label] = 0.0
            
        counts[label] += row['weights']
        
    common_label=max(counts.keys(), key=lambda key: counts[key])

    return common_label

In [34]:
def ID3(data,gain_type, attributes, labels, max_depth,depth):
    
    if (len(attributes) == 0) or depth==max_depth:
        label = majority_label(data)
        
        return Node(label)
    
    if (len(labels) == 1):
        label = labels.pop()
        
        return Node(label)

    #recursion
    max_attr = select_feature(data,gain_type,attributes)
    root = Node(max_attr)
    

    # split into subsets
    for v in attributes[max_attr]:
        sub_set = set_subset(data, max_attr, v)

        if len(sub_set) == 0:
            label = majority_label(data)
            root.children[v] = Node(label)
            
        else:
            
            sub_attributes = copy.deepcopy(attributes)
            sub_attributes.pop(max_attr)

            # update subset labels set
            sub_labels = set()
            for row in sub_set:
                sub_label = row['y']
                if  sub_labels not in sub_labels:
                    sub_labels.add(sub_label)

            # recursion
            root.children[v] = ID3(sub_set, gain_type,sub_attributes, sub_labels, max_depth, depth+1)

    return root