In [None]:
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Hyperparameter tuning

## Grid search

In [None]:
from sklearn import svm
from sklearn import datasets
from sklearn.model_selection import GridSearchCV
from pprint import pprint

iris = datasets.load_iris()
# print(iris.target)

estimator = svm.SVC()
parameters = {'kernel':('linear', 'rbf'), 
              'C':[1, 10]}
scoring = ['accuracy', 
           'precision_micro']

clf = GridSearchCV(estimator=estimator, 
                   param_grid=parameters,
                   cv=5,
                   scoring=scoring,
                   refit='accuracy')
search = clf.fit(iris.data, iris.target)

# pprint(sorted(clf.cv_results_.keys()))
print(search.best_estimator_)
print(search.best_score_)
print(search.best_params_)
print(search.n_splits_)
# pprint(search.cv_results_)

## Random Search

In [None]:
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import uniform
from pprint import pprint

iris = load_iris()

estimator = LogisticRegression(solver='saga', 
                               tol=1e-2, 
                               max_iter=200,
                               random_state=0)

distributions = dict(C=uniform(loc=0, scale=4),
                     penalty=['l2', 'l1'])

clf = RandomizedSearchCV(estimator, 
                         distributions, 
                         random_state=0)

search = clf.fit(iris.data, iris.target)

print(search.best_estimator_)
print(search.best_score_)
print(search.best_params_)
print(search.n_splits_)
# pprint(search.cv_results_)