# Decision Tree

This example shows how to use [SciKit-Learn](https://scikit-learn.org/stable/) to train a Decision Tree model on the Titanic dataset. Data is processed to increase the accuracy of the model. For a more detailed explanation of what is Decision Tree is, see [Decision Tree](../document/decision_tree.md).

## Imports

In [None]:
import polars as pl

from typing import Any

from numpy import ndarray
from polars import LazyFrame
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix, f1_score
from seaborn import heatmap

## Process Data
Apply the same processing to the training and testing data.

In [None]:
train_data: LazyFrame = pl.scan_csv("data/train.csv", has_header=True)
test_data: LazyFrame = pl.scan_csv("data/test.csv", has_header=True)

In [None]:
title_map: dict[str, int] = {
    # Common Titles
    "Mr": 1,
    "Ms": 2,
    "Mrs": 2,
    "Mme": 2,
    "Master": 3,
    "Miss": 4,
    "Mlle": 4,
    # Distinguished Titles
    "Capt": 5,
    "Col": 5,
    "Dr": 5,
    "Major": 5,
    "Rev": 5,
    # Royal Titles
    "Countess": 6,
    "Don": 6,
    "Dona": 6,
    "Jonkheer": 6,
    "Lady": 6,
    "Sir": 6
}

### Features

In [None]:
mean_miss_age = 21.5
mean_master_age = 4.5
mean_mrs_age = 35.5
mean_mr_age = 33.0

train_features: LazyFrame = train_data.select(
    sku=pl.col("Pclass").rank(method="dense"),
    n_family=pl.col("SibSp") + pl.col("Parch") + 1,
    origin=pl.col("Embarked").fill_null(strategy="forward").rank(method="dense"),
    title=pl.col("Name")
    .str.extract(r",\s*(\w+)\.\s*")
    .replace_strict(title_map, default=max(title_map.values()) + 1, return_dtype=pl.UInt8),
    has_cabin=pl.col("Cabin").is_not_null().cast(pl.UInt8),
    fare=pl.when(pl.col("Fare").is_null(),)
        .then(pl.col("Fare").mean())
        .when(pl.col("Fare").le(7.91),).then(1)
        .when(pl.col("Fare").is_between(7.91, 14.454, closed='left'),).then(2)
        .when(pl.col("Fare").is_between(14.454, 31.0, closed='left'),).then(3)
        .otherwise(4).cast(pl.UInt8),
    age=pl.when((pl.col("Name").str.contains("Master")) & (pl.col("Age").is_null()),)
        .then(mean_master_age)
        .when((pl.col("Name").str.contains("Miss")) & (pl.col("Age").is_null()),)
        .then(mean_miss_age)
        .when((pl.col("Name").str.contains("Mrs")) & (pl.col("Age").is_null()),)
        .then(mean_mrs_age)
        .when((pl.col("Name").str.contains("Mr")) & (pl.col("Age").is_null()),)
        .then(mean_mr_age)
        .when(pl.col("Age").is_null(),)
        .then(pl.col("Age").mean())
        .otherwise(pl.col("Age")),
)

test_features: LazyFrame = test_data.select(
    sku=pl.col("Pclass").rank(method="dense"),
    n_family=pl.col("SibSp") + pl.col("Parch") + 1,
    origin=pl.col("Embarked").fill_null(strategy="forward").rank(method="dense"),
    title=pl.col("Name")
    .str.extract(r",\s*(\w+)\.\s*")
    .replace_strict(title_map, default=max(title_map.values()) + 1, return_dtype=pl.UInt8),
    has_cabin=pl.col("Cabin").is_not_null().cast(pl.UInt8),
    fare=pl.when(pl.col("Fare").is_null(),)
        .then(pl.col("Fare").mean())
        .when(pl.col("Fare").le(7.91),).then(1)
        .when(pl.col("Fare").is_between(7.91, 14.454, closed='left'),).then(2)
        .when(pl.col("Fare").is_between(14.454, 31.0, closed='left'),).then(3)
        .otherwise(4).cast(pl.UInt8),
    age=pl.when((pl.col("Name").str.contains("Master")) & (pl.col("Age").is_null()),)
        .then(mean_master_age)
        .when((pl.col("Name").str.contains("Miss")) & (pl.col("Age").is_null()),)
        .then(mean_miss_age)
        .when((pl.col("Name").str.contains("Mrs")) & (pl.col("Age").is_null()),)
        .then(mean_mrs_age)
        .when((pl.col("Name").str.contains("Mr")) & (pl.col("Age").is_null()),)
        .then(mean_mr_age)
        .when(pl.col("Age").is_null(),)
        .then(pl.col("Age").mean())
        .otherwise(pl.col("Age")),
)

train_output: LazyFrame = train_data.select(y=pl.col("Survived"))

### Train the Decision Tree

In [None]:
X: ndarray[Any, Any] = train_features.collect().to_numpy()
y: ndarray[Any, Any] = train_output.collect().to_numpy()

X_train, X_validate, y_train, y_validate = train_test_split(X, y, test_size=0.2, random_state=73)

In [None]:
# Train the decision tree classifier
clf = DecisionTreeClassifier(max_depth=3)
clf.fit(X_train, y_train)

In [None]:
# Export the decision tree to a dot file
from matplotlib.text import Annotation


dot_data: str = tree.export_graphviz(
    clf,
    out_file="resource/decision-tree.dot",
    feature_names=train_features.columns,
    class_names=["0", "1"],
    filled=True,
    rounded=True,
    special_characters=True,
)

# View the decision tree
plot_result: list[Annotation] = tree.plot_tree(clf, filled=True, rounded=True, proportion=True)

In [None]:
# Evaluate the model
y_pred: ndarray = clf.predict(X_validate)
accuracy: float = accuracy_score(y_validate, y_pred)
precision: float = precision_score(y_validate, y_pred)
recall: float = recall_score(y_validate, y_pred)
f1: float = f1_score(y_validate, y_pred)
print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1: {f1}")

In [None]:
confusion: ndarray = confusion_matrix(y_validate, y_pred)
print(f"Confusion: {confusion}")
heatmap(confusion, annot=True)

----
Go back to [index](_index.ipynb).