talk about robustness and variance of trees

talk about pruning and cut off criteria

structure of model depends on data, unlike other models

cost complexity pruning

look at decision boundary

use mermaid diagram for decision trees

use house votes for tutorial

# Decision Tree Module

You may have noticed that while $k$-NN can work really well as a classifier, it doesn't provide us with any insights about the data itself. For example, $k$-NN doesn't tell us which features are most relevant in determining the class label of an observation. In this module, we introduce the decision tree model: a greedy, divide-and-conquer algorithm that partitions our feature space to create interpretable predictions.

The decision tree model allows us to extract a set of _classification rules_ to classify a given instance. This _ruled-based_ ML approach is similar to creating `IF-ELSE` statements that test different features, resulting in a model that is highly interpretable by humans. Unlike $k$-NN, which acts as a "black box" that simply outputs predictions, decision trees show us exactly how they arrive at their conclusions through a series of logical decisions.

Let's make these claims more concrete with a familiar example. A common parlor game is $20$ Questions, where one person chooses something for the other players to guess. The guessers are allowed to ask $20$ yes-or-no questions to identify what was chosen. If you've ever played this game, you know that the best strategy is to ask questions that give you the most information possible. For instance, if you're trying to guess an animal, asking "Is it a mammal?" is far more informative than asking "Is it a cat?" The first question immediately divides all animals into two groups, eliminating roughly half of the possibilities with a single question. The second question only helps if the answer happens to be "yes." 

So clearly, we should ask if it's a mammal before asking if it's a cat. Now that we have a notion that there are certain questions that are more useful than others, and that there's an ordering to asking these questions, this begs the question: which questions exactly to ask first and how do we quantify how useful each question is in order to chose the best one to ask. For this, we borrow a few ideas from information theory: entropy and information gain.

We define information gain as



Decision trees work in exactly this way: at each step, they choose the feature and threshold that best splits the data into distinct groups, progressively narrowing down the possibilities until an accurate prediction can be made.

This process of sequential splitting creates a tree-like structure where each internal node represents a test on a feature, each branch represents the outcome of that test, and each leaf node represents a class label. By following the path from the root to a leaf, we can see exactly which conditions led to a particular classification. This makes decision trees one of the most transparent ML models.

finish introducing trees

In [None]:
# ! Note: run this cell first to import necessary packages
import sys
from pathlib import Path

project_root = (
    str(Path.cwd().parent) if "notebooks" in str(Path.cwd()) else str(Path.cwd())
)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns
import seaborn.objects as so
from sklearn.datasets import make_circles
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from sklearn.tree import DecisionTreeClassifier

from utils import plot_boundary, plot_decision_tree

In scikit-learn, a decision tree classifier can be fit using the [`DecisionTreeClassifier`](https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html). The code for fitting a decision tree is pretty much the same as for the $k$-NN model.

# House Votes 84
For this excerise, we'll be working with the []()

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.ravel()

depths = [3, 5, 10, 20]

h = 0.02
x_min, x_max = -0.1, 4.1
y_min, y_max = -0.1, 4.1

xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

X, y = make_circles(n_samples=100000, noise=0.1, factor=0.5, random_state=42)
X = (X + 1) * 2

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)

for ax, depth in zip(axes, depths):
    dt = DecisionTreeClassifier(max_depth=depth, random_state=42)
    dt.fit(X_train, y_train)

    Z = dt.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)

    ax.scatter(xx, yy, c=Z, cmap="RdYlGn", alpha=0.8, s=1, marker=".")

    ax.set_xlim(0, 4)
    ax.set_ylim(0, 4)
    ax.set_xlabel(r"$x_1$")
    ax.set_ylabel(r"$x_2$")
    ax.set_title(f"Decision Tree: depth = {depth}")

plt.tight_layout()
plt.show()

Based on our formulation of the decision tree algorithm, you may have noticed that the only stopping criteria for


# Computational Considerations

searching over entire space of decision trees given dataset is NP-hard

hypothesis space of decision trees is space of all boolean functions

decision tree algorithm is form of hill-climbing algorithm

have to use greedy algorithm since computationally infeasible to find best accuracy wrt entire dataset


# Titanic

Here, we'll be using the Titanic dataset.


In [None]:
titanic = pl.read_csv("./data/titanic.csv")

X, y = titanic[:, 1:], titanic[:, 0]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
X_train

In [None]:
tree = DecisionTreeClassifier()

tree.fit(X_train, y_train)
tree.score(X_test, y_test)

