# Decision Tree

This example shows how to use [SciKit-Learn](https://scikit-learn.org/stable/) to train a Decision Tree model on the Titanic dataset. Data is processed to increase the accuracy of the model. For a more detailed explanation of what is Decision Tree is, see [Decision Tree](../document/decision_tree.md).

## Imports

In [None]:
import polars as pl
import seaborn as sns

from typing import Any

from matplotlib.axes import Axes
from matplotlib.text import Annotation
from numpy import ndarray
from polars import DataFrame, LazyFrame
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix, f1_score
from seaborn import heatmap

## Access the Preprocessed Data

The data is preprocessed in the [Data Preprocessing](./data_preprocessing.ipynb) notebook.

In [None]:
train_Xs: LazyFrame = pl.scan_csv("../data/train_Xs.csv")
train_ys: LazyFrame = pl.scan_csv("../data/train_ys.csv")
test_Xs: LazyFrame = pl.scan_csv("../data/test_Xs.csv")

## Train the Decision Tree

In [None]:
X: ndarray[Any, Any] = train_Xs.collect().to_numpy()
y: ndarray[Any, Any] = train_ys.collect().to_numpy()

X_train, X_validate, y_train, y_validate = train_test_split(X, y, test_size=0.2, random_state=73)

#### Find the Best Parameters (Optional)

Use a grid search to find the best parameters for the Decision Tree model. I found that this is sub-optimal and that a simple `DecisionTreeClassifier(max_depth=3)` works best as the Decision Tree seems to over-fit on the training and validation data.

It remains an interesting exercise to find the (not) best parameters for the model.

In [None]:
parameter_grid: dict[str, list[int]] = {
    "max_depth": [2, 3, 4, 5, 6],
    "max_features": [3, 4, 5, 6, 7, 8, 9],
    "min_samples_split": [2, 5, 10],
    "min_samples_leaf": [2, 5, 10],
    "random_state": [7, 19, 37, 53, 73],
}

template_dtc = DecisionTreeClassifier()
grid_search = GridSearchCV(template_dtc, param_grid=parameter_grid, cv=10, scoring="accuracy")
grid_search.fit(X_train, y_train)

print(f"Best parameters: {grid_search.best_params_}")
print(f"Best score: {grid_search.best_score_}")

In [None]:
best_dtc: DecisionTreeClassifier = grid_search.best_estimator_
best_dtc.fit(X_train, y_train)
dtc: DecisionTreeClassifier = best_dtc.fit(X_train, y_train)

#### Use the Best Parameters

In [None]:
dtc = DecisionTreeClassifier(max_depth=3)
dtc.fit(X_train, y_train)

### Plot the Decision Tree

In [None]:
sns.set_theme(font_scale=2, rc={"figure.figsize": (40, 20)})
tree_plot: list[Annotation] = tree.plot_tree(
    dtc,
    feature_names=train_Xs.collect_schema().names(),
    class_names=True,
    filled=True,
    rounded=True,
    proportion=True,
)

### Export the Decision Tree

In [None]:
# Export the decision tree to a dot file
# dot_data = tree.export_graphviz(
#     clf,
#     out_file="resource/decision-tree.dot",
#     feature_names=train_features.columns,
#     class_names=["0", "1"],
#     filled=True,
#     rounded=True,
#     special_characters=True,
# )

### Evalue the Model

In [None]:
# Evaluate the model
y_pred: ndarray = dtc.predict(X_validate)
accuracy: float = accuracy_score(y_validate, y_pred)
precision: float = precision_score(y_validate, y_pred)
recall: float = recall_score(y_validate, y_pred)
f1: float = f1_score(y_validate, y_pred)

print(f"Accuracy: {100 * accuracy:.2f}%")
print(f"Precision: {100 * precision:.2f}%")
print(f"Recall: {100 * recall:.2f}%")
print(f"F1: {100 * f1:.2f}%")

### Plot the Confusion Matrix

In [None]:
sns.axes_style(rc={"xtick.top": True, "axes.spines.top": True})

confusion: ndarray = confusion_matrix(y_validate, y_pred)

plot: Axes = heatmap(
    confusion, annot=True, fmt="d", xticklabels=["Foundered", "Survived"], yticklabels=["Foundered", "Survived"]
)

### Generate Prediction List

In [None]:
predictions: ndarray = dtc.predict(test_Xs.collect().to_numpy())
prediction_list = pl.DataFrame(
    {
        "PassengerId": pl.Series(range(892, 1310)),
        "Survived": pl.Series(predictions),
    }
)
prediction_list.write_csv("../data/decision_tree_predictions.csv")

### Compare the Predictions with the Ground Truth

In [None]:
source = pl.read_csv("../data/decision_tree_predictions.csv")
target = pl.read_csv("../data/gender_submission.csv")

y_source = source["Survived"]
y_target = target["Survived"]

num_differences = (y_source != y_target).sum()
num_difference_percentage = (num_differences / len(y_source)) * 100
num_difference_percentage

----
Go back to [index](_index.ipynb).