# Decision Trees: Real-World Example
## Customer Churn Prediction

You already know how decision trees work conceptually. Today, we're seeing them in action on real data.

**Goal:** Understand how trees make decisions by visualizing actual splits and manually tracing predictions through the tree.

---

## Setup: Import Libraries

We'll use pandas for data, scikit-learn for the tree, and matplotlib for visualization.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import seaborn as sns

# Set random seed for reproducibility
np.random.seed(42)

# Make plots look nice
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## Part 1: Load & Explore the Data

We're using a real telecom customer churn dataset. The question: which customers are likely to leave (churn)?

This matters because acquiring new customers is expensive. If a company can predict who's about to leave, they can offer retention deals.

In [None]:
# Load the telecom churn dataset directly from GitHub
url = 'https://raw.githubusercontent.com/TUHHStartupEngineers/dat_sci_ss20/master/13/WA_Fn-UseC_-Telco-Customer-Churn.csv'
df = pd.read_csv(url)

print(f"Dataset shape: {df.shape}")
print(f"\nFirst few rows:")
print(df.head())
print(f"\nColumn names and types:")
print(df.dtypes)

### Quick EDA: Understand the Target

Let's see what we're predicting: Do customers churn or not?

In [None]:
# Check the Churn column
print(f"Churn column unique values: {df['Churn'].unique()}")
print(f"\nChurn distribution:")
print(df['Churn'].value_counts())
churn_rate = (df['Churn'] == 'Yes').mean()
print(f"\nChurn rate: {churn_rate:.1%}")

### Prepare Data for the Tree

Decision trees need numeric features. We'll:
1. Drop columns we don't need (identifiers, etc.)
2. Encode categorical features (like 'Contract') to numbers
3. Convert target to binary (Yes=1, No=0)
4. Split into train/test


In [None]:
# Make a copy to avoid warnings
df_clean = df.copy()

# Drop non-informative columns (customerID)
df_clean = df_clean.drop(columns=['customerID'])

# Convert Churn to binary (Yes=1, No=0)
y = (df_clean['Churn'] == 'Yes').astype(int)
X = df_clean.drop(columns=['Churn'])

# Encode categorical columns to numeric
le_dict = {}
for col in X.columns:
    if X[col].dtype == 'object':
        le = LabelEncoder()
        X[col] = le.fit_transform(X[col])
        le_dict[col] = le

# Fix TotalCharges column (it has some spaces, convert to numeric)
X['TotalCharges'] = pd.to_numeric(X['TotalCharges'], errors='coerce')
X['TotalCharges'].fillna(X['TotalCharges'].median(), inplace=True)

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Training set: {X_train.shape}")
print(f"Test set: {X_test.shape}")
print(f"\nFeatures we're using:")
print(X.columns.tolist())

---

## Part 2: Train & Visualize a Decision Tree

Now we train a decision tree on the churn data. We'll keep it shallow (max_depth=5) so we can actually *see* and understand the splits.

In [None]:
# Train a decision tree
tree = DecisionTreeClassifier(max_depth=5, random_state=42)
tree.fit(X_train, y_train)

# Check accuracy
train_acc = tree.score(X_train, y_train)
test_acc = tree.score(X_test, y_test)

print(f"Training accuracy: {train_acc:.3f}")
print(f"Test accuracy: {test_acc:.3f}")

### Visualize the Tree

Here's the actual tree. Read it top-to-bottom:
- Each box is a **decision node**: a question like "Is tenure <= 24.5?"
- The arrows point to what happens based on YES (left) or NO (right)
- The **leaf nodes** (bottom) show: what class do we predict? How many samples? How many of each class?

Try to trace one path from root to leaf. That's exactly how the tree makes a decision.

In [None]:
# Visualize the tree
plt.figure(figsize=(25, 12))
plot_tree(
    tree,
    feature_names=X.columns,
    class_names=['Stay', 'Churn'],
    filled=True,
    fontsize=10
)
plt.title('Decision Tree for Customer Churn (max_depth=5)', fontsize=16, pad=20)
plt.tight_layout()
plt.show()