In [None]:
mnist = pl.read_parquet("./data/train-00000-of-00001.parquet")

In [None]:
from PIL import Image
import io

# Access the bytes field from the struct
image_bytes = mnist[1, "image"]["bytes"]

# Decode the image bytes using PIL
image = Image.open(io.BytesIO(image_bytes))

plt.imshow(image, cmap="gray")
plt.axis("off")
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import torch
import torch.nn as nn
import torch.optim as optim

np.random.seed(42)
X = np.linspace(-3, 3, 100).reshape(-1, 1)
y = np.sin(X) + 0.1 * np.random.randn(100, 1)

X_tensor = torch.FloatTensor(X)
y_tensor = torch.FloatTensor(y)


class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(1, 20), nn.Tanh(), nn.Linear(20, 20), nn.Tanh(), nn.Linear(20, 1)
        )

    def forward(self, x):
        return self.layers(x)


model = NeuralNetwork()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.scatter(X, y, alpha=0.5, s=10, label="True data")
(line,) = ax1.plot([], [], "r-", linewidth=2, label="NN prediction")
ax1.set_xlim(-3, 3)
ax1.set_ylim(-1.5, 1.5)
ax1.set_xlabel("x")
ax1.set_ylabel("y")
ax1.set_title("Neural Network Learning")
ax1.legend()
ax1.grid(True, alpha=0.3)

loss_history = []
ax2.set_xlim(0, 200)
ax2.set_ylim(0, 1)
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Loss")
ax2.set_title("Training Loss")
ax2.grid(True, alpha=0.3)
(loss_line,) = ax2.plot([], [], "b-", linewidth=2)

epoch_text = fig.text(0.5, 0.95, "", ha="center", fontsize=12, weight="bold")


def init():
    line.set_data([], [])
    loss_line.set_data([], [])
    return line, loss_line, epoch_text


def update(frame):
    optimizer.zero_grad()
    predictions = model(X_tensor)
    loss = criterion(predictions, y_tensor)
    loss.backward()
    optimizer.step()

    loss_history.append(loss.item())

    with torch.no_grad():
        y_pred = model(X_tensor).numpy()
    line.set_data(X, y_pred)

    loss_line.set_data(range(len(loss_history)), loss_history)
    if len(loss_history) > 50:
        ax2.set_ylim(0, max(loss_history[:50]))

    epoch_text.set_text(f"Epoch: {frame} | Loss: {loss.item():.4f}")

    return line, loss_line, epoch_text


anim = FuncAnimation(
    fig, update, frames=200, init_func=init, blit=True, interval=50, repeat=False
)

plt.tight_layout()
plt.show()

anim.save("neural_network_learning.gif", writer="pillow", fps=20)j

![Alt text for the GIF](../assets/neural_network_learning.gif)


<div style="background-color: white; padding: 5px; display: block; margin-left:auto; margin-right:auto; width: fit-content; zoom: 2;">

```mermaid
%%{init: {'theme':'base', 'themeVariables': {'primaryColor':'#e3f2fd','primaryTextColor':'#000','primaryBorderColor':'#1976d2','lineColor':'#000','secondaryColor':'#fff3e0','tertiaryColor':'#c8e6c9','edgeLabelBackground':'#ffffff', 'fontSize':'20px'}, 'flowchart': {'nodeSpacing': 100, 'rankSpacing': 100}}}%%
flowchart TD
    A[Age] --> |Young| B[Student]
    A --> |Middle Age| C[Yes]
    A --> |Older Adult| D[Credit_Score]
    
    B -->|No| E[No]
    B -->|Yes| F[Yes]
    
    D -->|Regular| G[No]
    D -->|Excellent| H[Yes]
    
    style A fill:#e3f2fd,stroke:#1976d2,stroke-width:3px,color:#000
    style B fill:#fff3e0,stroke:#f57c00,stroke-width:2px,color:#000
    style D fill:#fff3e0,stroke:#f57c00,stroke-width:2px,color:#000
    style C fill:#c8e6c9,stroke:#388e3c,stroke-width:2px,color:#000
    style E fill:#c8e6c9,stroke:#388e3c,stroke-width:2px,color:#000
    style F fill:#c8e6c9,stroke:#388e3c,stroke-width:2px,color:#000
    style G fill:#c8e6c9,stroke:#388e3c,stroke-width:2px,color:#000
    style H fill:#c8e6c9,stroke:#388e3c,stroke-width:2px,color:#000
    
    linkStyle 0,1,2,3,4,5,6 stroke:#000,color:#000
```

</div>