# Hyperparameter Search

This notebook demonstrates the different options for hyperparameter search available in `scikit-learn`.

In [17]:
import numpy as np
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.svm import SVC
from sklearn.datasets import load_digits
from pandas import DataFrame

# Load the dataset
X, y = load_digits(return_X_y=True)

# Split dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)

In [18]:
model = SVC(random_state=1)
param_range = [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0]
param_grid = [{
    'C': param_range,
    'kernel': ['linear']
},
{
    'C': param_range,
    'gamma': param_range,
    'kernel': ['rbf']
}]

gs = GridSearchCV(estimator=model,
                  param_grid=param_grid,
                  scoring='accuracy',
                  cv=10,
                  refit=True,
                  n_jobs=-1) # run on all cores


gs = gs.fit(X_train, y_train)

# Retrieve the best estimator -- GridSearchCV will refit it on all of the training data
clf = gs.best_estimator_

print(f'Test accuracy: {clf.score(X_test, y_test):.3f}')

# Display results in a table
print(gs.best_params_)
param_results = DataFrame(gs.cv_results_)
param_results

Test accuracy: 0.992
{'C': 10.0, 'gamma': 0.001, 'kernel': 'rbf'}


Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_C,param_kernel,param_gamma,params,split0_test_score,split1_test_score,split2_test_score,split3_test_score,split4_test_score,split5_test_score,split6_test_score,split7_test_score,split8_test_score,split9_test_score,mean_test_score,std_test_score,rank_test_score
0,0.063262,0.002840,0.008054,0.000472,0.0001,linear,,"{'C': 0.0001, 'kernel': 'linear'}",0.965278,0.937500,0.958333,0.937500,0.972222,0.965278,0.979167,0.951049,0.951049,0.944056,0.956143,0.013614,17
1,0.030964,0.000787,0.004453,0.000332,0.001,linear,,"{'C': 0.001, 'kernel': 'linear'}",0.986111,0.972222,0.958333,0.993056,0.986111,0.993056,0.986111,0.986014,0.986014,0.979021,0.982605,0.009941,8
2,0.029462,0.001561,0.003519,0.000472,0.01,linear,,"{'C': 0.01, 'kernel': 'linear'}",0.965278,0.972222,0.972222,0.979167,0.986111,0.993056,0.993056,0.972028,0.979021,0.965035,0.977720,0.009763,14
3,0.031259,0.003756,0.003861,0.001169,0.1,linear,,"{'C': 0.1, 'kernel': 'linear'}",0.965278,0.979167,0.972222,0.979167,0.986111,0.993056,0.993056,0.972028,0.979021,0.979021,0.979813,0.008497,9
4,0.031050,0.004100,0.003340,0.000105,1.0,linear,,"{'C': 1.0, 'kernel': 'linear'}",0.965278,0.979167,0.972222,0.979167,0.986111,0.993056,0.993056,0.972028,0.979021,0.979021,0.979813,0.008497,9
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
67,0.250653,0.006713,0.020006,0.000950,1000.0,rbf,0.1,"{'C': 1000.0, 'gamma': 0.1, 'kernel': 'rbf'}",0.111111,0.104167,0.104167,0.104167,0.111111,0.104167,0.104167,0.104895,0.104895,0.118881,0.107173,0.004729,23
68,0.269259,0.014311,0.021390,0.000638,1000.0,rbf,1.0,"{'C': 1000.0, 'gamma': 1.0, 'kernel': 'rbf'}",0.111111,0.104167,0.104167,0.104167,0.104167,0.104167,0.104167,0.104895,0.104895,0.111888,0.105779,0.002879,27
69,0.253994,0.011145,0.021665,0.001241,1000.0,rbf,10.0,"{'C': 1000.0, 'gamma': 10.0, 'kernel': 'rbf'}",0.111111,0.104167,0.104167,0.104167,0.104167,0.104167,0.104167,0.104895,0.104895,0.111888,0.105779,0.002879,27
70,0.244863,0.014733,0.018634,0.003466,1000.0,rbf,100.0,"{'C': 1000.0, 'gamma': 100.0, 'kernel': 'rbf'}",0.111111,0.104167,0.104167,0.104167,0.104167,0.104167,0.104167,0.104895,0.104895,0.111888,0.105779,0.002879,27
