In [1]:
import numpy as np
import pandas as pd

from pydataset import data

from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import precision_score, make_scorer

In [2]:
tips = data("tips")
tips.head()

Unnamed: 0,total_bill,tip,sex,smoker,day,time,size
1,16.99,1.01,Female,No,Sun,Dinner,2
2,10.34,1.66,Male,No,Sun,Dinner,3
3,21.01,3.5,Male,No,Sun,Dinner,3
4,23.68,3.31,Male,No,Sun,Dinner,2
5,24.59,3.61,Female,No,Sun,Dinner,4


In [3]:
X = tips[["tip", "total_bill", "size"]]
y = tips.time

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=123, test_size=.2)

In [4]:
tree = DecisionTreeClassifier(max_depth=4)

In [5]:
cross_val_score(tree, X_train, y_train, cv=4).mean()

0.6973852040816326

In [6]:
tree.fit(X_train, y_train)

DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
                       max_depth=4, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=None, splitter='best')

In [7]:
predicted = tree.predict(X_train)
actual = y_train

precision_score(actual, predicted, pos_label="Dinner")

0.8571428571428571

In [8]:
precision_scorer = make_scorer(precision_score, pos_label="Dinner")

# decision tree with max depth of 4
cross_val_score(tree, X_train, y_train, cv=4, scoring=precision_scorer).mean()

0.7371597454477888

In [9]:
cross_val_score(DecisionTreeClassifier(max_depth=3), X_train, y_train, scoring=precision_scorer).mean()

0.7396049896049897

---
## Grid Search

In [10]:
from sklearn.model_selection import GridSearchCV

# keys are names of hyperparameters
# values are a list of values to try with hyperparameters
# different values of c for logistic regression
# different value of k for KNN
params = {
    "max_depth": range(1, 11),
    "criterion": ["gini", "entropy"]
}

# cv=4 means four-fold cross validation, i.e. k=4
grid = GridSearchCV(tree, params, cv=4)
grid.fit(X_train, y_train)

# best accuracy for out-of-sample data
grid.best_params_

{'criterion': 'entropy', 'max_depth': 3}

In [11]:
# .best_estimator_ gives us a model that is prefit with the best hyperparameters
model = grid.best_estimator_
model.score(X_test, y_test)

0.6530612244897959

In [12]:
grid.best_score_

0.7387329931972789

In [13]:
# cv_results_ gives us a dictionary with a params key that containts a list of dictionaries that represent the
# params that were used for the model

results = grid.cv_results_
results

{'mean_fit_time': array([0.00259274, 0.00205475, 0.00190198, 0.00236416, 0.00212562,
        0.00268519, 0.00196373, 0.0021407 , 0.00194359, 0.00194949,
        0.00203609, 0.00170219, 0.00191861, 0.00205612, 0.00183719,
        0.00216818, 0.00204152, 0.00202894, 0.00193447, 0.00198853]),
 'std_fit_time': array([1.91043394e-04, 3.91550361e-04, 3.01347526e-04, 1.70688638e-04,
        3.17247680e-04, 6.97003953e-04, 2.09163391e-04, 1.67477417e-04,
        2.10845683e-04, 1.19461315e-04, 3.58265586e-04, 4.32900456e-05,
        1.97211036e-04, 2.15315936e-04, 1.18746021e-04, 2.80786290e-04,
        1.74107800e-04, 1.86231849e-04, 1.22372778e-04, 1.64723990e-04]),
 'mean_score_time': array([0.00163978, 0.0009734 , 0.00096595, 0.00115192, 0.00101167,
        0.00124007, 0.00100148, 0.00123084, 0.00088692, 0.00092179,
        0.00097764, 0.00085992, 0.00090504, 0.00099993, 0.00090915,
        0.00093889, 0.00087804, 0.00100559, 0.00091189, 0.00089437]),
 'std_score_time': array([5.31991603e-

In [14]:
# modify each parameter dictionary such that it also contains the model's average performance after cross validation
for score, p in zip(results["mean_test_score"], results["params"]):
    p["score"] = score

df = pd.DataFrame(results["params"])
df

Unnamed: 0,criterion,max_depth,score
0,gini,1,0.733418
1,gini,2,0.733418
2,gini,3,0.723214
3,gini,4,0.697385
4,gini,5,0.620642
5,gini,6,0.676764
6,gini,7,0.661352
7,gini,8,0.661458
8,gini,9,0.640838
9,gini,10,0.65625
