In [1]:
import math
import copy
import numpy as np 
import statistics
import matplotlib.pyplot as plt

In [2]:
class Node:
    def __init__(self, label):
        self.label = label
        self.children = dict()

    def isLeaf(self):
        return len(self.children) == 0

In [3]:
def entropy(data):
    if len(data) == 0:
        return 0
    
    counts = dict()

    for row in data:
        label = row['y']
        if label not in counts:
            counts[label] = 0.0
        counts[label] += row['weights']

    h = 0.0
    norm = _weighted(data)
    for (label, count) in counts.items():
        ratio = count / (norm)
        h -= math.log(ratio, 2) * ratio

    return h

def gini_index(data):
    if len(data) == 0:
        return 0
    
    counts = dict()

    for row in data:
        label = row['y']
        if label not in counts:
            counts[label] = 0.0
        counts[label] += row['weights']

    h = 0.0
    norm =  _weighted(data)
    for (label, count) in counts.items():
        ratio = count / float(norm)
        h +=ratio **2

    return h

def ME(data):
    if len(data) == 0:
        return 0
    
    counts = dict()

    for row in data:
        label = row['y']
        if label not in counts:
            counts[label] = 0.0
        counts[label] += row['weights']

    h = 0.0
    norm = _weighted(data)
    for (label, count) in counts.items():
        ratio = count / float(norm)
        h = max(h, ratio)

    return 1-h



def info_gain(data,gain_type,attribute, vals):
    
    h = None
    new_h = 0.0
    if gain_type == 0:
        h=entropy(data)
           
    elif gain_type == 1:
        h=gini_index(data)
        
    elif gain_type == 2:
        h=ME(data)
            
    for val in vals:
        sub_data = set_subdata(data, attribute, val)
        ratio = _weighted(sub_data) / float(_weighted(data))
        new_h += ratio * h
        
    return h - new_h

In [4]:
def majority_label(data):
    
    counts = dict()
    for row in data:
        label = row['y']
        
        if label not in counts:
            counts[label] = 0.0
            
        counts[label] += row['weights']
        
    common_label=max(counts.keys(), key=lambda key: counts[key])

    return common_label

In [5]:
def WID3(data,gain_type, attributes, labels, max_depth,depth):
    
    if (len(labels) == 1) or (len(attributes) == 0) or depth==max_depth:
        label = majority_label(data)
        
        return Node(label)

    #recursion
    max_attr = select_feature(data,gain_type,attributes)
    root = Node(max_attr)
    

    # split into subsets
    for v in attributes[max_attr]:
        sub_data = set_subdata(data, max_attr, v)

        if len(sub_data) == 0:
            label = majority_label(data)
            root.children[v] = Node(label)
            
        else:
            
            sub_attributes = copy.deepcopy(attributes)
            sub_attributes.pop(max_attr)

            # update subset labels set
            sub_labels = set()
            for row in sub_data:
                sub_label = row['y']
                if  sub_labels not in sub_labels:
                    sub_labels.add(sub_label)

            # recursion
            root.children[v] = WID3(sub_data, gain_type,sub_attributes, sub_labels, max_depth, depth+1)

    return root

In [6]:
def select_feature(data,gain_type, attributes):
    gain_x= dict()

    for ln, lv in attributes.items():
        gain = info_gain(data,gain_type, ln, lv)
        gain_x[ln] = gain
        max_attr=max(gain_x.keys(), key=lambda key: gain_x[key])

    return max_attr

In [7]:
def set_subdata(data, attribute, val):
    sub_data = []

    for row in data:
        if row[attribute] == val:
            sub_data.append(row)

    return sub_data

In [8]:
def _weighted(data):
    
    length = 0.0
    for row in data:
        length += row['weights']
    
    return length

def get_label(row, root):
    new_node = root

    while not new_node.isLeaf():
        curr_attr = new_node.label
        attr_val = row[curr_attr]
        new_node = new_node.children[attr_val]

    return new_node.label

def weighted_error(data, root):
    
    error = 0.0
    for row in data:
        label = get_label(row, root)
        if label != row['y']:
            error += row['weights']

    return error