In [1]:
from __future__ import annotations
import numpy as np

In [2]:
class Node:
    def __init__(
        self, left: Node = None, right: Node = None, feature_idx: int = None, value=None
    ):
        self.left = left
        self.right = right
        self.value = value
        self.feature_idx = feature_idx

    def is_leaf_node(self):
        return self.value != None

In [3]:
class DecisionTree:
    def __init__(self, max_depth: int = 5, min_samples_split: int = 2):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.root = None

    def _grow_tree(self, X, y, depth=0):
        n_samples, n_features = X.shape
        n_labels = len(np.unique(y))

        # check  for stopping  criterion
        if (
            depth >= self.max_depth
            or n_labels == 1
            or n_samples < self.min_samples_split
        ):

            def most_common_label(y):
                return np.bincount(y).argmax()

            leaf_value = most_common_label(y)
            return Node(value=leaf_value)

        # find best split
        best_feature_idx = self._find_best_split(X, y)

        # create child nodes
        left_node = self._grow_tree(
            X[X[:, best_feature_idx] == True],
            y[X[:, best_feature_idx] == True],
            depth=depth + 1,
        )
        right_node = self._grow_tree(
            X[X[:, best_feature_idx] == False],
            y[X[:, best_feature_idx] == False],
            depth=depth + 1,
        )

        node = Node(left=left_node, right=right_node, feature_idx=best_feature_idx)

        return node

    def _find_best_split(self, X, y):
        min_impurity = float("inf")
        best_feature_idx = None

        for feature_idx in range(X.shape[1]):
            impurity = self.impurity(X, y, feature_idx)

            if impurity < min_impurity:
                min_impurity = impurity
                best_feature_idx = feature_idx

        return best_feature_idx

    def fit(self, X: np.array, Y: np.array):
        self.labels = np.unique(Y)

        self.root = self._grow_tree(X, y)

    def predict(self, x: np.array):
        return self.traverse(self.root, x)

    def traverse(self, node: Node, x: np.array):
        if node.is_leaf_node():
            return node.value

        if x[node.feature_idx]:
            return self.traverse(node.left, x)
        else:
            return self.traverse(node.right, x)

    def impurity(node: Node, X: np.array, y: np.array, feature_idx: int):
        # Get the gini impurity for a node
        n = len(X)

        unique_feats_under_feature_idx = np.unique(X[feature_idx])
        X_column = X[:, feature_idx]
        g_i = 0
        for unique_feat in unique_feats_under_feature_idx:
            subsubset = X_column[
                X_column == unique_feat
            ]  # np.array([row for row in subset if row[0] == unique_feat])
            n_sss = len(subsubset)

            unique_y_counts = np.bincount(
                y[X_column == unique_feat]
            )  # subsubset[:, 1])
            # gini impurity for sub-sub-set
            g_i_sss = 1 - sum(
                [(unique_y_count / n_sss) ** 2 for unique_y_count in unique_y_counts]
            )

            g_i += n_sss * g_i_sss

        return g_i

In [4]:
clf = DecisionTree()

In [5]:
data = np.array([
    [True, False, True, False, False],
    [True, True, True, False, True],
    [True, False, True, False, False],
    [True, True, True, False, True],
    [False, False, False, False, False],
    [False, True, False, False, True],
    [True, True, True, False, True],
    [True, True, True, True, True],
])
X = data[:,:-1]
y = data[:,-1]
X

array([[ True, False,  True, False],
       [ True,  True,  True, False],
       [ True, False,  True, False],
       [ True,  True,  True, False],
       [False, False, False, False],
       [False,  True, False, False],
       [ True,  True,  True, False],
       [ True,  True,  True,  True]])

In [6]:
y

array([False,  True, False,  True, False,  True,  True,  True])

In [7]:
clf.fit(X,y)

In [8]:
print(clf.root.right.right)

None


In [14]:
clf.predict(np.array([False, True,  False, False]))

1

In [15]:
clf.predict(np.array([False, False,  False, False]))

0

The tree pattern mainly depends on wether the 1st value(after 0th) is True or False, and it is correctly  able to classify it based on that

In [16]:
clf.root.feature_idx

1