# Decision Tree Regression

In [38]:
import pandas as pd
from sklearn.model_selection import KFold, GridSearchCV
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import r2_score, mean_squared_error
from prettytable import PrettyTable

In [39]:
steel_data = pd.read_csv("steel.csv")

In [40]:
# Features of the dataset, all cols bar the last
features = steel_data.values[:, :-1]
# Ground truths, the last column
ground_truths = steel_data.values[:, -1]

kf = KFold(n_splits=10, shuffle=False)

model = DecisionTreeRegressor(random_state=4)

param_grid={
    "max_depth" : [3,4,5,6,7,8],
    "min_samples_leaf" : [3,4,5,6,7,8]}

scores_headers = ["Fold", "R2 Score", "Mean Squared Error"]
scores_list = []

In [41]:
for i, (train_index, test_index) in enumerate(kf.split(features)):
    # Features and ground truths for the ith fold
    training_features, test_features = features[train_index], features[test_index]
    training_ground_truths, test_ground_truths = ground_truths[train_index], ground_truths[test_index]

    model.fit(training_features, training_ground_truths)

    prediction = model.predict(test_features)

    r2 = r2_score(test_ground_truths, prediction)

    mse = mean_squared_error(test_ground_truths, prediction)

    scores_list.append([f"{i+1}", f"{r2:.2f}", f"{mse:.2f}"])

In [42]:
table = PrettyTable()

table.title = "Error Scores with Default Params"
table.field_names = scores_headers
table.add_rows(scores_list)
print(table)

+--------------------------------------+
|   Error Scores with Default Params   |
+------+----------+--------------------+
| Fold | R2 Score | Mean Squared Error |
+------+----------+--------------------+
|  1   |   0.45   |      2307.94       |
|  2   |  -0.24   |      5851.16       |
|  3   |   0.60   |      2747.84       |
|  4   |   0.74   |      1587.71       |
|  5   |   0.64   |      1807.87       |
|  6   |   0.70   |      2304.31       |
|  7   |   0.50   |      2074.93       |
|  8   |   0.38   |      3719.86       |
|  9   |   0.33   |      7637.25       |
|  10  |   0.33   |      5687.03       |
+------+----------+--------------------+


In [43]:
hyperparam_score_headers = ["Max Tree Depth", "Min Samples Leaf", "R2 Score", "Mean Squared Error"]
hyperparam_scores_list = []

grid_search = GridSearchCV(estimator=model, param_grid=param_grid, cv=kf, scoring=["r2", "neg_mean_squared_error"], refit="r2", n_jobs=-1)

grid_search.fit(features, ground_truths)

results = grid_search.cv_results_
for mean_r2, mean_mse, params in zip(results['mean_test_r2'], results['mean_test_neg_mean_squared_error'], results['params']):

    min_samples_leaf = params['min_samples_leaf']
    max_depth = params['max_depth']
    hyperparam_scores_list.append([f"{min_samples_leaf}", f"{max_depth}", f"{mean_r2:.2f}", f"{(-1)*mean_mse:.2f}"])

hyperparam_table = PrettyTable()
hyperparam_table.title = "Hyperparameters & Values"
hyperparam_table.field_names = hyperparam_score_headers
hyperparam_table.add_rows(hyperparam_scores_list)
hyperparam_table.sortby = "Max Tree Depth"
print(hyperparam_table)

print(grid_search.best_params_)

+-------------------------------------------------------------------+
|                      Hyperparameters & Values                     |
+----------------+------------------+----------+--------------------+
| Max Tree Depth | Min Samples Leaf | R2 Score | Mean Squared Error |
+----------------+------------------+----------+--------------------+
|       3        |        3         |   0.12   |      5424.82       |
|       3        |        4         |   0.30   |      4304.26       |
|       3        |        5         |   0.40   |      3756.57       |
|       3        |        6         |   0.41   |      3701.77       |
|       3        |        7         |   0.44   |      3466.84       |
|       3        |        8         |   0.42   |      3564.45       |
|       4        |        3         |   0.12   |      5424.82       |
|       4        |        4         |   0.30   |      4314.97       |
|       4        |        5         |   0.42   |      3631.10       |
|       4        |  