In [1]:
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import mean_squared_error
from Features import FEATURE_LIST
from sklearn.metrics import mean_absolute_error

In [2]:
df = pd.read_csv('../data/wine_data_train.csv')
df = df[FEATURE_LIST]

df_test = pd.read_csv('../data/wine_data_test.csv')
df_test = df_test[FEATURE_LIST]

X_train, X_val, y_train, y_val = train_test_split(df.drop("quality",axis=1), df["quality"], test_size=0.2, random_state=42)

param_grid = {
    'max_leaf_nodes': range(2, 100, 5),
    'max_depth': range(1, 20),
}

reg = DecisionTreeRegressor(random_state=42)
grid_search = GridSearchCV(reg, param_grid, cv=5, scoring='neg_mean_squared_error')
grid_search.fit(X_train, y_train)

In [5]:
best_params = grid_search.best_params_
reg = grid_search.best_estimator_

y_pred = reg.predict(X_val)
MAE = mean_absolute_error(y_pred, y_val)
MSE = mean_squared_error(y_pred,y_val)
r_2 = reg.score(X_val, y_val)

# Perform of test.
X_test = df_test.drop("quality",axis=1)
y_test = df_test["quality"]
y_pred_test = reg.predict(X_test)
MAE_t = mean_absolute_error(y_pred_test, y_test)
MSE_t = mean_squared_error(y_pred_test,y_test)
r_2_t = reg.score(X_test, y_test)

message = (f'best parameter{best_params}\n' +
           f'Validation R^2: {r_2}\n' + f'Validation MAE: {MAE}\n' + f'Validation MSE: {MSE}\n' 
           + f'Test R^2: {r_2_t}\n' + f'Test MAE: {MAE_t}\n' + f'Test MSE: {MSE_t}')
print(message)

with open("results/DecisionTree_results.txt", "w") as file:
    file.write(message)

best parameter{'max_depth': 7, 'max_leaf_nodes': 22}
Validation R^2: 0.31762008174570555
Validation MAE: 0.5787563970026884
Validation MSE: 0.5704780657017372
Test R^2: 0.2855867708473492
Test MAE: 0.5679650491739966
Test MSE: 0.5276275653772226


In [4]:
mean_mse = grid_search.cv_results_['mean_test_score']

max_depth_values = param_grid['max_depth']
max_leaf_nodes_values = param_grid['max_leaf_nodes']
mse_matrix = mean_mse.reshape(len(max_depth_values), len(max_leaf_nodes_values))

# Create a DataFrame from the MSE matrix
mse_df = pd.DataFrame(-mse_matrix, index=max_depth_values, columns=max_leaf_nodes_values)

# Set row and column names
mse_df.index.name = 'Max Depth'
mse_df.columns.name = 'Max Leaf Nodes'

# Display the table
mse_df

Max Leaf Nodes,2,7,12,17,22,27,32,37,42,47,52,57,62,67,72,77,82,87,92,97
Max Depth,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
1,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285,0.6285
2,0.6285,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755,0.576755
3,0.6285,0.558757,0.557373,0.557373,0.557373,0.557373,0.557373,0.557373,0.557373,0.557373,0.557373,0.557373,0.557373,0.557373,0.557373,0.557373,0.557373,0.557373,0.557373,0.557373
4,0.6285,0.561831,0.545241,0.536407,0.536407,0.536407,0.536407,0.536407,0.536407,0.536407,0.536407,0.536407,0.536407,0.536407,0.536407,0.536407,0.536407,0.536407,0.536407,0.536407
5,0.6285,0.561831,0.544608,0.539506,0.54281,0.541589,0.544371,0.544371,0.544371,0.544371,0.544371,0.544371,0.544371,0.544371,0.544371,0.544371,0.544371,0.544371,0.544371,0.544371
6,0.6285,0.561831,0.544608,0.53954,0.535276,0.540713,0.539257,0.540841,0.541491,0.541647,0.543423,0.546137,0.54614,0.54614,0.54614,0.54614,0.54614,0.54614,0.54614,0.54614
7,0.6285,0.561831,0.544608,0.53954,0.533841,0.537624,0.536386,0.539205,0.547076,0.550022,0.55056,0.554424,0.555663,0.558965,0.560207,0.562293,0.563486,0.56679,0.568177,0.57071
8,0.6285,0.561831,0.544608,0.53954,0.534756,0.538538,0.538072,0.538383,0.536665,0.54621,0.548037,0.550397,0.551113,0.555097,0.556794,0.554359,0.55743,0.556212,0.556539,0.558203
9,0.6285,0.561831,0.544608,0.53954,0.534756,0.538538,0.537976,0.538257,0.534317,0.545601,0.547514,0.548191,0.548606,0.552463,0.55322,0.558788,0.563628,0.561574,0.563832,0.561603
10,0.6285,0.561831,0.544608,0.53954,0.534756,0.538538,0.539203,0.53839,0.535097,0.54562,0.546639,0.550262,0.548851,0.552441,0.556623,0.556952,0.562566,0.566336,0.566113,0.565266
