In [1]:
import numpy as np

class DecisionTree:
    def __init__(self, max_depth=3, min_samples_split=2):
        self.max_depth = max_depth          
        self.min_samples_split = min_samples_split  
        self.tree = None                    

    class Node:
        def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
            self.feature = feature         
            self.threshold = threshold     
            self.left = left             
            self.right = right              
            self.value = value             
    def fit(self, X, y):
      
        self.tree = self._grow_tree(X, y)

    def _gini(self, y):
        
        classes = np.unique(y)
        impurity = 1.0
        for c in classes:
            p = np.mean(y == c)
            impurity -= p ** 2
        return impurity

    def _best_split(self, X, y):
        
        best_gini = float('inf')
        best_feature, best_threshold = None, None

        for feature in range(X.shape[1]): 
            thresholds = np.unique(X[:, feature])  
            for threshold in thresholds:
                # Split data
                left_mask = X[:, feature] <= threshold
                right_mask = ~left_mask
                
                # Skip if split doesn't divide data
                if len(y[left_mask]) == 0 or len(y[right_mask]) == 0:
                    continue
                
                # Calculate weighted Gini impurity
                gini = (len(y[left_mask]) * self._gini(y[left_mask]) + 
                        len(y[right_mask]) * self._gini(y[right_mask])) / len(y)
                
                # Update best split if current one is better
                if gini < best_gini:
                    best_gini = gini
                    best_feature = feature
                    best_threshold = threshold

        return best_feature, best_threshold

    def _grow_tree(self, X, y, depth=0):
        """Recursively grow the tree"""
        n_samples, n_features = X.shape
        n_classes = len(np.unique(y))

        # Stopping conditions:
        # 1. Reached max depth
        # 2. Too few samples to split
        # 3. All samples belong to one class (pure node)
        if (depth >= self.max_depth or 
            n_samples < self.min_samples_split or 
            n_classes == 1):
            return self.Node(value=np.argmax(np.bincount(y)))

        # Find best split
        feature, threshold = self._best_split(X, y)
        
        # If no split improves purity, return leaf node
        if feature is None:
            return self.Node(value=np.argmax(np.bincount(y)))

        # Split data
        left_mask = X[:, feature] <= threshold
        right_mask = ~left_mask
        
        # Recursively grow left and right subtrees
        left = self._grow_tree(X[left_mask], y[left_mask], depth + 1)
        right = self._grow_tree(X[right_mask], y[right_mask], depth + 1)
        
        return self.Node(feature=feature, threshold=threshold, left=left, right=right)

    def predict(self, X):
        """Predict class for input samples"""
        return np.array([self._predict(x, self.tree) for x in X])

    def _predict(self, x, node):
        """Traverse tree to make prediction for a single sample"""
        if node.value is not None:  # Leaf node
            return node.value
        if x[node.feature] <= node.threshold:
            return self._predict(x, node.left)
        else:
            return self._predict(x, node.right)