In [1]:
import numpy as np
import pandas as pd

In [None]:
df = [[2.771244718, 1.784783929, 0],
      [1.728571309, 1.169761413, 0],
      [3.678319846, 2.81281357, 0],
      [3.961043357, 2.61995032, 0],
      [2.999208922, 2.209014212, 0],
      [7.497545867, 3.162953546, 1],
      [9.00220326 , 3.339047188, 1],
      [7.444542326, 0.476683375, 1],
      [10.12493903, 3.234550982, 1],
      [6.642287351, 3.319983761, 1]]

In [63]:
c1 = ['Y', 'Y', 'N', 'N', 'Y', 'Y', 'N']
c2 = ['Y', 'N', 'Y', 'Y', 'Y', 'N', 'N']
c3 = [7, 12, 18, 35, 38, 50, 83]
targets = [0, 0, 1, 1, 1, 0, 0]

cat_cols = {'c1': c1, 'c2': c2}
num_cols = {'c3': c3}
num_classes = np.unique(targets)
cols = {'c1': c1, 'c2': c2, 'c3': c3}

In [89]:
def gini_impurity(left_node_data, right_node_data, targets1):
    left_splits = [0]*num_classes
    right_splits = [0]*num_classes
    for cls in num_classes:
        for data_id in left_node_data:
            if targets1[data_id]==cls:
                left_splits[cls] += 1
        for data_id in right_node_data:
            if targets1[data_id]==cls:
                right_splits[cls] += 1
    left_impurity = 1
    right_impurity = 1
    for cls in num_classes:
        left_impurity -= (left_splits[cls]/sum(left_splits))**2
        right_impurity -= (right_splits[cls]/sum(right_splits))**2
    total_impurity = ((sum(left_splits)*left_impurity) + (sum(right_splits)*right_impurity))/sum(left_splits+right_splits)
    return(total_impurity, left_splits, right_splits)
    

In [125]:
def get_decision_node(cat_cols1, num_cols1, targets1):
    
    # to find best node
    node_options = dict()
    # cat_cols
    for c_name, c in cat_cols1.items():
        categories = list(np.unique(c))
        node_conditions = categories[:-1] # all but one
        for condition in node_conditions:
            left_node_data_id = [i for i in range(len(c)) if c[i]==condition]
            right_node_data_id = [i for i in range(len(c)) if c[i]!=condition]
            total_impurity, left_splits, right_splits = gini_impurity(left_node_data_id, right_node_data_id, targets1)   
            node_options[(c_name, condition)] = (total_impurity, left_node_data_id, right_node_data_id)
    # num_cols
    for c_name, c in num_cols1.items():
        sorted_c = sorted(c)
        avg_col = [(sorted_c[i]+sorted_c[i+1])/2 for i in range(len(c)-1)]
        for condition in avg_col:
            left_node_data_id = [i for i in range(len(c)) if c[i]<condition]
            right_node_data_id = [i for i in range(len(c)) if c[i]>=condition]
            total_impurity, left_splits, right_splits = gini_impurity(left_node_data_id, right_node_data_id, targets1)
            node_options[(c_name, condition)] = (total_impurity, left_node_data_id, right_node_data_id)
    if len(node_options) != 0:
        decision_node = sorted(node_options.items(), key=lambda item: item[1][0])[0]
        return(decision_node)
    
    

In [140]:
get_decision_node(cat_cols, num_cols, targets)

(('c2', 'N'), (0.21428571428571427, [1, 5, 6], [0, 2, 3, 4]))

In [137]:
tree = {}
n = 0
depth = 0
# add root node 
(c_name, cond_val), (_, left_node_data_id, right_node_data_id) = get_decision_node(cat_cols, num_cols, targets)
tree[f'node{n}'] = {'c_name':c_name, 'cond_val':cond_val, 'depth':depth, 'leftNode':False, 'rightNode':False, 'parent': None}
parent = f'node{n}'
n += 1

In [142]:
def get_data_subset(node_data_id):
    cat_cols1 = {}
    num_cols1 = {}
    for c_name, c in cat_cols.items():
        cat_cols1[c_name] = [c[i] for i in node_data_id]
    for c_name, c in num_cols.items():
        num_cols1[c_name] = [c[i] for i in node_data_id]
    targets1 = [targets[i] for i in node_data_id]
    return(cat_cols1, num_cols1, targets1)

In [None]:
def add_node(parent, left_node_data_id, right_node_data_id)

In [138]:
# while depth < 3:
l_cat_cols1 = {}
l_num_cols1 = {}
for c_name, c in cat_cols.items():
    l_cat_cols1[c_name] = [c[i] for i in left_node_data_id]
for c_name, c in num_cols.items():
    l_num_cols1[c_name] = [c[i] for i in left_node_data_id]
