## What is a Decision Tree?

A **decision tree** is a method used in machine learning to make predictions.
It works like a **series of yes/no questions** that split the data into smaller and smaller groups, until you reach an answer (a prediction).

* For **regression**, the answer is a **number** (like a price, temperature, or score).
* For **classification**, the answer is a **category** (like "yes/no", "apple/orange").

It’s called a “tree” because the splits and branches look like an upside-down tree!

---

## Why Do We Use Decision Trees?

1. **Easy to Understand**

   * The “if this, then that” logic is simple and visual—humans can read and follow it.
2. **Works for Numbers and Categories**

   * Can handle both regression (numbers) and classification (labels).
3. **No Need for Complex Data Preparation**

   * Handles missing data and different types of features well.
4. **Flexible and Powerful**

   * Can model complex, non-linear relationships.

---

## In short:

* **A decision tree splits your data using simple questions until it’s easy to make a prediction.**
* **We use it because it’s simple, clear, and works for many different types of problems.**

---


## Step-by-Step: Decision Tree Regression (With Tiny Data)

Suppose you have:

| x | y |
| - | - |
| 1 | 2 |
| 3 | 4 |
| 5 | 6 |

---

### **List all possible splits**

You can only split between values of x.
Possible split points are **x = 2** and **x = 4**.

---

### **Try the first split: x = 2**

* **Left group:** x ≤ 2 → (1, 2)
* **Right group:** x > 2 → (3, 4), (5, 6)

**Group Averages:**

* Left: (2) / 1 = **2** (only one point)
* Right: (4 + 6) / 2 = **5**

**Prediction for each group:**

* Left: predict 2
* Right: predict 5

**Calculate errors (use squared error):**

| x | y | Prediction | Error (y - prediction) | Error² |
| - | - | ---------- | ---------------------- | ------ |
| 1 | 2 | 2          | 0                      | 0      |
| 3 | 4 | 5          | -1                     | 1      |
| 5 | 6 | 5          | 1                      | 1      |

**Total error for this split:** 0 + 1 + 1 = **2**

---

### **Try the second split: x = 4**

* **Left group:** x ≤ 4 → (1, 2), (3, 4)
* **Right group:** x > 4 → (5, 6)

**Group Averages:**

* Left: (2 + 4) / 2 = **3**
* Right: (6) / 1 = **6** (only one point)

**Prediction for each group:**

* Left: predict 3
* Right: predict 6

**Calculate errors (use squared error):**

| x | y | Prediction | Error (y - prediction) | Error² |
| - | - | ---------- | ---------------------- | ------ |
| 1 | 2 | 3          | -1                     | 1      |
| 3 | 4 | 3          | 1                      | 1      |
| 5 | 6 | 6          | 0                      | 0      |

**Total error for this split:** 1 + 1 + 0 = **2**

---

### **Choose the best split**

Both splits (x = 2 and x = 4) have **the same total error (2)**.
You can pick either, or use a rule (like always pick the first, or try further splits).

---

### **Assign predictions for new data**

Suppose you pick **split at x = 2**:

* For x = 1: Is 1 ≤ 2? Yes → **predict 2**
* For x = 3: Is 3 ≤ 2? No → **predict 5**
* For x = 5: Is 5 ≤ 2? No → **predict 5**

---

## **Summary Table**

| Step            | What happens?                                                      |
| --------------- | ------------------------------------------------------------------ |
| List splits     | Try all splits between values (here, x = 2 and x = 4)              |
| Make groups     | For each split, separate data into left/right groups               |
| Group average   | For each group, calculate the average y (prediction)               |
| Calculate error | For each point, squared difference from group prediction           |
| Choose split    | Pick the split with the lowest total error                         |
| Predict         | Use the group average for new x values (follow the split question) |

---

# Example

In [1]:
import numpy as np

In [2]:
# data
x = np.array([1, 3, 5])
y = np.array([2, 4, 6])

In [3]:
# step 1: try possible splits (between each pair of x)
split_points = [(x[i] + x[i+1])/2 for i in range(len(x)-1)]
best_split = None
best_error = float('inf')


In [4]:
split_points

[np.float64(2.0), np.float64(4.0)]

In [5]:
for split in split_points:
    left_mask = x <= split
    right_mask = x > split
    
    if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
        continue  # skip if a group would be empty
    
    left_mean = y[left_mask].mean()
    right_mean = y[right_mask].mean()
    
    # assign predictions based on which group each point falls into
    preds = np.where(left_mask, left_mean, right_mean)
    errors = (y - preds) ** 2
    total_error = errors.sum()
    
    print(f"Split at {split:.1f}: left mean={left_mean:.1f}, right mean={right_mean:.1f}, total error={total_error:.2f}")
    
    if total_error < best_error:
        best_error = total_error
        best_split = split
        best_left_mean = left_mean
        best_right_mean = right_mean

print("\nBest split at:", best_split)
print("Left mean prediction:", best_left_mean)
print("Right mean prediction:", best_right_mean)


Split at 2.0: left mean=2.0, right mean=5.0, total error=2.00
Split at 4.0: left mean=3.0, right mean=6.0, total error=2.00