print("How to read the tree:")
print("- Top: Root node (initial split)")
print("- YES (left arrow) or NO (right arrow)")
print("- Color: Darker green = more 'Stay', darker red = more 'Churn'")
print("- Leaf (bottom): Prediction class, samples in leaf, value [#Stay, #Churn]")

---

## Part 3: Manually Trace a Prediction

Let's pick one customer and trace them through the tree manually. This shows exactly how trees make decisions.

In [None]:
# Pick one customer from the test set
test_idx = 5
customer = X_test.iloc[test_idx]
actual_churn = y_test.iloc[test_idx]

print("=" * 70)
print(f"CUSTOMER #{test_idx}")
print("=" * 70)
print(f"\nCustomer profile (selected features):")
important_features = ['tenure', 'MonthlyCharges', 'TotalCharges', 'Contract', 'InternetService']
for col in important_features:
    if col in X.columns:
        print(f"  {col}: {customer[col]:.2f}")
print(f"\nActual outcome: {'CHURNED' if actual_churn == 1 else 'STAYED'}")

### Trace the Path

Now let's manually walk this customer through the tree, following each split. This is how the tree actually makes the prediction!

In [None]:
# Make a prediction and show the tree's reasoning
prediction = tree.predict([customer])[0]
prediction_proba = tree.predict_proba([customer])[0]

print("\nTREE'S DECISION PATH:")
print("=" * 70)

# Get decision path
decision_path = tree.decision_path([customer]).toarray()[0]
node_index = np.where(decision_path)[0]

print(f"\nFollowing the path through the tree:")
step = 1
for node_id in node_index:
    feature = tree.tree_.feature[node_id]
    threshold = tree.tree_.threshold[node_id]

    if feature != -2:  # -2 means leaf node
        feature_name = X.columns[feature]
        customer_value = customer[feature_name]
        direction = "YES (<=)" if customer_value <= threshold else "NO (>)"
        print(f"\n  Step {step}: Is {feature_name} <= {threshold:.2f}?")
        print(f"            {feature_name} = {customer_value:.2f}")
        print(f"            Answer: {direction}")
        step += 1
    else:
        # Leaf node
        value = tree.tree_.value[node_id][0]
        stay_count = int(value[0])
        churn_count = int(value[1])
        print(f"\n  \"\"\" LEAF NODE (Decision Made!) \"\"\"")
        print(f"  In this group: {stay_count} stayed, {churn_count} churned")
        print(f"  Prediction: {'CHURN' if churn_count > stay_count else 'STAY'}")
        print(f"  Confidence: {max(prediction_proba)*100:.1f}%")

print("\n" + "=" * 70)
print(f"TREE PREDICTION: {'CHURNED' if prediction == 1 else 'STAYED'}")
print(f"ACTUAL OUTCOME: {'CHURNED' if actual_churn == 1 else 'STAYED'}")
print(f"✓ CORRECT!" if prediction == actual_churn else "✗ WRONG")

---

## Part 4: Feature Importance

Which features does the tree rely on most when making decisions?

The tree will learn to split on features that best separate churners from non-churners. If a feature appears high in the tree, it's important.

In [None]:
# Extract feature importance
importances = tree.feature_importances_
feature_importance_df = pd.DataFrame({
    'feature': X.columns,
    'importance': importances
}).sort_values('importance', ascending=False)

print("Feature Importance (Top 10):")
print(feature_importance_df.head(10).to_string(index=False))

### Visualize Feature Importance

Which features matter most for predicting churn? The tree learned this automatically!

In [None]:
# Plot top 10 features
top_n = 10
top_features = feature_importance_df.head(top_n)

plt.figure(figsize=(10, 6))
plt.barh(range(len(top_features)), top_features['importance'], color='steelblue')
plt.yticks(range(len(top_features)), top_features['feature'])
plt.xlabel('Importance')
plt.title(f'Top {top_n} Most Important Features for Predicting Churn')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

