In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import r2_score
from sklearn import tree
from utils import find_best_hyperparameters


In [2]:
import warnings
#suppress warnings
warnings.filterwarnings('ignore')

In [3]:
dataset = pd.read_csv("50_Startups.csv")

In [4]:
dataset = pd.get_dummies(dataset, drop_first=True)

In [5]:
independent_vars = dataset[['R&D Spend', 'Administration', 'Marketing Spend','State_Florida', 'State_New York']]

In [6]:
dependent_var = dataset[['Profit']]

In [7]:
x_train,x_test,y_train,y_test=train_test_split(independent_vars,dependent_var,test_size=0.30,random_state=0)

In [8]:
param_dict = {
    "criterion": ["squared_error", "friedman_mse", "absolute_error", "poisson"],
    "splitter" : ["best","random"],
    "max_features": ["sqrt", "log2",None]
}


In [21]:
#print csv header
print("criterion,splitter,max_features,r_score")

best_combo = find_best_hyperparameters(
    param_dict,
    x_train, y_train, x_test, y_test,
    create_regressor_callback=lambda combo: DecisionTreeRegressor(**combo),
    print_combo_callback=lambda combo: print(f'{combo["criterion"]},{combo["splitter"]},{combo["max_features"]},{combo["r_score"]}')
)

print("\nBest combination:")
print(f'criterion={best_combo["criterion"]}, splitter={best_combo["splitter"]}, max_features={best_combo["max_features"]},r_score={best_combo["r_score"]}')


criterion,splitter,max_features,r_score
squared_error,best,sqrt,0.46040282136675437
squared_error,best,log2,0.09466567626283551
squared_error,best,None,0.9068851233484354
squared_error,random,sqrt,0.8495575815016504
squared_error,random,log2,0.34316288254812377
squared_error,random,None,0.8436252064253887
friedman_mse,best,sqrt,0.7676819546076277
friedman_mse,best,log2,-0.7836351600475915
friedman_mse,best,None,0.9420041821648154
friedman_mse,random,sqrt,0.37126398216405543
friedman_mse,random,log2,0.653874890469893
friedman_mse,random,None,0.9140932756005998
absolute_error,best,sqrt,0.7177074797778329
absolute_error,best,log2,0.3594001512559779
absolute_error,best,None,0.9381746066956602
absolute_error,random,sqrt,0.5973789572538233
absolute_error,random,log2,0.017161113967351427
absolute_error,random,None,0.8909901511287408
poisson,best,sqrt,0.9045096954717673
poisson,best,log2,0.5362320550967974
poisson,best,None,0.9088008328386012
poisson,random,sqrt,-0.08948670183305163
poisson,ra