In [7]:
import pandas as pd
import numpy as np

class Node:
    def __init__(self, feature=None, threshold=None, left=None, right=None,*, value=None):
        self.feature = feature
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value
        self.label = None  # label attribute for node labels
        
    def is_leaf_node(self):
        return self.value is not None
        
        
        
class DecisionTree:
    def __init__(self, min_samples_split=2):
        self.min_samples_split=min_samples_split
        self.root=None
           
    def fit(self, X, y):
        self.root = self._grow_tree(X, y)
        
    def _grow_tree(self, X, y, best=-1): # 'best' is the best_ratio in loop
        n_samples, n_feats = X.shape
        n_labels = len(np.unique(y))
        
        # check stopping criteria
        if (n_labels==1 or \
            n_samples < self.min_samples_split or \
            best==0):
            leaf_value = self._most_common_label(y)
            return Node(value=leaf_value)
        
        # find best split
        best_feat, best_thr, best_ratio = self._best_split(X, y)
        
        #create child node
        left_idxs, right_idxs = self._split(X[:, best_feat], best_thr)
        left = self._grow_tree(X[left_idxs, :], y[left_idxs], best_ratio)
        right = self._grow_tree(X[right_idxs, :], y[right_idxs], best_ratio)
        
        return Node(best_feat, best_thr, left, right)
    
    
    def _best_split(self, X, y):
        best_gain_ratio = -1
        split_idx, split_threshold = None, None
        
        for idx in range(X.shape[1]):
            X_col = X[:,idx]
            thresholds = np.unique(X_col)
            
            for thres in thresholds:
                # compute H_D(S)
                prob_left = np.sum(X_col >= thres)/len(X_col) # P(left)
                prob_right = 1 - prob_left # P(right)
                if (prob_left==0 or prob_right==0):   # skip slpits with zero split info
                    gain_ratio = 0    # manually makes it 0
                else:
                    H_S = -prob_left*np.log2(prob_left) -prob_right*np.log2(prob_right)
                
                    # compute GainRatio
                    gain_ratio = self._InfoGain(X_col, y, thres)/H_S
                
                if gain_ratio > best_gain_ratio:
                    best_gain_ratio = gain_ratio
                    split_idx = idx
                    split_threshold = thres
                    
        return split_idx, split_threshold, best_gain_ratio
    
    
    def _InfoGain(self, X_col, y, thres):
        # parent entropy
        parent_entropy = self._entropy(y)
        
        # create child
        left_idx, right_idx = self._split(X_col, thres)
        
        if len(left_idx)==0 or len(right_idx)==0: # one of child node is empty
            return 0
        
        # child entropy
        n = len(y)
        n_left, n_right = len(left_idx), len(right_idx)
        left_entropy, right_entropy = self._entropy(y[left_idx]), self._entropy(y[right_idx])
        child_entropy = n_left/n * left_entropy + n_right/n * right_entropy
        
        #InfoGain
        InfoGain = parent_entropy - child_entropy
        return InfoGain
    
    def _split(self, X_col, thres):
        left_idxs = np.argwhere(X_col >= thres).flatten()
        right_idxs = np.argwhere(X_col < thres).flatten()
        return left_idxs, right_idxs
    
    def _entropy(self, y):    
        prob1 = np.sum(y==1)/len(y) # y=1
        prob0 = 1 - prob1 # y=0
        if prob1==0 or prob0==0: #when the node is empty or pure
            return 0
        else:
            return(-prob0*np.log2(prob0) -prob1*np.log2(prob1))
        
    def _most_common_label(self, y):
        if (sum(y==1) >= sum(y==0)): #if no majority class, return y=1
            return 1
        else:
            return 0
        
    ##  prediction  
    def predict(self, X):
        return np.array([self._traverse_tree(x, self.root) for x in X])
    
    def _traverse_tree(self, x, node):
        if node.is_leaf_node():
            return node.value
        
        if x[node.feature] >= node.threshold:
            return self._traverse_tree(x, node.left)
        return self._traverse_tree(x, node.right)
    
    ## tree plot
    def visualize_tree(self):
        def visualize_tree_recursive(node, depth=0):
            indent = "  " * depth
            if node.is_leaf_node():
                print(indent + f"Leaf: Class {node.value}")
            else:
                print(indent + f"X{node.feature+1} >= {node.threshold}")
                visualize_tree_recursive(node.left, depth + 1)
                visualize_tree_recursive(node.right, depth + 1)

        visualize_tree_recursive(self.root)
        
    
    ## boundary plot
    def plot_decision_boundary(self, X, y, feature_names=['feature 1', 'feature 2'], class_names=['0','1']):
        if X.shape[1] != 2:
            raise ValueError("This method only for visualizing 2-D data.")

        if self.root is None:
            raise ValueError("Decision tree needs training.")

        # Create meshgrid for the feature space
        x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
        y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
        xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01), np.arange(y_min, y_max, 0.01))

        # Make predictions for all points in the meshgrid
        Z = self.predict(np.c_[xx.ravel(), yy.ravel()])
        Z = Z.reshape(xx.shape)

        # Plot decision boundary
        plt.figure(figsize=(8, 6))
        plt.contourf(xx, yy, Z, cmap='magma', alpha=0.8)
        

        # Plot points for each class
        plt.scatter(X[:, 0], X[:, 1], c=y, cmap='rainbow', edgecolor='k', s=20)

        plt.xlabel('X1')
        plt.ylabel('X2')
        plt.title('Decision Boundary of Decision Tree')
        
        # remove axis
        if feature_names:
            plt.xticks([])
            plt.yticks([])

        plt.show()
        
        
    ## count nodes number (we count leaf as a node ('leaf node' as defined in Wiki))
    def count_nodes(self):
        def count_nodes_recursive(node):
            if node is None:
                return 0
            return 1 + count_nodes_recursive(node.left) + count_nodes_recursive(node.right)
        
        return count_nodes_recursive(self.root)

        