print(f"\nThe tree relies most on:")
for idx, (_, row) in enumerate(top_features.head(3).iterrows(), 1):
    print(f"  {idx}. {row['feature']} ({row['importance']:.1%})")

---

## YOUR TURN: Pairs Practice (20 minutes)

Now you build your own tree. Work in pairs.

### Task 1: Train & Visualize Your Own Tree (10 min)

- Train a decision tree with max_depth=5 on the same churn data
- Visualize it
- **Compare to the demo tree above:** Are the splits different? If so, why?

### Task 2: Reflect (5 min)

- **What surprised you about the tree's splits?**
- **If you were a telecom company using this tree, what would you do?** (Who would you target for retention offers?)

### Task 3: Debug & Ask Questions (5 min)

Stuck? Ask questions. This is the time to explore.

---

## Student Practice: Code Along Below

Use the cells below to complete the tasks. Start fresh if you want, or modify the demo code.

In [None]:
# TASK 1a: Train your own tree (use max_depth=5, same as demo)
# TODO: Fit a DecisionTreeClassifier on X_train and y_train
# Hint: Use DecisionTreeClassifier(max_depth=5, random_state=42).fit(X_train, y_train)

my_tree = None  # REPLACE THIS with your tree

if my_tree:
    my_train_acc = my_tree.score(X_train, y_train)
    my_test_acc = my_tree.score(X_test, y_test)
    print(f"Your training accuracy: {my_train_acc:.3f}")
    print(f"Your test accuracy: {my_test_acc:.3f}")
else:
    print("TODO: Train your tree in this cell!")

In [None]:
# TASK 1b: Visualize your tree
# TODO: Create a plot_tree visualization of my_tree
# Hint: Copy the visualization code from Part 2, but use my_tree instead of tree

if my_tree:
    plt.figure(figsize=(25, 12))
    # YOUR CODE HERE
    # plot_tree(my_tree, ...)
    plt.show()
else:
    print("Train your tree first (cell above)")

In [None]:
# TASK 1c: Extract and plot feature importance from YOUR tree
# TODO: Get feature_importances_ from my_tree and visualize

if my_tree:
    my_importances = my_tree.feature_importances_
    my_importance_df = pd.DataFrame({
        'feature': X.columns,
        'importance': my_importances
    }).sort_values('importance', ascending=False)

    # Plot top 10 features
    top_features = my_importance_df.head(10)
    plt.figure(figsize=(10, 6))
    plt.barh(range(len(top_features)), top_features['importance'], color='darkgreen')
    plt.yticks(range(len(top_features)), top_features['feature'])
    plt.xlabel('Importance')
    plt.title('Your Tree: Top 10 Most Important Features')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.show()
else:
    print("Train your tree first")

## Reflection Questions (TASK 2)

**Discuss these with your pair partner. Write answers as comments below:**

1. How does your tree compare to the demo tree? Are the top splits different?
2. What surprised you about which features are important?
3. If you were the telecom CEO, how would you use this tree to reduce churn?
4. What could go wrong if you relied on this tree blindly?

In [None]:
# TASK 2: Reflection
# Type your thoughts below as comments

# 1. Comparison to demo tree:
#    YOUR ANSWER HERE

# 2. Surprising features:
#    YOUR ANSWER HERE

# 3. As a CEO, I would:
#    YOUR ANSWER HERE

# 4. What could go wrong:
#    YOUR ANSWER HERE

---

## Summary

**Key takeaways from this notebook:**

1. **Trees are interpretable**: You can see *exactly* how they decide—no black boxes
2. **Manual tracing matters**: Walking through splits shows how a tree thinks about one customer
3. **Feature importance is useful**: But it's not the whole story—you need context
4. **Trees are unstable**: Your tree's splits might differ from the demo tree. Small data changes → different trees. This is a key limitation.

**This is where Random Forests come in.** By training many trees and averaging, we get:
- ✓ Better accuracy (reduces overfitting)
- ✓ More stable predictions
- ✓ Still interpretable (via feature importance)

Next: Random forests address tree instability and show why ensembles are powerful.