In [2]:
import numpy as np

In [13]:
class Node:
    def __init__(self, features=None, threshold=None, left=None, right=None, *, value=None):
        self.feature = features
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value
    
    def _is_leaf_node(self):
        return self.value is not None

In [14]:
class DecisionTree:
    def __init__(self, max_depth=100, min_sample_split=2, criterion="gini"):
        self.max_depth = max_depth
        self.min_sample_split = min_sample_split
        self.criterion = criterion
        self.root = None
    
    def gini(self, y):
        m = len(y)
        if m == 0: return 0
        class_count = np.bincount(y)
        prob = class_count/m
        return 1 - np.sum(prob**2)
    
    def entropy(self, y):
        m = len(y)
        if m == 0: return 0
        class_counts = np.bincount(y)
        prob = class_counts / m
        return -np.sum([p*np.log2(p) for p in prob if p >0])
    
    def impurity(self, y):
        if self.criterion == "gini":
            return self.gini(y)
        elif self.criterion == "entropy":
            return self.entropy(y)
        else:
            ValueError("Such criterion doesnt exist")
    
    def info_gain(self, y, y_left, y_right):
        H = self.impurity(y)
        m = len(y)
        H_left = self.impurity(y_left)
        H_right = self.impurity(y_right)
        m_left = len(y_left)
        m_right = len(y_right)
        return H - (m_left/m * H_left + m_right/m * H_right)
    
    def best_split(self, X, y):
        best_gain = -1
        split_idx, split_thresh = None, None

        m, n = X.shape
        for feature in range(n):
            thresholds = np.unique(X[:, feature])
            for t in thresholds:
                left_idx = np.where(X[:, feature] <= t)
                right_idx = np.where(X[:, feature] > t)
                if len(left_idx[0]) == 0 or len(right_idx[0]) == 0:
                    continue
                y_left, y_right = y[left_idx], y[right_idx]
                gain = self.info_gain(y, y_left, y_right)
                if gain > best_gain:
                    best_gain = gain
                    split_idx = feature
                    split_thresh = t

        return split_idx, split_thresh
    
    def _most_common_label(self, y):
        counts = np.bincount(y)
        return np.argmax(counts)
    
    def build_tree(self, X, y, depth=0):
        n_samples, n_features = X.shape
        n_classes = len(np.unique(y))
        if (depth >= self.max_depth or n_classes == 1 or n_samples < self.min_sample_split):
            leaf_value = self._most_common_label(y)
            return Node(value=leaf_value)
        
        feature, thresh = self.best_split(X, y)
        if feature is None:
            leaf_value = self._most_common_label(y)
            return Node(value=leaf_value)
        
        left_idx = np.where(X[:, feature] <= thresh)
        right_idx = np.where(X[:, feature] > thresh)

        left_child = self.build_tree(X[left_idx], y[left_idx], depth+1)
        right_child = self.build_tree(X[right_idx], y[right_idx], depth+1)
        return Node(feature, thresh, left_child, right_child)
    
    def fit(self, X, y):
        self.root = self.build_tree(X, y)
    
    def predict_one(self, X, node):
        if node._is_leaf_node():
            return node.value
        if X[node.feature] <= node.threshold:
            return self.predict_one(X, node.left)
        else:
            return self.predict_one(X, node.right)
    
    def predict(self, X):
        return np.array([self.predict_one(x, self.root) for x in X])

In [15]:
from numpy.random import random_sample
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

data = load_breast_cancer()
X, y = data.data, data.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 42)

tree = DecisionTree(max_depth = 5, criterion = "gini")
tree.fit(X_train, y_train)

y_pred = tree.predict(X_test)

acc = np.mean(y_pred == y_test) * 100
print(f"Test Accuracy: {acc:.2f}%")

Test Accuracy: 92.98%
