# Classification decision tree implemented in Python

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

## Decision tree algorithm

In [2]:
class DecisionTree():
    
    def __init__(self):
        self.root_dict = None
        self.tree_dict = None
        
        
    def split_dataset(self, X, y, feature_idx, threshold):
        left_idx = np.where(X[:, feature_idx] < threshold)
        right_idx = np.where(X[:, feature_idx] >= threshold)
        
        left_subset = X[left_idx]
        y_left = y[left_idx]
        
        right_subset = X[right_idx]
        y_right = y[right_idx]
        
        split = {
            'left': left_subset,
            'y_left': y_left,
            'right': right_subset,
            'y_right': y_right,
        }
        
        return split
    
    
    def gini(self, y_left, y_right, n_left, n_right):
        gini_left, nigi_right = 0, 0
        n_total = n_left + n_right
        score_left, score_right = 0, 0
        
        if len(y_left) != 0:
            for c in range(self.n_class):
                p_left = len(np.where(y_left == c)[0]) / n_left
                score_left += p_left * p_left
            gini_left = 1 - score_left
        
        if len(y_right) != 0:
            for c in range(self.n_class):
                p_right = len(np.where(y_right == c)[0]) / n_right
                score_right += p_right * p_right
            gini_right = 1 - score_right
            
        return gini_left, gini_right
    
    
    def cost(self, split):
        y_left = split['y_left']
        y_right = split['y_right']
        
        n_left = len(y_left)
        n_right = len(y_right)
        n_total = n_left + n_right
        
        gini_left, gini_right = self.gini(y_left, y_right, n_left, n_right)
        cost = (n_left/n_total)*gini_left + (n_right/n_total)*gini_right
        
        return cost
    
    
    def find_best_split(self, X, y):
        n_samples, n_features = X.shape
        best_feature_idx, best_threshold, best_cost, best_split = np.inf, np.inf, np.inf, None
        
        for feature_idx in range(n_features):
            for i in range(n_samples):
                current = X[i]
                threshold = current[feature_idx]
                split = self.split_dataset(X, y, feature_idx, threshold)
                cost = self.cost(split)
                
                if cost < best_cost:
                    best_feature_idx = feature_idx
                    best_threshold = threshold
                    bext_cost = cost
                    best_split = split
                    
        best_split_param = {
            'feature_idx': best_feature_idx,
            'threshold': best_threshold,
            'cost': best_cost,
            'left': best_split['left'],
            'y_left': best_split['y_left'],
            'right': best_split['right'],
            'y_right': best_split['y_right'],
        }
        
        return best_split_param
    
    
    def build_tree(self, node_dict, depth, max_depth, min_samples):
        left_samples = node_dict['left']
        right_samples = node_dict['right']
        y_left_samples = node_dict['y_left']
        y_right_samples = node_dict['y_right']
        
        if len(y_left_samples) == 0 or len(y_right_samples) == 0:
            node_dict['left_child'] = node_dict['right_child'] = self.create_terminal_node(np.append(y_left_samples, y_right_samples))
            return None
        
        if depth >= max_depth:
            node_dict['left_child'] = self.create_terminal_node(y_left_samples)
            node_dict['right_child'] = self.create_terminal_node(y_right_samples)
            return None
        
        if len(right_samples) < min_samples:
            node_dict['right_child'] = self.create_terminal_node(y_right_samples)
        else:
            node_dict['right_child'] = self.find_best_split(right_samples, y_right_samples)
            self.build_tree(node_dict['right_child'], depth + 1, max_depth, min_samples)
            
        if len(left_samples) < min_samples:
            node_dict['left_child'] = self.create_terminal_node(y_left_samples)
        else:
            node_dict['left_child'] = self.find_best_split(left_samples, y_left_samples)
            self.build_tree(node_dict['left_child'], depth + 1, max_depth, min_samples)
            
        return node_dict
    
    
    def create_terminal_node(self, y):
        classification = max(set(y), key = list(y).count)
        return classification
    
    
    def train(self, X, y, max_depth, min_samples):
        self.n_class = len(set(y))
        self.root_dict = self.find_best_split(X, y)
        self.tree_dict = self.build_tree(self.root_dict, 1, max_depth, min_samples)
        
        
    def predict(self, X, node):
        feature_idx = node['feature_idx']
        threshold = node['threshold']
        
        if X[feature_idx] < threshold:
            if isinstance(node['left_child'], (int, np.integer)):
                return node['left_child']
            else:
                prediction = self.predict(X, node['left_child'])
        elif X[feature_idx] >= threshold:
            if isinstance(node['right_child'], (int, np.integer)):
                return node['right_child']
            else:
                prediction = self.predict(X, node['right_child'])
                
        return prediction

## Dataset

In [3]:
iris = load_iris()
X, y = iris.data, iris.target

In [4]:
X_train, X_test, y_train, y_test = train_test_split(X, y)

print(f'Shape X_train: {X_train.shape}')
print(f'Shape y_train: {y_train.shape}')
print(f'Shape X_test: {X_test.shape}')
print(f'Shape y_test: {y_test.shape}')

Shape X_train: (112, 4)
Shape y_train: (112,)
Shape X_test: (38, 4)
Shape y_test: (38,)


## Training and testing

In [5]:
tree = DecisionTree()
tree.train(X_train, y_train, max_depth = 3, min_samples = 1)

In [6]:
def print_tree(node, depth = 0):
    if isinstance(node, (int, np.integer)):
        print(f"{depth*'   '} predicted_class: {iris.target_names[node]}")
    else:
        print(f"{depth*'   '} {iris.feature_names[node['feature_idx']]} < {round(node['threshold'], 3)}, "
             f"cost of split: {round(node['cost'], 3)}")
        print_tree(node['left_child'], depth + 1)
        print_tree(node['right_child'], depth + 1)

In [7]:
print_tree(tree.tree_dict)

 petal width (cm) < 1.8, cost of split: inf
    petal width (cm) < 1.4, cost of split: inf
       petal width (cm) < 1.0, cost of split: inf
          predicted_class: setosa
          predicted_class: versicolor
       petal width (cm) < 1.4, cost of split: inf
          predicted_class: versicolor
          predicted_class: versicolor
    petal width (cm) < 1.8, cost of split: inf
       predicted_class: virginica
       predicted_class: virginica


In [8]:
predictions = []

for i in range(X_test.shape[0]):
    result = tree.predict(X_test[i], tree.tree_dict)
    predictions.append(y_test[i] == result)
    
print(f"Accuracy on test set: {sum(predictions) / len(predictions)}")

Accuracy on test set: 0.9736842105263158
