# 03 — Decision Tree Classifier (Fill-in-the-Blanks)

## Objective

Learn the supervised classification workflow from dataset setup through model evaluation.

## Install libraries

Use `pip install -r requirements.txt` from the repo root if needed. This notebook does not run pip commands.

## Imports

In [None]:
# Concept: import libraries for classification workflow
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

## Load / Create dataset

In [None]:
# Concept: create a simple customer dataset
df = pd.DataFrame({
    "Age": [22, 25, 28, 35, 40, 45, 52, 23, 27, 31, 38, 42, 48, 55, 29, 34, 41, 46, 53, 26],
    "Salary": [25000, 28000, 32000, 45000, 50000, 62000, 70000, 26000, 30000, 42000, 52000, 58000, 65000, 80000, 34000, 46000, 56000, 63000, 76000, 31000],
    "Buy": [0,0,0,1,1,1,1,0,0,1,1,1,1,1,0,1,1,1,1,0]
})
df.head()

## Separate features and target

In [None]:
# TODO: create X and y
X = df[["Age", "Salary"]]
y = df["Buy"]
X.head(), y.head()

## Train-test split

In [None]:
# TODO: split train/test
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

## Check yourself

In [None]:
print("X shape:", X.shape)
print("y shape:", y.shape)
print("X_train shape:", X_train.shape)
print("X_test shape:", X_test.shape)
print("X type:", type(X))
print("y type:", type(y))
print("NaNs in X:", int(X.isna().sum().sum()))
print("NaNs in y:", int(y.isna().sum()))

## Create model

In [None]:
# TODO: instantiate model with params
model = DecisionTreeClassifier()
model

## Train model

In [None]:
# TODO: fit model
model.fit(X_train, y_train)

## Make predictions

In [None]:
# TODO: predict on test set
y_pred = model.predict(X_test)
pd.DataFrame({"Actual": y_test.values, "Predicted": y_pred})

## Evaluate model

In [None]:
# TODO: compute metrics
accuracy = accuracy_score(y_test, y_pred)
cm = confusion_matrix(y_test, y_pred)
report = classification_report(y_test, y_pred)

print(f"Accuracy: {accuracy:.4f}")
print("Confusion Matrix:
", cm)
print("Classification Report:
", report)

## Predict new data

In [None]:
# TODO: predict a new customer
new_customer = pd.DataFrame([[40, 48000]], columns=["Age", "Salary"])
new_prediction = model.predict(new_customer)[0]
print("WILL BUY" if new_prediction == 1 else "WILL NOT BUY")

## Visualization (if applicable)

In [None]:
# Concept: quick scatter of dataset and highlight new customer
plt.figure(figsize=(6, 4))
colors = df["Buy"].map({0: "tomato", 1: "seagreen"})
plt.scatter(df["Age"], df["Salary"], c=colors, alpha=0.75, label="Training examples")
plt.scatter(new_customer["Age"], new_customer["Salary"], c="blue", marker="X", s=120, label="New customer")
plt.xlabel("Age")
plt.ylabel("Salary")
plt.title("Age + Salary vs Buy")
plt.legend()
plt.tight_layout()
plt.show()

## Core Concepts

- **Splits and branches**: the tree asks feature threshold questions to separate classes.
- **Interpretability**: path-to-leaf rules are easy to explain.
- **Overfitting risk**: deep trees can memorize training noise.

## Common Pitfalls

- Feature-name mismatch: predict with DataFrame columns `Age` and `Salary` to avoid warnings.
- Shape mismatches: model expects exactly 2 features in the same order as training.
- Plotting gotcha: use scatter for 2-feature classification, not a single fitted line.

## Smoke Tests (must pass)

In [None]:
assert X.shape[1] == 2
assert len(y_pred) == len(y_test)
assert 0.0 <= accuracy <= 1.0
print("✅ Smoke tests passed")

## Further Reading

- [LinearRegression (scikit-learn)](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html)
- [KNeighborsClassifier (scikit-learn)](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html)
- [DecisionTreeClassifier (scikit-learn)](https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html)
- [RandomForestClassifier (scikit-learn)](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html)
- [train_test_split (scikit-learn)](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html)
- [accuracy_score (scikit-learn)](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html)
- [confusion_matrix (scikit-learn)](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html)
- [classification_report (scikit-learn)](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html)
- [mean_squared_error (scikit-learn)](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html)
- [r2_score (scikit-learn)](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html)
- [Pandas indexing guide](https://pandas.pydata.org/docs/user_guide/indexing.html)
- [Matplotlib scatter](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html)
