In [12]:
import numpy as np
import pandas as pd
eps = np.finfo(float).eps
from numpy import log2 as log

In [13]:
train_data = pd.read_excel('dataset for part 1.xlsx',sheet_name=0)

In [14]:
train_data

Unnamed: 0,price,maintenance,capacity,airbag,profitable
0,low,low,2,no,yes
1,low,med,4,yes,no
2,low,high,4,no,no
3,med,med,4,no,no
4,med,med,4,yes,yes
5,med,high,2,yes,no
6,high,med,4,yes,yes
7,high,high,2,yes,no
8,high,high,5,yes,yes


In [15]:
test_data = pd.read_excel('dataset for part 1.xlsx',sheet_name=1)

In [16]:
test_data

Unnamed: 0,price,maintenance,capacity,airbag,profitable
0,med,high,5,no,yes
1,low,low,4,no,yes


In [17]:
def entropy_parent(df):
    entropy_node = 0
    values = df['profitable'].unique()
    for value in values:
        fraction = df['profitable'].value_counts()[value]/len(df['profitable'])
        entropy_node += -fraction*np.log2(fraction)
    return entropy_node

In [18]:
def entropy_split(df,attribute):
    target_variables = df['profitable'].unique()
    variables = df[attribute].unique()
    entropy_attribute = 0
    for variable in variables:
        entropy_each_feature = 0
        for target_variable in target_variables:
            num = len(df[attribute][df[attribute]==variable][df['profitable'] == target_variable])
            den = len(df[attribute][df[attribute]==variable])
            fraction = num/(den+eps)
            entropy_each_feature += -fraction*log(fraction+eps)
        fraction2 = den/len(df)
        entropy_attribute += -fraction2*entropy_each_feature
    return abs(entropy_attribute)

In [19]:
def find_best_split(df):
    info_gains = []
    for key in df.keys()[:-1]:
        info_gains.append(entropy_parent(df) - entropy_split(df,key))
    return df.keys()[:-1][np.argmax(info_gains)]

In [20]:
def get_subtable(df,node,value):
    return df[df[node]==value].reset_index(drop=True)

In [21]:
def build_tree(df,tree=None):
    node = find_best_split(df)
    attValues = np.unique(df[node])
    if tree is None:
        tree = {}
        tree[node] = {}
    for value in attValues:
        subtable = get_subtable(df,node,value)
        clValue,counts = np.unique(subtable['profitable'],return_counts=True)
        if len(counts)==1:
            tree[node][value] = clValue[0]
        else:
            tree[node][value] = build_tree(subtable)
    return tree

In [22]:
tree = build_tree(train_data)

In [23]:
import pprint
pprint.pprint(tree)

{'maintenance': {'high': {'capacity': {2: 'no', 4: 'no', 5: 'yes'}},
                 'low': 'yes',
                 'med': {'price   ': {'high': 'yes',
                                      'low': 'no',
                                      'med': {'airbag ': {'no': 'no',
                                                          'yes': 'yes'}}}}}}


In [24]:
print(tree)

{'maintenance': {'med': {'price   ': {'med': {'airbag ': {'yes': 'yes', 'no': 'no'}}, 'low': 'no', 'high': 'yes'}}, 'low': 'yes', 'high': {'capacity': {2: 'no', 4: 'no', 5: 'yes'}}}}


In [25]:
def print_tree(tree,level):
    if isinstance(tree,dict):
        print('')
        attr = list(tree.keys())[0]
        for i in tree[attr]:
            if level>0:
                for j in range(level-1):
                    print('\t',end='')
                print('|',end=' ')
            print(attr + ' = '+str(i),end=' ')
            print_tree(tree[attr][i],level+1)
    else:
        print(': ' + str(tree))
        
    

In [26]:
print_tree(tree,0)


maintenance = med 
| price    = med 
	| airbag  = yes : yes
	| airbag  = no : no
| price    = low : no
| price    = high : yes
maintenance = low : yes
maintenance = high 
| capacity = 2 : no
| capacity = 4 : no
| capacity = 5 : yes


In [27]:
def predict(inst,tree):
    for nodes in tree.keys():
        value = inst[nodes]
        tree = tree[nodes][value]
        prediction = ''
        if isinstance(tree,dict):
            prediction = predict(inst,tree)
        else:
            prediction = tree
            break
    return prediction

In [30]:
test_y = test_data['profitable']

In [28]:
for index,row in test_data.drop('profitable',axis=1).iterrows():
    print(predict(row,tree))

yes
yes
