#Binary Decision Tree

In [0]:
import numpy as np
import csv
from math import log

## Tree Class and Methods

In [0]:
class BDTree:
  def __init__(self, sample_filename, lbl_list):
    self.data = self.parse_data(sample_filename, lbl_list)
    self.split_var = lbl_list[-1]
    self.data_attrs = lbl_list[:-1]
    self.attr_vars = self.get_key_vars()
    self.top_node = Node()
    self.depth = None
    
  def parse_data(self, fname, lbls):
    '''Takes in the filename of the csv and the ordered labels list and gives
        back a list of dictionaries where each dictionary is a datapoint'''
    data = open(fname)
    data_reader = csv.reader(data)
    data_matrix = []
    for row in data_reader:
      new_dict = {}
      for i in range(len(row)):
        row[i] = row[i].strip()
        if row[i] == 'Yes':
          row[i] = True
        elif row[i] == 'No':
          row[i] = False
        new_dict[lbls[i]] = row[i]
      data_matrix.append(new_dict)
    return data_matrix
  
  def get_key_vars(self):
    '''This function creates a dictionary that holds values for all the possible
        labels that are present in this dataset.'''
    new_dict = {}
    for i in range(len(self.data_attrs)):
      k = self.data_attrs[i]
      new_dict[k] = []
      for j in range(len(self.data)):
        new_dict[k].append(self.data[j][k])
      new_dict[k] = list(set(new_dict[k]))
    return new_dict
      
  
  def node_entropy(self, samples):
    '''Calculates the impurity at a node given the samples coming in aftera 
        split.'''
    T, F = 0, 0
    n = len(samples)
    
    if n == 0:
      return 0
    
    for i in range(n):
      if (samples[i][self.split_var] == True):
        T += 1
      else:
        F += 1
    P_T, P_F = T/n, F/n
    
    # cast these to 1 if 0 to avoid NaN error in log function.
    if P_T == 0:
      P_T = 1
    if P_F == 0:
      P_F = 1
    
    return -sum([P_T*log(P_T,2), P_F*log(P_F,2)]) # entropy
  
  def split_entropy(self, samples, split_attr):
    '''Calculates the entropy of a given split with the samples at a certain
         node.'''
    attr_vals = self.attr_vars[split_attr] # list values that this attribute can take on.
    n = len(samples)
    a = len(attr_vals)
    
    # split samples into separate nodes:
    sample_splits = self.split_data(samples, split_attr)
  
    # find entropy for each new node:
    entropy_terms = [] # will be same size and correspond to attr_vals
    for i in range(a):
      node_sample_num = len(sample_splits[attr_vals[i]])
      e = self.node_entropy(sample_splits[attr_vals[i]])      
      entropy_terms.append((node_sample_num/n) * e)
    return sum(entropy_terms)
  
  def min_split_var(self, attrs, samples):
    '''Finds the variable to split the data on that yields the lowest resulting 
        entropy.'''
    vals = np.zeros(len(attrs))
    for i in range(len(attrs)):
      vals[i] = self.split_entropy(samples, attrs[i])
    idx = np.argmin(vals)
    return attrs[idx]

  def build(self, data, prev_node, attr_list, prev_branch_val= None):
    '''Builds the Tree structure given a predefined maximum depth in a recursive
        fashion using node objects.'''
    if self.depth == None: 
      print("please set depth first")
      return
    
    # hit point where we have used all attributes: (LEAF)
    if ((len(attr_list) == 0) or (self.depth <= self.node_depth(prev_node))): # then this is a leaf node
      # now figure out if it is a True or False leaf node or Neither.
      if(len(data) == 0): # then just initialize leaf node as None
        new_node = Node(parent = prev_node, branch_val = prev_branch_val)
        prev_node.children.append(new_node)
        return
      elif(data[0][self.split_var] == True):
        new_node = Node(parent = prev_node, branch_val = prev_branch_val, leaf = True)
        prev_node.children.append(new_node)
        return
      elif(data[0][self.split_var] == False):
        new_node = Node(parent = prev_node, branch_val = prev_branch_val, leaf = False)
        prev_node.children.append(new_node)
        return
    
    # hit point where we ran out of samples: (LEAF)
    elif(len(data) == 0): # then this is also a leaf node
      new_node = Node(parent = prev_node, branch_val = prev_branch_val)
      prev_node.children.append(new_node)
      return
    
    # otherwise we continue with the recursion:
    
    # find minimum split given current data and attr_list
  
    attr = self.min_split_var(attr_list, data)
    # remove it from attribute list
    attr_list.remove(attr)
    # create new node for this split
    new_node = Node(parent = prev_node, branch_val = prev_branch_val, var = attr)
    prev_node.children.append(new_node)
    # split the data
    data_splits = self.split_data(data, attr)
    # recurse for each branch
    branches = self.attr_vars[attr]
    for i in range(len(branches)):
      self.build(data_splits[branches[i]], new_node, attr_list, branches[i])
    
  def split_data(self, data, attr):
    '''Returns the incoming data split into two lists given a certain attribute.'''
    attr_vals = self.attr_vars[attr] # list values that this attribute can take on.
    n = len(data)
    a = len(attr_vals)
    data_splits = {}
    for i in range(a):
      data_splits[attr_vals[i]] = [] # create empty lists for each branch
    for i in range(n):
      data_splits[data[i][attr]].append(data[i]) # put each sample into correct list.
        
    return data_splits
    
  def node_depth(self, node):
    '''Returns depth of a node.'''
    d = 0
    while(node.parent != None):
      d += 1
      node = node.parent
      
    return d
  
  def visualize(self, node=None, dictionary = {}):
    '''Returns a dictionary of node descriptions where each key is the depth in
        the tree.'''
    if node == None:
      node = self.top_node
      p_var = None
    else: 
      p_var = node.parent.var
      
    d = self.node_depth(node)
    
    if d in dictionary.keys():
      dictionary[d].append((node.var, 'From Node: ', p_var, 'Through Branch: ', node.branch_val, 'Leaf Value: ', node.leaf)) 
    else:
      dictionary[d] = [(node.var, 'From Node: ', p_var, 'Through Branch: ', node.branch_val, 'Leaf Value: ', node.leaf)]
    
    for i in range(len(node.children)):
      self.visualize(node.children[i], dictionary)
      
    return dictionary
  
  def evaluate(self, sample):
    '''Uses the generated tree structure to classify a sample datapoint.'''
    current = self.top_node.children[0]
    decision = None
    not_leaf = True
    while not_leaf:
      # check if current node is a leaf:
      if not current.children: # will be true when list is empty (leaf node)
        decision = current.leaf
        not_leaf = False
        
      # otherwise figure out which branch to go down:
      else:
        for i in range(len(current.children)):
          if current.children[i].branch_val == sample[current.var]:
            current = current.children[i]
            break
    return decision
             
    