l_targets1 = [targets[i] for i in left_node_data_id]
(c_name, cond_val), (_, left_node_data_id, right_node_data_id) = get_decision_node(l_cat_cols1, l_num_cols1, l_targets1)
# node_n = f'node{n}' if f'node{n}' not in tree else f'node{n}'
tree[f'node{n}'] = {'c_name':c_name, 'cond_val':cond_val, 'depth':depth, 'leftNode':True, 'rightNode':False, 'parent': parent}
n += 1
# right_node_data_id = node[1][2]
# r_cat_cols1 = {}
# r_num_cols1 = {}
# for c_name, c in cat_cols.items():
#     r_cat_cols1[c_name] = [c[i] for i in right_node_data_id]
# fo c_name, c in num_cols.items():
#     r_num_cols1[c_name] = [c[i] for i in right_node_data_id]
# r_targets1 = [targets[i] for i in right_node_data_id]

In [139]:
tree

{'node0': {'c_name': 'c2',
  'cond_val': 'N',
  'depth': 0,
  'leftNode': False,
  'rightNode': False,
  'parent': None},
 'node1': {'c_name': 'c1',
  'cond_val': 'N',
  'depth': 0,
  'leftNode': True,
  'rightNode': False,
  'parent': 'node0'}}

In [122]:
def add_node(tree, depth, max_depth, n, cat_cols1, num_cols1, targets1):
    
    node = get_decision_node(cat_cols1, num_cols1, targets1)
    if node:
    
    #     tree = {} 
    #     depth = 0
    #     n = 1
        if depth >= max_depth:
            
            tree[f'n{n}'] = {'c_name': node[0][0],
                            'condition': node[0][1],
                            'depth': depth,
                            'leftChild': None,
                            'rightChild': None
                            }
            return(tree)
        else:
            left_node_data_id = node[1][1]
            l_cat_cols1 = {}
            l_num_cols1 = {}
            for c_name, c in cat_cols.items():
                l_cat_cols1[c_name] = [c[i] for i in left_node_data_id]
            for c_name, c in num_cols.items():
                l_num_cols1[c_name] = [c[i] for i in left_node_data_id]
            l_targets1 = [targets[i] for i in left_node_data_id]
            
            right_node_data_id = node[1][2]
            r_cat_cols1 = {}
            r_num_cols1 = {}
            for c_name, c in cat_cols.items():
                r_cat_cols1[c_name] = [c[i] for i in right_node_data_id]
            for c_name, c in num_cols.items():
                r_num_cols1[c_name] = [c[i] for i in right_node_data_id]
            r_targets1 = [targets[i] for i in right_node_data_id]
            
            tree[f'n{n}'] = {'c_name': node[0][0],
                            'condition': node[0][1],
                            'depth': depth,
                            'leftChild': add_node(tree, depth+1, max_depth, n+1, l_cat_cols1, l_num_cols1, l_targets1),
                            'rightChild': add_node(tree, depth+1, max_depth, n+2, r_cat_cols1, r_num_cols1, r_targets1)
                            }
    else:
        return(tree)

In [123]:
tree = add_node({}, 0, 2, 0, cat_cols, num_cols, targets)

In [124]:
tree

In [102]:
get_decision_node(cat_cols, num_cols, targets)

(('c2', 'N'), (0.21428571428571427, [1, 5, 6], [0, 2, 3, 4]))

In [103]:
left_node_data_id = [1, 5, 6]
cat_cols1 = {}
num_cols1 = {}
for c_name, c in cat_cols.items():
    cat_cols1[c_name] = [c[i] for i in left_node_data_id]
for c_name, c in num_cols.items():
    num_cols1[c_name] = [c[i] for i in left_node_data_id]
targets1 = [targets[i] for i in left_node_data_id]
get_decision_node(cat_cols1, num_cols1, targets1)

(('c1', 'N'), (0.0, [2], [0, 1]))

In [100]:
right_node_data_id = [0, 2, 3, 4]
cat_cols1 = {}
num_cols1 = {}
for c_name, c in cat_cols.items():
    cat_cols1[c_name] = [c[i] for i in right_node_data_id]
for c_name, c in num_cols.items():
    num_cols1[c_name] = [c[i] for i in right_node_data_id]
targets1 = [targets[i] for i in right_node_data_id]
get_decision_node(cat_cols1, num_cols1, targets1)

(('c3', 12.5), (0.0, [0], [1, 2, 3], array([1, 0]), array([0, 3])))

In [45]:
node_options

{('c1', 'N'): 0.4047619047619047,
 ('c2', 9.5): 0.42857142857142855,
 ('c2', 15.0): 0.34285714285714286,
 ('c2', 26.5): 0.4761904761904762,
 ('c2', 36.5): 0.4761904761904762,
 ('c2', 44.0): 0.34285714285714286,
 ('c2', 66.5): 0.42857142857142855}