In [3]:
from pyexpat import features

import numpy as np

In [13]:
class DecisionTreeNode:
    def __init__(self, features=None, left=None,right=None,threshold=None,value=None):
        self.features = features
        self.left = left
        self.right = right
        self.threshold = threshold
        self.value = value

class DecisionTreeClassifier:
    def __init__(self,max_depth = 3):
        self.max_depth = max_depth
        self.root = None

    # create gini impurity function for calculating the purity for left and right node.
    def gini_imp(self,y):
        # initialise gini impurity.
        # formula for gini impurity is 1- sum((proportion of class)**2)
        # proportion = count of class i/ total number of sample
        cls,count = np.unique(y,return_counts=True)
        prop = count/sum(count)
        return 1 - sum(prop**2)

    def best_split(self,x,y):
        # Try all features
        # Try all possible thresholds
        # Compute impurity after split
        # Choose the split that minimizes impurity
        # Recurse on left and right subsets
        best_gini = float('inf')
        best_threshold = None
        best_feature = None

        features = [i for i in range(x.shape[1])]

        # for each features
        for feature in features:
            # all the possible threshold
            threshold = np.unique(x[:,feature])
            for thr in threshold:
                # left node have values which is less than threshold
                left_node = x[:,feature] <= thr # return True or False
                right_node = ~ left_node

                # if a node don't have a values then the split is invalid so we skip the threshold
                if left_node.sum() == 0 or right_node.sum() == 0:
                    continue

                left_gini = self.gini_imp(y[left_node])
                right_gini = self.gini_imp(y[right_node])

                # find the weighted gini impurity for finding the overall goodness of the split.
                # weight gini = (count of left node/n) * left gini impurity + (count of right node/n) * right gini impurity.
                weighted_gini = (left_node.sum()/x.shape[0]) * left_gini + (right_node.sum()/x.shape[0]) * right_gini

                if weighted_gini < best_gini:
                    best_feature = feature
                    best_threshold = thr
                    best_gini = weighted_gini
        return best_threshold, best_feature

    def build_tree(self,x,y,depth=0):
        # if we have only one class or reached the max_depth
        if len(np.unique(y)) == 1 or depth == self.max_depth:
            leaf_val = np.bincount(y).argmax()
            return DecisionTreeNode(value=leaf_val)

        # finding the best features and threshold for current node.
        threshold,feature = self.best_split(x,y)

        if feature is None:
            leaf_val = np.bincount(y).argmax()
            return DecisionTreeNode(value=leaf_val)

        left_node = x[:,feature] <= threshold # here all columns will be include
        right_node = ~left_node

        left_child = self.build_tree(x[left_node],y[left_node],depth+1,)
        right_child = self.build_tree(x[right_node],y[right_node],depth+1)

        return DecisionTreeNode(features=feature,left=left_child,right=right_child,threshold=threshold)

    def fit(self,x,y):
        self.root = self.build_tree(x,y)

    def predict_one(self, x, node):
        if node.value is not None:
            return node.value

        if x[node.features] <= node.threshold:
            return self.predict_one(x,node.left)
        else:
            return self.predict_one(x, node.right)
    def predict(self, X):
        return [self.predict_one(x,self.root) for x in X]




In [17]:
X = np.array([
    [2, 3],
    [1, 1],
    [3, 2],
    [6, 5],
    [7, 8],
    [8, 6]
])
y = np.array([0, 0, 0, 1, 1, 1])
X_test = np.array([
    [2, 2],   # should lean toward class 0
    [7, 7]    # should lean toward class 1
])


In [18]:
tree = DecisionTreeClassifier()
tree.fit(X, y)
pred = tree.predict(X_test)

In [19]:
pred

[np.int64(0), np.int64(1)]