In [2]:
import numpy as np
from collections import Counter
from sklearn import datasets
from sklearn.model_selection import train_test_split

In [36]:
class Node:
    def __init__(self, left=None, right=None, feature=None, threshold=None, value=None) -> None:
        self.feature = feature
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value

    def is_leaf_node(self):
        return self.value is not None

In [37]:
class DecisionTree:
    def __init__(self, n_features=None, max_depth=10, min_sample_split=2) -> None:
        self.max_depth = max_depth
        self.min_sample_split = min_sample_split
        self.n_features = n_features
        self.root = None

    def fit(self, X, y):
        self.n_features = X.shape[1]
        self.root = self._grow_tree(X, y, depth=0)

    def _grow_tree(self, X, y, depth):
        
        n_sample = len(y)
        n_label = len(np.unique(y))
        # termination condition
        if depth > self.max_depth or n_sample < self.min_sample_split or n_label == 1:
            # return leaf node
            value = self._most_common_label(y)
            return Node(value=value)
        
        # find the best split
        features = np.random.choice(
            self.n_features, self.n_features, replace=False)
        feature_idx, threshold = self._best_split(X, y, features)

        left_idx, right_idx = self._split(X, threshold, feature_idx)

        left = self._grow_tree(X[left_idx, :], y[left_idx], depth + 1)
        right = self._grow_tree(X[right_idx, :], y[right_idx], depth + 1)

        return Node(left=left, right=right, feature=feature_idx, threshold=threshold)

    
    def _best_split(self, X, y, features):

        # greedy search for finding the best feature and split value
        best_gain = -1
        best_feature_idx, best_threshold = None, None
        for feature in features:
            X_column = X[:, feature]
            thresholds = np.unique(X_column)
            for threshold in thresholds:
                gain = self._information_gain(X, y, threshold, feature)

                if gain > best_gain:
                    best_gain = gain
                    best_feature_idx = feature
                    best_threshold = threshold

        return best_feature_idx, best_threshold
    
    def _entropy(self, y):
        hist = np.array(list(Counter(y).values()))
        probabilities = hist / len(y)
        return -np.sum(probabilities * np.log(probabilities))
    
    def _information_gain(self, X, y, threshold, feature_idx):
        parent_entropy = self._entropy(y)

        # create left and right child
        left_idx, right_idx = self._split(X, threshold, feature_idx)

        if len(left_idx) == 0 or len(right_idx) == 0:
            return 0

        n = len(y)
        n_l, n_r = len(left_idx), len(right_idx)
        e_l, e_r = self._entropy(y[left_idx]), self._entropy(y[right_idx])
        child_entropy = (n_l / n) * e_l + (n_r / n) * e_r
        return parent_entropy - child_entropy

    def _split(self, X, threshold, feature_idx):
        left_idx = np.argwhere(X[:, feature_idx] <= threshold).flatten()
        right_idx = np.argwhere(X[:, feature_idx] > threshold).flatten()
        return left_idx, right_idx
    
    def _most_common_label(self, y):
        return Counter(y).most_common(1)[0][0]
    
    def predict(self, X):
        return [self._traverse_tree(x, self.root) for x in X]

    def _traverse_tree(self, x, node: Node):
        if node.is_leaf_node():
            return node.value
        
        if x[node.feature] <= node.threshold:
            return self._traverse_tree(x, node.left)
        return self._traverse_tree(x, node.right)

In [59]:
data = datasets.load_breast_cancer()
X, y = data.data, data.target

In [60]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=24)

In [61]:
clf = DecisionTree(max_depth=50)
clf.fit(X_train, y_train)

In [62]:
y_preds = clf.predict(X_test)

In [63]:
accuracy = np.sum(y_preds == y_test) / len(y_test) * 100
print(accuracy)

92.98245614035088
