In [20]:
# example of bayesian optimization with scikit-optimize
from numpy import mean
from sklearn.datasets import make_blobs
from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from skopt.space import Integer
from skopt.utils import use_named_args
from skopt import gp_minimize
from warnings import catch_warnings
from warnings import simplefilter

In [13]:
# generate 2d classification dataset
X, y = make_blobs(n_samples=500, centers=3, n_features=2)
# define the model
model = KNeighborsClassifier()

In [16]:
# define the space of hyperparameters to search
search_space = [Integer(1, 5, name='n_neighbors'), Integer(1, 2, name='p')]

In [18]:
# define the funciton used to evaluate a given configuration
@use_named_args(search_space)
def evaluate_model(**params):
    # something
    model.set_params(**params)
    # calculate 5-fold corss validation
    result = cross_val_score(model, X, y, cv=5, n_jobs=-1, scoring='accuracy')
    # calculate the mean of the scores
    estimate = mean(result)
    return 1.0 - estimate


In [21]:
# perform optimization
with catch_warnings():
    # ignore generated warnings
    simplefilter('ignore')
    result = gp_minimize(evaluate_model, search_space)

In [22]:
# summarizing finding:
print('Best Accuracy: %.3f' % (1.0 - result.fun))
print('Best Parameters: n_neighbors=%d, p=%d' % (result.x[0], result.x[1]))

Best Accuracy: 1.000
Best Parameters: n_neighbors=3, p=2


In [12]:
X

array([[ 2.57596936e-01, -7.98665409e+00],
       [ 9.70414145e-01, -9.14324205e+00],
       [ 2.43616957e+00, -6.54524747e+00],
       [ 7.60994147e+00, -1.03918468e+01],
       [ 1.46397965e+00, -9.19287870e+00],
       [ 8.88307121e-01, -7.98525026e+00],
       [ 2.21089996e+00, -8.40044269e+00],
       [ 2.33107531e-01, -8.07701147e+00],
       [ 4.02570340e+00, -8.10368838e+00],
       [ 2.18996820e+00, -7.31337241e+00],
       [-1.26407998e+00, -8.32016628e+00],
       [ 5.65317229e+00, -8.65638508e+00],
       [ 1.90962998e+00, -7.09761376e+00],
       [ 2.81203416e+00,  6.01357353e+00],
       [ 6.51817416e+00, -9.05183316e+00],
       [ 6.46764888e-01, -6.12394272e+00],
       [ 1.40253802e+00, -9.27013866e+00],
       [-4.56943328e-01, -8.05028070e+00],
       [ 6.98447654e+00, -8.59729677e+00],
       [ 7.07862828e+00, -1.01314756e+01],
       [ 2.62857047e+00,  7.49942872e+00],
       [ 8.44624523e-01, -6.90727336e+00],
       [ 5.22583152e-01,  4.47059520e+00],
       [ 7.