In [50]:
# here we will be building a decision tree to classify whether a student will pass an exam based on attendance and study hours per week
import numpy as np

#here is the synthetic dataset that we will be using for this

# Features: [Study Hours per Week, Attendance %]
# Labels: 0 = Fail, 1 = Pass

X = np.array([
    [2, 60],   # studied little, low attendance → Fail
    [3, 65],   # little study, low attendance → Fail
    [5, 70],   # average study, average attendance → Fail
    [7, 80],   # good study, good attendance → Pass
    [8, 85],   # high study, good attendance → Pass
    [10, 90],  # very high study, excellent attendance → Pass
])
y = np.array([0, 0, 0, 1, 1, 1])

In [51]:
# here we can see we can split the example as something like, if study hours<=6, then fail
#but we have to do it mathematically as our model should learn to split the features here
#now lets implement it

# lets define gini impurity

def gini_from_counts(counts):    #this is when we have the count, say out of 10 samples we have[6, 4] for pass and fail

    total=counts.sum()
    prob=counts/total
    if total==0:
        return 0.0

    return 1.0 - np.sum(prob**2)

#this is when we do not have the count

def gini(y):
    counts=np.bincount(y)
    counts=np.array(counts)
    return gini_from_counts(counts)

In [92]:
# now lets implement the class Node

class Node:
    def __init__(self, *, feature=None, threshold=None, gini=0.0, left=None, right=None, value=None, prob=None, n_samples=None):
        self.feature=feature
        self.threshold=threshold
        self.gini=gini
        self.right=right
        self.left= left
        self.value=value
        self.prob=prob
        self.n_samples=n_samples

    def is_leaf(self):
            return self.value is not None

In [93]:
# now lets write the main algo for finding the best threshold, best split and returning the gain( gain has to be max for the best split)

def best_split( X, y, n_classes):

   n_samples, n_features=X.shape
   parent_counts=np.bincount(y)
   parent_gini=gini_from_counts(parent_counts)
   best_gain=0
   best_feature=None
   best_threshold=None

   # for each feature calculate the cumulative class count

   for feat in range(n_features):
    col=X[:, feat]
    order=np.argsort(col)
    col_sorted=col[order]
    y_sorted=y[order]
        
    counts_per_class=np.zeros((n_classes, n_samples), dtype=int)
    for cls in range(n_classes):
        counts_per_class[cls]=np.cumsum(y_sorted==cls)

    total_count=parent_counts

    for i in range(n_samples-1):
        if(col_sorted[i]==col_sorted[i+1]):
            continue
        left_counts=counts_per_class[:, i]
        right_counts=total_count-left_counts

        left_n=i+1
        right_n=n_samples-left_n

        left_gini=gini_from_counts(left_counts)
        right_gini=gini_from_counts(right_counts)

        weighted_gini=(left_n/n_samples)*left_gini  + (right_n/n_samples)*right_gini
        gain=parent_gini-weighted_gini

        if gain>best_gain:
            best_gain=gain
            best_feature=feat
            best_threshold=( col_sorted[i]+col_sorted[i+1])/2.0

    return best_gain, best_feature, best_threshold

In [94]:
# now that we have the algo for getting the best split, we now need to use that to build the tree

def build_tree( X, y, * , max_depth=10, min_samples_split=2, min_samples_leaf=1, min_impurity_decrease=1e-7, depth=0, n_classes=None ):
    n_samples, n_features=X.shape
    if n_classes is None:
        n_classes=len(np.unique(y))
    node_counts=np.bincount(y)
    node_gini=gini_from_counts(node_counts)
    majority_class=int(np.argmax(node_counts))
    if (len(np.unique(y)) == 1) or (depth >= max_depth) or (n_samples < min_samples_split):
        prob = node_counts / node_counts.sum()
        return Node(value=majority_class, prob=prob, n_samples=n_samples, gini=node_gini)
    
    gain, feat, thresh = best_split(X, y, n_classes)

    if feat is None or gain<=min_impurity_decrease:
        prob=node_counts/node_counts.sum()
        return Node(value=majority_class, prob=prob, n_samples=n_samples, gini=node_gini)
    
    left_mask=X[:, feat]<=thresh
    right_mask=~left_mask

    if left_mask.sum()<min_samples_leaf or right_mask.sum()<min_samples_leaf:
        prob=node_counts/node_counts.sum()
        return Node(value=majority_class, prob=prob, n_samples=n_samples, gini=node_gini)
    
    left_node=build_tree(X[left_mask], y[left_mask], max_depth=max_depth, min_samples_split=min_samples_split,
                           min_samples_leaf=min_samples_leaf, min_impurity_decrease=min_impurity_decrease,
                           depth=depth + 1, n_classes=n_classes)
    right_node=build_tree(X[right_mask], y[right_mask], max_depth=max_depth, min_samples_split=min_samples_split,
                           min_samples_leaf=min_samples_leaf, min_impurity_decrease=min_impurity_decrease,
                           depth=depth + 1, n_classes=n_classes)
    
    return Node(feature=feat, threshold=thresh, left=left_node, right=right_node, value=None, prob=None, n_samples=n_samples, gini=node_gini)



In [95]:
# prediction

def predict_one(x, node):
    while not node.is_leaf():
        if x[node.feature] <= node.threshold:
            node = node.left
        else:
            node = node.right
    return node.value

def predict(X, node):
    return np.array([predict_one(x, node) for x in X])

In [98]:
def print_tree(node, depth=0):
    prefix = "  " * depth
    if node.is_leaf():
        print(f"{prefix}Leaf: predict={node.value}, prob={np.round(node.prob,3)}, n={node.n_samples}, gini={node.gini:.3f}")
    else:
        print(f"{prefix}Node: feat={node.feature}, thresh={node.threshold:.4f}, n={node.n_samples}, gini={node.gini:.3f}")
        print_tree(node.left, depth + 1)
        print_tree(node.right, depth + 1)

In [99]:
tree = build_tree(X, y, max_depth=3, min_samples_split=2)
print_tree(tree)

preds = predict(X, tree)
print("preds:", preds)
print("accuracy:", (preds == y).mean())

Node: feat=0, thresh=6.0000, n=6, gini=0.500
  Leaf: predict=0, prob=[1.], n=3, gini=0.000
  Leaf: predict=1, prob=[0. 1.], n=3, gini=0.000
preds: [0 0 0 1 1 1]
accuracy: 1.0
