## Decision Tree for Classification

Algorithm: <br>
* Set up: features X, labels y<br>
* step 1: start at the root<br>
* step 2: measure impurity by using gini or entropy for classification (MSE for regression)<br>
* step 3: try all possible splits<br>
* step 4: choose the best split<br>
* step 5: recurse (max depth reached, node is pure, too few samples to split)<br>
* step 6: make predictions <br>


In [None]:
import numpy as np

def gini(y):
    """
    y: shape (n, )
    """
    classes, counts = np.unique(y, return_counts=True)
    probs = counts/counts.sum() # probability of each class 
    return 1-np.sum(probs**2)

In [45]:
def split_impurity(y_left, y_right):
    """
    find the impurity of the split
    When using this function, 
    we can find the best split by calculating the split impurity
    """
    N = len(y_left)+len(y_right)
    gini_left = gini(y_left)
    gini_right = gini(y_right)
    N_left = len(y_left)
    N_right = len(y_right)
    return N_left/N*gini_left+N_right/N*gini_right

In [None]:
def best_split(X, y):
    n_samples, n_features = X.shape
    best_feature, best_threshold = None, None
    best_impurity = float('inf')

    for feature in range(n_features):
        # get all the unique values of the features
        thresholds = np.unique(X[:, feature])
        for threshold in thresholds:
            left_idx = X[:, feature]<=threshold # boolean index
            right_idx = X[:, feature]>threshold # boolean index

            impurity = split_impurity(y[left_idx], y[right_idx])

            if impurity<best_impurity:
                best_impurity = impurity
                best_feature = feature
                best_threshold = threshold
    return best_feature, best_threshold

In [47]:
class TreeNode:
    def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
        self.feature = feature
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value


In [48]:
def build_tree(X, y, depth=0, max_depth=4, min_samples_split=2):

    # stopping criteria
    if len(np.unique(y))==1 or len(y)<min_samples_split or depth==max_depth:
        values, counts = np.unique(y, return_counts=True)
        return TreeNode(value=values[np.argmax(counts)])

    features, threshold = best_split(X, y)

    left_idx = X[:, features]<=threshold
    right_idx = X[:, features]>threshold

    left = build_tree(X[left_idx], y[left_idx], depth+1, max_depth, min_samples_split)
    right = build_tree(X[right_idx], y[right_idx], depth+1, max_depth, min_samples_split)

    return TreeNode(
        feature=features, 
        threshold=threshold, 
        left=left,
        right=right
    )


In [49]:
def predict_one(x, node):
    if node.value is not None:
        return node.value
    if x[node.feature]<=node.threshold:
        return predict_one(x, node.left)
    else:
        return predict_one(x, node.right)

In [50]:
def predict(X, tree):
    return np.array([predict_one(x, tree) for x in X])

In [51]:
X = np.array([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])
y = np.array([0, 1, 0])
tree = build_tree(X, y)
print(tree)
print(predict(X, tree))


<__main__.TreeNode object at 0x112226d70>
[0 1 0]
