# Decision Trees



In [4]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax.scipy import stats
from random import seed, randrange

Algorithm: Classification and Regression Tree (CART)
Input: Training dataset, Maximum tree depth (max_depth), Minimum node size (min_size)
Output: Decision Tree

Procedure:

1. Start with the entire training dataset.

2. Determine the best feature and value to split the dataset on:
   - For each feature in the dataset:
     - For each unique value of that feature:
       - Split the dataset into two groups based on whether their value for the feature is less than or equal to the value.
       - Calculate the cost (such as Gini impurity for classification or sum of squared residuals for regression) of this split.
       - Keep track of the feature and value that produces the lowest cost.

3. Create a node in the tree representing this decision (to split on the best feature and value).

4. Recursively apply this process to each of the two groups of data created by the split. Each group creates a new branch in the tree:
   - If a group is pure (all the outputs are the same) or if it is smaller than min_size or if the tree depth is equal to max_depth, create a leaf node. The prediction of the leaf node is the most common output in the group (for classification) or the mean output (for regression).
   - Otherwise, repeat from step 2 with the current group of data.

5. Return the tree. For a new instance, to make a prediction, start at the root of the tree and follow the branches based on the instance's features, until a leaf node is reached. The prediction is the value associated with the leaf node.


In [6]:
# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
    # count all samples at split point
    n_instances = sum([len(group) for group in groups])
    # sum weighted Gini index for each group
    gini = 0.0
    for group in groups:
        size = float(len(group))
        # avoid divide by zero
        if size == 0:
            continue
        score = 0.0
        # score the group based on the score for each class
        for class_val in classes:
            p = (group[:, -1] == class_val).sum() / size
            score += p * p
        # weight the group score by its relative size
        gini += (1.0 - score) * (size / n_instances)
    return gini

# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):
    left, right = list(), list()
    for row in dataset:
        if row[index] < value:
            left.append(row)
        else:
            right.append(row)
    return left, right

# Select the best split point for a dataset
def get_split(dataset, n_features):
    class_values = list(set(row[-1] for row in dataset))
    b_index, b_value, b_score, b_groups = 999, 999, 999, None
    features = list()
    while len(features) < n_features:
        index = randrange(len(dataset[0])-1)
        if index not in features:
            features.append(index)
    for index in features:
        for row in dataset:
            groups = test_split(index, row[index], dataset)
            gini = gini_index(groups, class_values)
            if gini < b_score:
                b_index, b_value, b_score, b_groups = index, row[index], gini, groups
    return {'index':b_index, 'value':b_value, 'groups':b_groups}

# Create a terminal node value
def to_terminal(group):
    outcomes = [row[-1] for row in group]
    return max(set(outcomes), key=outcomes.count)

# Create child splits for a node or make terminal
def split(node, max_depth, min_size, n_features, depth):
    left, right = node['groups']
    del(node['groups'])
    # check for a no split
    if not left or not right:
        node['left'] = node['right'] = to_terminal(left + right)
        return
    # check for max depth
    if depth >= max_depth:
        node['left'], node['right'] = to_terminal(left), to_terminal(right)
        return
    # process left child
    if len(left) <= min_size:
        node['left'] = to_terminal(left)
    else:
        node['left'] = get_split(left, n_features)
        split(node['left'], max_depth, min_size, n_features, depth+1)
    # process right child
    if len(right) <= min_size:
        node['right'] = to_terminal(right)
    else:
        node['right'] = get_split(right, n_features)
        split(node['right'], max_depth, min_size, n_features, depth+1)

# Build a decision tree
def build_tree(train, max_depth, min_size, n_features):
    root = get_split(train, n_features)
    split(root, max_depth, min_size, n_features, 1)
    return root

# Make a prediction with a decision tree
def predict(node, row):
    if row[node['index']] < node['value']:
        if isinstance(node['left'], dict):
            return predict(node['left'], row)
        else:
            return node['left']
    else:
        if isinstance(node['right'], dict):
            return predict(node['right'], row)
        else:
            return node['right']

# Create a random subsample from the dataset with replacement
def subsample(dataset, ratio):
    sample = list()
    n_sample = round(len(dataset) * ratio)
    while len(sample) < n_sample:
        index = randrange(len(dataset))
        sample.append(dataset[index])
    return sample


In [7]:
# Test CART on Banknote dataset
def test_CART():
    seed(1)
    # create a small toy dataset
    dataset = [[2.771244718,1.784783929,0],
    [1.728571309,1.169761413,0],
    [3.678319846,2.81281357,0],
    [3.961043357,2.61995032,0],
    [2.999208922,2.209014212,0],
    [7.497545867,3.162953546,1],
    [9.00220326,3.339047188,1],
    [7.444542326,0.476683375,1],
    [10.12493903,3.234550982,1],
    [6.642287351,3.319983761,1]]
    # set parameters
    max_depth = 2
    min_size = 1
    # build the tree
    tree = build_tree(dataset, max_depth, min_size, n_features=2)
    print('Tree:', tree)
    # make predictions
    for row in dataset:
        prediction = predict(tree, row)
        print('Predicted=%d, Actual=%d' % (prediction, row[-1]))

test_CART()


TypeError: list indices must be integers or slices, not tuple