In [21]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import r2_score
from sklearn import tree
from utils import generate_parameter_combinations


In [23]:
import warnings
#suppress warnings
warnings.filterwarnings('ignore')

In [25]:
dataset = pd.read_csv("50_Startups.csv")

In [27]:
dataset = pd.get_dummies(dataset, drop_first=True)

In [29]:
independent_vars = dataset[['R&D Spend', 'Administration', 'Marketing Spend','State_Florida', 'State_New York']]

In [31]:
dependent_var = dataset[['Profit']]

In [33]:
x_train,x_test,y_train,y_test=train_test_split(independent_vars,dependent_var,test_size=0.30,random_state=0)

In [39]:
param_dict = {
    "criterion": ["squared_error", "friedman_mse", "absolute_error", "poisson"],
    "splitter" : ["best","random"],
    "max_features": ["sqrt", "log2",None]
}


In [45]:
combinations = generate_parameter_combinations(param_dict)

best_combo = None
max_r_score = float('-inf')

for combo in combinations:
    regressor = DecisionTreeRegressor(**combo)
    regressor.fit(x_train,y_train)
    y_predict = regressor.predict(x_test)
    r_score = r2_score(y_test,y_predict)
    combo["r_score"] = r_score
    print(f'{combo["criterion"]},{combo["splitter"]},{combo["max_features"]},{combo["r_score"]}')

    if r_score > max_r_score:
        max_r_score = r_score
        best_combo = combo.copy()
        
print("\nBest combination:")
print(f'criterion={best_combo["criterion"]}, splitter={best_combo["splitter"]}, max_features={best_combo["max_features"]},r_score={best_combo["r_score"]}')


squared_error,best,sqrt,0.18597602120485812
squared_error,best,log2,0.617446280976921
squared_error,best,None,0.9255271251892302
squared_error,random,sqrt,-0.0012120196444767029
squared_error,random,log2,-0.923900612587405
squared_error,random,None,0.6176420835295096
friedman_mse,best,sqrt,0.09047722393548108
friedman_mse,best,log2,0.5876792741239747
friedman_mse,best,None,0.9222437938612578
friedman_mse,random,sqrt,0.5933237244111906
friedman_mse,random,log2,0.08643030491581771
friedman_mse,random,None,0.6858953593414545
absolute_error,best,sqrt,0.7560728149092735
absolute_error,best,log2,-0.5485786244340098
absolute_error,best,None,0.9734719142636618
absolute_error,random,sqrt,-0.11211007959041575
absolute_error,random,log2,0.7580894354094667
absolute_error,random,None,0.5672570780099401
poisson,best,sqrt,0.3793208710256508
poisson,best,log2,0.8162345206432022
poisson,best,None,0.9403019035364156
poisson,random,sqrt,0.7387428820770441
poisson,random,log2,0.8019189828387029
poisson,ra