Build a simple Decision Tree Classifier from scratch using information gain (entropy) for the Iris dataset (load_iris).

   * Use only two features (e.g., petal length, petal width) for simplicity.
   * Show how the tree splits data.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris

In [2]:
iris = load_iris()
X = iris.data[:, [2,3]]
y = iris.target

In [18]:
def entropy(y):
    classes, counts = np.unique(y, return_counts=True)
    probs = counts / counts.sum()
    return -np.sum(probs * np.log2(probs))

In [19]:
def info_gain(x, y, threshold):
    parent_entropy = entropy(y)
    left_idx = x <= threshold
    right_idx = x > threshold

    if len(y[left_idx]) == 0 or len(y[right_idx]) == 0:
        return 0
    
    n = len(y)
    n_left, n_right = len(y[left_idx]), len(y[right_idx])
    child_entropy = (n_left/n) * entropy(y[left_idx]) + (n_right/n) * entropy(y[right_idx])

    return parent_entropy - child_entropy

In [20]:
def best_split(X, y):
    best_gain = -1
    best_feature = None
    best_threshold = None

    for i in range(X.shape[1]):
        values = np.unique(X[:, i]) 
        for thr in values:
            gain = info_gain(X[:, i], y, thr)
            if gain > best_gain:
                best_gain = gain
                best_feature = i
                best_threshold = thr
    
    return best_gain, best_feature, best_threshold

In [21]:
def build_tree(X, y, depth=0, max_depth=2):
    if len(np.unique(y)) == 1 or depth == max_depth:
        return np.bincount(y).argmax()
    
    gain, feature, threshold = best_split(X, y)

    if gain == 0:
        return np.bincount(y).argmax()
    
    left_idx = X[:, feature] <= threshold
    right_idx = X[:, feature] > threshold

    left_branch = build_tree(X[left_idx], y[left_idx], depth + 1, max_depth)
    right_branch = build_tree(X[right_idx], y[right_idx], depth + 1, max_depth)

    return {
        'feature': feature,
        'threshold': threshold,
        'left': left_branch,
        'right': right_branch
    }

In [22]:
tree = build_tree(X, y)
print(tree)

{'feature': 0, 'threshold': np.float64(1.9), 'left': np.int64(0), 'right': {'feature': 1, 'threshold': np.float64(1.7), 'left': np.int64(1), 'right': np.int64(2)}}


In [26]:
def print_tree(tree, depth=0):
    indent = "  " * depth  # indentation for visual levels
    if isinstance(tree, dict):  # not a leaf
        print(f"{indent}Feature {tree['feature']} <= {tree['threshold']:.2f}?")
        print(f"{indent}├─ Yes:")
        print_tree(tree['left'], depth + 1)
        print(f"{indent}└─ No:")
        print_tree(tree['right'], depth + 1)
    else:
        print(f"{indent}→ Predict class: {tree}")
        
print_tree(tree)


Feature 0 <= 1.90?
├─ Yes:
  → Predict class: 0
└─ No:
  Feature 1 <= 1.70?
  ├─ Yes:
    → Predict class: 1
  └─ No:
    → Predict class: 2
