In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import r2_score
from sklearn import tree
from utils.utils import find_best_hyperparameters


In [3]:
import warnings
#suppress warnings
warnings.filterwarnings('ignore')

In [5]:
dataset = pd.read_csv("50_Startups.csv")

In [7]:
dataset = pd.get_dummies(dataset, drop_first=True)

In [9]:
independent_vars = dataset[['R&D Spend', 'Administration', 'Marketing Spend','State_Florida', 'State_New York']]

In [11]:
dependent_var = dataset[['Profit']]

In [13]:
x_train,x_test,y_train,y_test=train_test_split(independent_vars,dependent_var,test_size=0.30,random_state=0)

In [15]:
param_dict = {
    "criterion": ["squared_error", "friedman_mse", "absolute_error", "poisson"],
    "splitter" : ["best","random"],
    "max_features": ["sqrt", "log2",None]
}


In [17]:
#print csv header
print("criterion,splitter,max_features,r_score")

best_combo = find_best_hyperparameters(
    param_dict,
    x_train, y_train, x_test, y_test,
    create_regressor_callback=lambda combo: DecisionTreeRegressor(**combo),
    print_combo_callback=lambda combo: print(f'{combo["criterion"]},{combo["splitter"]},{combo["max_features"]},{combo["r_score"]}')
)

print("\nBest combination:")
print(f'criterion={best_combo["criterion"]}, splitter={best_combo["splitter"]}, max_features={best_combo["max_features"]},r_score={best_combo["r_score"]}')


criterion,splitter,max_features,r_score
squared_error,best,sqrt,0.6235470054758422
squared_error,best,log2,0.0596361108392639
squared_error,best,None,0.9048609445049367
squared_error,random,sqrt,0.9112773986985244
squared_error,random,log2,-1.4146258801818816
squared_error,random,None,0.9272167871134455
friedman_mse,best,sqrt,0.12713675132395597
friedman_mse,best,log2,0.7874559294924994
friedman_mse,best,None,0.921072540271293
friedman_mse,random,sqrt,-0.17850568937917966
friedman_mse,random,log2,0.3316809042126361
friedman_mse,random,None,0.9476955033291437
absolute_error,best,sqrt,0.6767778507244208
absolute_error,best,log2,0.8889038289724436
absolute_error,best,None,0.9668594647644447
absolute_error,random,sqrt,-0.053623822648428465
absolute_error,random,log2,0.5866858551960833
absolute_error,random,None,0.5592070182811405
poisson,best,sqrt,0.4653405606395379
poisson,best,log2,0.7656250610762741
poisson,best,None,0.920902868420011
poisson,random,sqrt,0.6748402515634971
poisson,rando