### Import Python Packages

In [1]:
import numpy as np
import matplotlib.pyplot as plt

### Import the Data

In [2]:
bank_data = np.genfromtxt('./Data/CART/data_banknote_authentication.txt', delimiter = ',')

In [3]:
bank_data.shape

(1372, 5)

### Create Required Funtions

In [4]:
def calc_gini(preds, actuals):
    total_len = len(preds)
    group_ginis = np.array([])
    group_probs = np.array([])
    for group_val in np.unique(preds):
        group_probs = np.append(group_probs, np.sum(preds == group_val) / total_len)
        cat_ginis = np.array([])
        for cat_val in np.unique(preds):
            cat_ginis = np.append(cat_ginis,
                                  np.sum((actuals == cat_val) & (preds == group_val)) / np.sum(preds == group_val))
        group_ginis = np.append(group_ginis, 1 - np.sum(cat_ginis ** 2))
        
    return np.dot(group_probs, group_ginis)

In [5]:
def split_data(feature_index, value, array_to_split):
    left_split_index = np.where(array_to_split[:, feature_index] - value < 0)
    right_split_index = np.where(array_to_split[:, feature_index] - value >= 0)
    left = array_to_split[left_split_index]
    right = array_to_split[right_split_index]
    
    return (left, right)

In [6]:
def get_split_index(feature_vals, split_val):
    return np.where(feature_vals >= split_val)[0]

In [7]:
def get_gini_from_split(data_array, feature_index, split_val, responses):
    splits = get_split_index(data_array[:, feature_index], split_val)
    preds = [1 if np.any(splits == i) else 0 for i in range(len(data_array[:, feature_index]))]
    actuals = responses
    return calc_gini(preds, actuals)

In [32]:
def get_best_split(features, responses):
    best_split = {'feature_index':-1, 'value': 0, 'gini': 1}
    for feature_index in range(features.shape[1]):
        values = features[:, feature_index]
        for value in values:
            gini = get_gini_from_split(features, feature_index, value, responses)
            if gini < best_split['gini']:
                best_split['feature_index'] = feature_index
                best_split['value'] = value
                best_split['gini'] = gini
                
    return best_split

In [76]:
def get_terminals(data_array, feature_index, value, responses):
    split_index = get_split_index(data_array[:, feature_index], value)
    return (np.rint(np.mean([responses[i] for i in range(responses.shape[0]) if not np.any(split_index == i)])),
            np.rint(np.mean(responses[split_index])))

In [82]:
features = np.array([[2.771244718,1.784783929],
                    [1.728571309,1.169761413],
                    [3.678319846,2.81281357],
                    [3.961043357,2.61995032],
                    [2.999208922,2.209014212],
                    [7.497545867,3.162953546],
                    [9.00220326,3.339047188],
                    [7.444542326,0.476683375],
                    [10.12493903,3.234550982],
                    [6.642287351,3.319983761]])

actuals = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])

tree = {'indexes': np.zeros(actuals.shape), 'splits': []}

# {'feature_index': 0, 'value': 6.642287351, 'gini': 0.0}
split1 = get_best_split(features, actuals)

get_terminals(features, split1['feature_index'], split1['value'], actuals)
tree

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])