In [None]:
from sklearn.datasets import load_iris
# from sklearn.tree import export_graphviz
import numpy as np

# Redefine the DecisionTreeNode class
class DecisionTreeNode:
    def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
        self.feature = feature
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value

# Redefine the DecisionTreeScratch class
class DecisionTreeScratch:
    def __init__(self, max_depth=3, min_samples_split=2):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.root = None

    def fit(self, X, y):
        self.root = self._grow_tree(X, y, depth=0)

    def predict(self, X):
        return [self._traverse_tree(x, self.root) for x in X]

    def _gini(self, y):
        classes = np.unique(y)
        impurity = 1.0
        for cls in classes:
            p = np.sum(y == cls) / len(y)
            impurity -= p**2
        return impurity

    def _information_gain(self, y, X_col, threshold):
        parent_impurity = self._gini(y)
        left_idxs = X_col <= threshold
        right_idxs = X_col > threshold
        if len(y[left_idxs]) == 0 or len(y[right_idxs]) == 0:
            return 0
        left_impurity = self._gini(y[left_idxs])
        right_impurity = self._gini(y[right_idxs])
        n = len(y)
        n_left, n_right = sum(left_idxs), sum(right_idxs)
        child_impurity = (n_left/n)*left_impurity + (n_right/n)*right_impurity
        return parent_impurity - child_impurity

    def _best_split(self, X, y):
        num_features = X.shape[1]
        best_gain = -np.inf
        split_idx, split_thr = None, None
        for feature in range(num_features):
            thresholds = np.unique(X[:, feature])
            for threshold in thresholds:
                gain = self._information_gain(y, X[:, feature], threshold)
                if gain > best_gain:
                    best_gain = gain
                    split_idx = feature
                    split_thr = threshold
        return split_idx, split_thr

    def _majority_class(self, y):
        return np.bincount(y).argmax()

    def _grow_tree(self, X, y, depth):
        num_samples, num_features = X.shape
        num_classes = len(np.unique(y))

        if num_classes == 1 or num_samples < self.min_samples_split or depth >= self.max_depth:
            return DecisionTreeNode(value=self._majority_class(y))

        split_idx, split_thr = self._best_split(X, y)
        
        if split_idx is None:
            return DecisionTreeNode(value=self._majority_class(y))

        left_idxs = X[:, split_idx] <= split_thr
        right_idxs = X[:, split_idx] > split_thr

        left = self._grow_tree(X[left_idxs], y[left_idxs], depth + 1)
        right = self._grow_tree(X[right_idxs], y[right_idxs], depth + 1)

        return DecisionTreeNode(feature=split_idx, threshold=split_thr, left=left, right=right)

    def _traverse_tree(self, x, node):
        if node.value is not None:
            return node.value
        if x[node.feature] <= node.threshold:
            return self._traverse_tree(x, node.left)
        return self._traverse_tree(x, node.right)

# Print tree helper
def print_tree(node, feature_names, depth=0):
    indent = "  " * depth
    if node.value is not None:
        print(f"{indent}Leaf -> class {node.value}")
    else:
        print(f"{indent}[{feature_names[node.feature]} <= {node.threshold}]")
        print_tree(node.left, feature_names, depth+1)
        print_tree(node.right, feature_names, depth+1)

# Load Iris data
iris = load_iris()
X, y = iris.data, iris.target

# Train the scratch tree
clf = DecisionTreeScratch(max_depth=3)
clf.fit(X, y)

# Print the tree
print_tree(clf.root, iris.feature_names)


[petal length (cm) <= 1.9]
  Leaf -> class 0
  [petal width (cm) <= 1.7]
    [petal length (cm) <= 4.9]
      Leaf -> class 1
      Leaf -> class 2
    [petal length (cm) <= 4.8]
      Leaf -> class 2
      Leaf -> class 2


In [None]:
def to_dot(node, feature_names, node_id=0):
    """ Recursively generate dot representation. Returns string and next id. """
    if node.value is not None:
        return f'{node_id} [label="Leaf: class {node.value}", shape=box];\n', node_id + 1
    else:
        label = f"{feature_names[node.feature]} <= {node.threshold}"
        dot = f'{node_id} [label="{label}"];\n'
        left_dot, next_id = to_dot(node.left, feature_names, node_id+1)
        right_dot, next_id = to_dot(node.right, feature_names, next_id)
        dot += left_dot + right_dot
        dot += f"{node_id} -> {node_id+1};\n"
        dot += f"{node_id} -> {node_id+len(left_dot)//len(str(node_id))};\n"
        return dot, next_id


# import graphviz
# dot_body, _ = to_dot(clf.root, iris.feature_names)
# dot_data = f"digraph Tree {{\n{dot_body}\n}}"
# graphviz.Source(dot_data)


ExecutableNotFound: failed to execute WindowsPath('dot'), make sure the Graphviz executables are on your systems' PATH

<graphviz.sources.Source at 0x165382f8ef0>