Best split at: 2.0
Left mean prediction: 2.0
Right mean prediction: 5.0


In [6]:
# prediction function
def predict(val):
    if val <= best_split:
        return best_left_mean
    else:
        return best_right_mean

# test predictions
for test_x in [1, 2, 3, 4, 5]:
    print(f"x={test_x} => predicted y={predict(test_x):.1f}")

x=1 => predicted y=2.0
x=2 => predicted y=2.0
x=3 => predicted y=5.0
x=4 => predicted y=5.0
x=5 => predicted y=5.0




##  Decision Tree Regression (CART, Splitting, Pruning)

---

### CART Algorithm

* **What is it?**
  A method for building decision trees that can predict numbers (regression) or categories (classification).

* **How does it work for regression?**

  1. Try every way to split your data.
  2. For each split, calculate how good it is (using error).
  3. Pick the split that gives the lowest error.
  4. Repeat for each group (branch).
  5. Stop when the tree is simple enough.

---

### Splitting Criteria (with math example)

Suppose you have:

| x | y |
| - | - |
| 1 | 2 |
| 3 | 4 |
| 5 | 8 |

#### **Try to split at x = 2**

* **Left group:** x ≤ 2 → (1, 2)
* **Right group:** x > 2 → (3, 4), (5, 8)

**Step 1: Find mean of y for each group**

* Left mean = 2 (only one value)
* Right mean = (4 + 8) / 2 = **6**

**Step 2: Calculate squared error for each group**

| x | y | Pred | (y - pred)² |
| - | - | ---- | ----------- |
| 1 | 2 | 2    | 0           |
| 3 | 4 | 6    | 4           |
| 5 | 8 | 6    | 4           |

* Left group error: 0 (only one value)
* Right group error: 4 + 4 = 8

**Total error for this split:** **8**

---

#### **Try to split at x = 4**

* **Left group:** x ≤ 4 → (1, 2), (3, 4)
* **Right group:** x > 4 → (5, 8)

**Step 1: Mean of y**

* Left mean = (2 + 4) / 2 = **3**
* Right mean = 8

**Step 2: Squared error**

| x | y | Pred | (y - pred)² |
| - | - | ---- | ----------- |
| 1 | 2 | 3    | 1           |
| 3 | 4 | 3    | 1           |
| 5 | 8 | 8    | 0           |

* Left group error: 1 + 1 = 2
* Right group error: 0

**Total error for this split:** **2**

---

#### **Best split?**

* Split at x = 4 has lower error (**2 < 8**).
* So, the tree will split at x = 4!

---

### Tree Pruning (Super Easy)

* If you keep splitting until every group has just one point, the tree "memorizes" the data (overfits).
* **Pruning** means:
  "Stop splitting when groups are small, or when splitting doesn't reduce error by much."

#### Example:

* After our split at x = 4, each group has only one or two points.
* No need to split more — this is a good place to stop.

---

## Summary Table

| Step       | What Happens                                  |
| ---------- | --------------------------------------------- |
| Try splits | Test all possible splits                      |
| Calculate  | For each split, compute mean & squared errors |
| Pick best  | Choose split with the lowest total error      |
| Prune      | Stop splitting when it's not helpful          |

---


In [10]:
import numpy as np

# tiny dataset
x = np.array([1, 3, 5])
y = np.array([2, 4, 8])

# all possible splits (midpoints between unique x values)
split_points = [(x[i] + x[i+1])/2 for i in range(len(x)-1)]
print("Possible split points:", split_points)

results = []
for split in split_points:
    # split into left and right groups
    left_mask = x <= split
    right_mask = x > split

    # group means
    left_mean = y[left_mask].mean()
    right_mean = y[right_mask].mean()

    # predictions and squared errors
    y_pred = np.where(left_mask, left_mean, right_mean)
    errors = (y - y_pred)**2

    total_error = errors.sum()
    results.append((split, total_error, left_mean, right_mean))
    print(f"\nSplit at x = {split:.1f}:")
    print("  Left group x:", x[left_mask], "y:", y[left_mask], "mean:", left_mean)
    print("  Right group x:", x[right_mask], "y:", y[right_mask], "mean:", right_mean)
    print("  Squared errors:", errors)
    print("  Total error:", total_error)

# find best split (lowest total error)
best = min(results, key=lambda tup: tup[1])
print("\nBest split is at x =", best[0])
print("Predictions: left group =", best[2], ", right group =", best[3])
print("Lowest total squared error:", best[1])


Possible split points: [np.float64(2.0), np.float64(4.0)]

Split at x = 2.0:
  Left group x: [1] y: [2] mean: 2.0
  Right group x: [3 5] y: [4 8] mean: 6.0
  Squared errors: [0. 4. 4.]
  Total error: 8.0

Split at x = 4.0:
  Left group x: [1 3] y: [2 4] mean: 3.0
  Right group x: [5] y: [8] mean: 8.0
  Squared errors: [1. 1. 0.]
  Total error: 2.0

Best split is at x = 4.0
Predictions: left group = 3.0 , right group = 8.0
Lowest total squared error: 2.0