class Node:
  '''A class representing the nodes of a decision tree.'''
  def __init__(self, parent = None, branch_val = None, var = None, leaf = None):
    self.parent = parent
    self.branch_val = branch_val
    self.var = var
    self.children = []
    self.leaf = leaf

## Instantiate and Build Tree

In [0]:
Tree = BDTree('restaurant.csv', ['alt', 'bar', 'fri', 'hun', 'pat', 'price', 'rain', 'res', 'type', 'est', 'wait'])

In [0]:
Tree.depth = 3 # This isn't counting the leaf node depth, so this will give a total depth of 4.
Tree.build(Tree.data, Tree.top_node, Tree.data_attrs)

## Visualize Tree

In [0]:
dictionary = Tree.visualize()
for i in range(1, len(list(dictionary.keys()))): 
  # not showing the placeholder top node. That one is just a pointer.
  print(dictionary[i])

[('pat', 'From Node: ', None, 'Through Branch: ', None, 'Leaf Value: ', None)]
[('alt', 'From Node: ', 'pat', 'Through Branch: ', 'None', 'Leaf Value: ', None), ('hun', 'From Node: ', 'pat', 'Through Branch: ', 'Full', 'Leaf Value: ', None), ('price', 'From Node: ', 'pat', 'Through Branch: ', 'Some', 'Leaf Value: ', None)]
[('bar', 'From Node: ', 'alt', 'Through Branch: ', False, 'Leaf Value: ', None), (None, 'From Node: ', 'alt', 'Through Branch: ', True, 'Leaf Value: ', None), ('fri', 'From Node: ', 'hun', 'Through Branch: ', False, 'Leaf Value: ', None), ('type', 'From Node: ', 'hun', 'Through Branch: ', True, 'Leaf Value: ', None), ('rain', 'From Node: ', 'price', 'Through Branch: ', '$', 'Leaf Value: ', None), ('res', 'From Node: ', 'price', 'Through Branch: ', '$$$', 'Leaf Value: ', None), ('est', 'From Node: ', 'price', 'Through Branch: ', '$$', 'Leaf Value: ', None)]
[(None, 'From Node: ', 'bar', 'Through Branch: ', False, 'Leaf Value: ', False), (None, 'From Node: ', 'bar', 'T

## Test Tree on Training Set

In [0]:
count = 0
for i in range(len(Tree.data)):
  actual = Tree.data[i]['wait']
  pred = Tree.evaluate(Tree.data[i])
  if actual == pred:
    count += 1
print("Accuracy of Tree on Training Data: ", count/len(Tree.data))

Accuracy of Tree on Training Data:  0.9166666666666666
