<img src="datamecum_logo.png" align="right" style="float" width="400">
<font color="#CA3532"><h1 align="left">Programa técnico intensivo en data science. Datamecum.</h1></font>
<font color="#6E6E6E"><h2 align="left">Decision Trees</h2></font> 
<font color="#6E6E6E"><h2 align="left">Podado del árbol</h2></font>

In [1]:
import os

import pandas as pd
import matplotlib.pyplot as plt
import joblib

from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import ParameterGrid, GridSearchCV
from sklearn.metrics import make_scorer, accuracy_score, classification_report

In [2]:
DATA_DIR = os.path.join(".", "data")
TRAIN_CSV_PATH = os.path.join(DATA_DIR, "train.csv")
TEST_CSV_PATH = os.path.join(DATA_DIR, "test.csv")
CLASS_NAMES_PATH = os.path.join(DATA_DIR, "class_names.joblib")

In [3]:
train_df = pd.read_csv(TRAIN_CSV_PATH)
X_train = train_df.drop("target", axis=1).values
y_train = train_df["target"].values

In [4]:
test_df = pd.read_csv(TEST_CSV_PATH)
X_test = test_df.drop("target", axis=1).values
y_test = test_df["target"].values

In [5]:
feature_names = test_df.drop("target", axis=1).columns

In [6]:
class_names = joblib.load(CLASS_NAMES_PATH)

### Full Tree

In [7]:
full_tree = DecisionTreeClassifier(random_state=42)
full_tree.fit(X_train, y_train)

In [None]:
plt.figure(figsize=(4, 4), dpi=1000)
plot_tree(full_tree, feature_names=feature_names, class_names=class_names, filled=True)
plt.show()

In [8]:
print(classification_report(y_test, full_tree.predict(X_test)))

              precision    recall  f1-score   support

           0       0.85      0.93      0.89        42
           1       0.96      0.90      0.93        72

    accuracy                           0.91       114
   macro avg       0.90      0.92      0.91       114
weighted avg       0.92      0.91      0.91       114



### Best Depth Tree

In [9]:
max_depth = full_tree.get_depth()

In [10]:
max_depth

7

In [12]:
max_depth_grid_search = GridSearchCV(
    estimator=DecisionTreeClassifier(random_state=42),
    scoring=make_scorer(accuracy_score),
    param_grid=ParameterGrid(
        {"max_depth": [[max_depth] for max_depth in range(1, max_depth + 1)]}
    ),
)

In [13]:
max_depth_grid_search.fit(X_train, y_train)

In [14]:
max_depth_grid_search.best_params_

{'max_depth': 4}

In [15]:
best_max_depth_tree = max_depth_grid_search.best_estimator_

In [16]:
best_max_depth = best_max_depth_tree.get_depth()

In [None]:
plt.figure(figsize=(4, 4), dpi=1000)
plot_tree(
    best_max_depth_tree,
    feature_names=feature_names,
    class_names=class_names,
    filled=True,
)
plt.show()

In [18]:
print(classification_report(y_test, best_max_depth_tree.predict(X_test)))

              precision    recall  f1-score   support

           0       0.91      0.93      0.92        42
           1       0.96      0.94      0.95        72

    accuracy                           0.94       114
   macro avg       0.93      0.94      0.93       114
weighted avg       0.94      0.94      0.94       114



### Pruned Tree

In [20]:
ccp_alphas = full_tree.cost_complexity_pruning_path(X_train, y_train)["ccp_alphas"]

In [21]:
ccp_alphas

array([0.        , 0.00218083, 0.0028662 , 0.0029304 , 0.00395604,
       0.00425059, 0.00502355, 0.00527473, 0.00593407, 0.00764113,
       0.01439037, 0.02038595, 0.05433359, 0.32661707])

In [22]:
ccp_alpha_grid_search = GridSearchCV(
    estimator=DecisionTreeClassifier(random_state=42),
    scoring=make_scorer(accuracy_score),
    param_grid=ParameterGrid({"ccp_alpha": [[alpha] for alpha in ccp_alphas]}),
)

In [23]:
ccp_alpha_grid_search.fit(X_train, y_train)

In [24]:
ccp_alpha_grid_search.best_params_

{'ccp_alpha': 0.005934065934065933}

In [25]:
best_ccp_alpha_tree = ccp_alpha_grid_search.best_estimator_

In [None]:
plt.figure(figsize=(4, 4), dpi=1000)
plot_tree(
    best_ccp_alpha_tree,
    feature_names=feature_names,
    class_names=class_names,
    filled=True,
)
plt.show()

In [27]:
print(classification_report(y_test, best_ccp_alpha_tree.predict(X_test)))

              precision    recall  f1-score   support

           0       0.93      0.90      0.92        42
           1       0.95      0.96      0.95        72

    accuracy                           0.94       114
   macro avg       0.94      0.93      0.93       114
weighted avg       0.94      0.94      0.94       114

