In [1]:
import numpy as np

In [3]:
class DecisionTree:
    def __init__(self, d=None):
        self.max_depth = d
        self.tree = None


    def fit(self, X, y):
        self.tree = self._build_tree(X, y)


    def predict(self, X):
        return np.array([self.pred(inputs) for inputs in X])


    def _gini(self, y):
        #calculating the gini impurity
        m = len(y)
        return 1.0 - sum((np.sum(y == c) / m) ** 2 for c in np.unique(y))


    def _split(self, X, y, index, value):
        #splitting the node
        left_mask = X[:, index] <= value
        right_mask = X[:, index] > value
        return X[left_mask], X[right_mask], y[left_mask], y[right_mask]


    def _best_split(self, X, y):
        #initializing the gini value to infinity (worst case)
        best_gini = float('inf')
        best_index, best_value = None, None

        for index in range(X.shape[1]):
            #looping through all the unique values in the column
            for value in np.unique(X[:, index]):
                #splitting
                X_left, X_right, y_left, y_right = self._split(X, y, index, value)
                #checking if either splitted node is empty i.e. the splitting didn't occur
                if len(y_left) == 0 or len(y_right) == 0:
                    continue
                #calculating the gini impurity
                gini = (len(y_left) * self._gini(y_left) + len(y_right) * self._gini(y_right)) / len(y)
                #updating the best gini and the value that it belongs to
                if gini < best_gini:
                    best_gini, best_index, best_value = gini, index, value
        return best_index, best_value


    def _build_tree(self, X, y, depth=0):
        #condition that terminates the recursive building of the tree
        if len(np.unique(y)) == 1 or (self.max_depth and depth >= self.max_depth):
            return np.argmax(np.bincount(y))

        #finding the best attribute to make the node that we would split
        index, value = self._best_split(X, y)
        if index is None:
            return np.argmax(np.bincount(y))

        #splitting the node into left and right sub nodes
        X_left, X_right, y_left, y_right = self._split(X, y, index, value)

        #recursively creating the left subtree
        left_subtree = self._build_tree(X_left, y_left, depth + 1)
        #recursively creating the right subtree
        right_subtree = self._build_tree(X_right, y_right, depth + 1)

        return (index, value, left_subtree, right_subtree)


    def pred(self, inputs):
        node = self.tree
        while isinstance(node, tuple):
            if inputs[node[0]] <= node[1]:
                node = node[2]
            else:
                node = node[3]
        return node


In [4]:
X = np.array([[1, 1], [1, 0], [0, 1], [0, 0]])
y = np.array([1, 1, 0, 0])

tree = DecisionTree(d=2)
tree.fit(X, y)
predictions = tree.predict(X)

print("Predictions:", predictions)


Predictions: [1 1 0 0]
