# Decision Tree

#### imports and misc

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeRegressor
from dataset import x_train, x_test, y_train, y_test
from sklearn.metrics import mean_absolute_error

In [None]:
def set_color(_fig, _ax):
    _fig.patch.set_facecolor('#1b212c')
    _ax.patch.set_facecolor('#1b212c')
    _ax.spines['bottom'].set_color('white')
    _ax.spines['top'].set_color('white')
    _ax.spines['left'].set_color('white')
    _ax.spines['right'].set_color('white')
    _ax.xaxis.label.set_color('white')
    _ax.yaxis.label.set_color('white')
    _ax.grid(alpha=0.1)
    _ax.title.set_color('white')
    _ax.tick_params(axis='x', colors='white')
    _ax.tick_params(axis='y', colors='white')

#### show the depth performances

In [None]:
results = {}
for i in range(25):
    regr = DecisionTreeRegressor(max_depth=i+1)
    regr.fit(x_train, y_train)
    y_pred = regr.predict(x_test)
    results[i] = mean_absolute_error(y_test, y_pred)
fig, ax1 = plt.subplots(ncols=1, figsize=(10, 7), dpi=300)

sns.lineplot(x=list(results.keys()), y=list(results.values()), ax=ax1)
plt.xlabel('Max Depth')
plt.ylabel('Mean Absolute Error')

set_color(fig, ax1)
plt.savefig('../images/decision_tree/max_depth.png', dpi=300)
plt.show()


#### show MAE and the differences between y_test and y_pred

In [None]:
regr = DecisionTreeRegressor(max_depth=25)
regr.fit(x_train, y_train)
y_pred = regr.predict(x_test)

print("Mean Absolute Error:", mean_absolute_error(y_test, y_pred))

df = pd.DataFrame()
df['actual'] = y_test
df['predicted'] = y_pred
df['diff'] = abs(y_test - y_pred)
df2 = df[df["diff"] < 30000]

sns.set_theme(style="ticks", palette="pastel")
fig, ax = plt.subplots(figsize=(10, 6), dpi=300)
sns.distplot(df2["diff"], color="#81acc3", hist_kws=dict(alpha=0.4), fit_kws=dict(alpha=1), ax=ax)

plt.axvline(np.mean(df["diff"]), color='r', linestyle='--', label='Median')

sns.despine(offset=10, trim=True)
set_color(fig, ax)
plt.xlabel('Price deviation')
plt.xlim(0, None)
plt.savefig('../images/decision_tree/tree_error_dist.png', dpi=300)
plt.show()