In [1]:
import numpy as np
import pandas as pd
import json
import os
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier as RFC
from sklearn.svm import SVC
from bayes_opt import BayesianOptimization
from bayes_opt.util import Colours
from bayes_opt import UtilityFunction
import copy
from functools import partial

In [2]:
def get_data():
    """Synthetic binary classification dataset."""
    data, targets = make_classification(
        n_samples=5000,
        n_features=100,
        n_informative=80,
        #n_redundant=2,
        n_classes = 5,
        #random_state=134985745,
    )
    return data, targets


def svc_cv(expC, expGamma, X, Y):
    """SVC cross validation.
    This function will instantiate a SVC classifier with parameters C and
    gamma. Combined with data and targets this will in turn be used to perform
    cross validation. The result of cross validation is returned.
    Our goal is to find combinations of C and gamma that maximizes the roc_auc
    metric.
    """
    """Wrapper of SVC cross validation.
    Notice how we transform between regular and log scale. While this
    is not technically necessary, it greatly improves the performance
    of the optimizer.
    """
    C = 10 ** expC
    gamma = 10 ** expGamma
    estimator = SVC(C=C, gamma=gamma, random_state=2)
    cval = cross_val_score(estimator, X, Y, scoring='f1_weighted', cv=4)
    return cval.mean()


def rfc_cv(n_estimators, min_samples_split, max_features, X, Y):
    """Random Forest cross validation.
    This function will instantiate a random forest classifier with parameters
    n_estimators, min_samples_split, and max_features. Combined with data and
    targets this will in turn be used to perform cross validation. The result
    of cross validation is returned.
    Our goal is to find combinations of n_estimators, min_samples_split, and
    max_features that minimzes the log loss.
    """
    estimator = RFC(
        n_estimators=int(n_estimators),
        min_samples_split=int(min_samples_split),
        max_features=max_features,
        random_state=2
    )
    cval = cross_val_score(estimator, X, Y,
                           scoring='f1_weighted', cv=4)
    return cval.mean()


In [3]:
X, Y = get_data()
print(Colours.yellow("--- Optimizing SVM ---"))
n_iter = 10

black_box_funs = [svc_cv, rfc_cv]
pbounds_lst = [{
                    "expC": (-3, 2), 
                    "expGamma": (-4, -1)},
               {
                    "n_estimators": (10, 250),
                    "min_samples_split": (2, 25),
                    "max_features": (0.1, 0.999),
                }
              ]
for idx, black_box_fptr in enumerate(black_box_funs):
    print("---------- Optimizing {}--------------".format(black_box_fptr))
    optimizer = BayesianOptimization(
        f=partial(black_box_fptr, X=X, Y=Y),
        pbounds=pbounds_lst[idx],
        verbose=2, # verbose = 1 prints only when a maximum is observed, verbose = 0 is silent
        random_state=65535,
    )
    utility = UtilityFunction(kind="ucb", kappa=2.5, xi=0.0)
    if idx == 0:
        optimizer.probe(
        params={"expC":-1.002, "expGamma": -2.002},
        lazy=True,
        )
    else:
        optimizer.probe(
        params={"n_estimators":123, "min_samples_split": 21, "max_features":0.8888},
        lazy=True,
        )
    optimizer.maximize(init_points=30, n_iter=n_iter)
    '''
    for _ in range(n_iter):
        next_point = optimizer.suggest(utility)
        params = copy.copy(next_point)
        params.update({"X":X, "Y":Y})
        target = black_box_fptr(**params)
        optimizer.register(params=next_point, target=target)    
        print(target, next_point)
    '''    
    print(optimizer.max)


[93m--- Optimizing SVM ---[0m
---------- Optimizing <function svc_cv at 0x7f347f0ab2f0>--------------
|   iter    |  target   |   expC    | expGamma  |
-------------------------------------------------


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 1       [0m | [0m 0.06716 [0m | [0m-1.002   [0m | [0m-2.002   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 2       [0m | [0m 0.06716 [0m | [0m-2.033   [0m | [0m-3.207   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 3       [0m | [0m 0.06716 [0m | [0m 1.802   [0m | [0m-2.146   [0m |
| [95m 4       [0m | [95m 0.8717  [0m | [95m 1.511   [0m | [95m-3.238   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 5       [0m | [0m 0.06716 [0m | [0m-0.3827  [0m | [0m-2.296   [0m |
| [0m 6       [0m | [0m 0.5075  [0m | [0m-1.06    [0m | [0m-3.696   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 7       [0m | [0m 0.06716 [0m | [0m-1.595   [0m | [0m-2.264   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 8       [0m | [0m 0.06716 [0m | [0m 0.9716  [0m | [0m-2.155   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 9       [0m | [0m 0.07003 [0m | [0m-1.349   [0m | [0m-3.265   [0m |
| [0m 10      [0m | [0m 0.7779  [0m | [0m 1.243   [0m | [0m-2.778   [0m |
| [0m 11      [0m | [0m 0.8713  [0m | [0m 1.27    [0m | [0m-3.235   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 12      [0m | [0m 0.06716 [0m | [0m 0.8516  [0m | [0m-2.04    [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 13      [0m | [0m 0.06716 [0m | [0m-0.6569  [0m | [0m-1.993   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 14      [0m | [0m 0.06716 [0m | [0m-2.026   [0m | [0m-1.25    [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 15      [0m | [0m 0.06716 [0m | [0m-2.033   [0m | [0m-1.146   [0m |
| [0m 16      [0m | [0m 0.4665  [0m | [0m-1.172   [0m | [0m-3.597   [0m |
| [0m 17      [0m | [0m 0.7514  [0m | [0m 0.0768  [0m | [0m-3.897   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 18      [0m | [0m 0.06716 [0m | [0m 0.7598  [0m | [0m-1.928   [0m |
| [0m 19      [0m | [0m 0.7359  [0m | [0m 0.07078 [0m | [0m-3.948   [0m |
| [0m 20      [0m | [0m 0.4569  [0m | [0m-1.125   [0m | [0m-3.367   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 21      [0m | [0m 0.06716 [0m | [0m 0.1577  [0m | [0m-1.413   [0m |
| [0m 22      [0m | [0m 0.6151  [0m | [0m-0.6614  [0m | [0m-3.948   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 23      [0m | [0m 0.06716 [0m | [0m-2.41    [0m | [0m-2.252   [0m |
| [0m 24      [0m | [0m 0.8653  [0m | [0m 0.6662  [0m | [0m-3.815   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 25      [0m | [0m 0.06716 [0m | [0m-0.6346  [0m | [0m-1.306   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 26      [0m | [0m 0.06716 [0m | [0m-0.3186  [0m | [0m-1.296   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 27      [0m | [0m 0.06716 [0m | [0m-2.381   [0m | [0m-2.433   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 28      [0m | [0m 0.06716 [0m | [0m-0.1805  [0m | [0m-1.72    [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 29      [0m | [0m 0.06716 [0m | [0m 0.5273  [0m | [0m-1.041   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 30      [0m | [0m 0.06716 [0m | [0m-0.4946  [0m | [0m-1.625   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 31      [0m | [0m 0.06716 [0m | [0m-2.109   [0m | [0m-3.003   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 32      [0m | [0m 0.06716 [0m | [0m-3.0     [0m | [0m-4.0     [0m |
| [0m 33      [0m | [0m 0.8712  [0m | [0m 2.0     [0m | [0m-4.0     [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 34      [0m | [0m 0.06716 [0m | [0m-3.0     [0m | [0m-1.0     [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 35      [0m | [0m 0.06716 [0m | [0m 2.0     [0m | [0m-1.0     [0m |
| [0m 36      [0m | [0m 0.8706  [0m | [0m 1.351   [0m | [0m-4.0     [0m |
| [0m 37      [0m | [0m 0.8504  [0m | [0m 2.0     [0m | [0m-3.075   [0m |
| [0m 38      [0m | [0m 0.8579  [0m | [0m 0.4609  [0m | [0m-3.124   [0m |
| [0m 39      [0m | [0m 0.7826  [0m | [0m-0.3339  [0m | [0m-3.335   [0m |
| [95m 40      [0m | [95m 0.8818  [0m | [95m 2.0     [0m | [95m-3.564   [0m |


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


| [0m 41      [0m | [0m 0.06716 [0m | [0m-3.0     [0m | [0m-3.055   [0m |
{'target': 0.8817805859576646, 'params': {'expC': 2.0, 'expGamma': -3.5637399929825033}}
---------- Optimizing <function rfc_cv at 0x7f347f0ab158>--------------
|   iter    |  target   | max_fe... | min_sa... | n_esti... |
-------------------------------------------------------------
| [0m 1       [0m | [0m 0.6199  [0m | [0m 0.8888  [0m | [0m 21.0    [0m | [0m 123.0   [0m |
| [95m 2       [0m | [95m 0.6859  [0m | [95m 0.2738  [0m | [95m 8.082   [0m | [95m 240.5   [0m |
| [0m 3       [0m | [0m 0.5912  [0m | [0m 0.6555  [0m | [0m 22.75   [0m | [0m 70.96   [0m |
| [0m 4       [0m | [0m 0.6367  [0m | [0m 0.5706  [0m | [0m 15.06   [0m | [0m 103.1   [0m |
| [0m 5       [0m | [0m 0.6641  [0m | [0m 0.191   [0m | [0m 8.461   [0m | [0m 148.9   [0m |
| [0m 6       [0m | [0m 0.6166  [0m | [0m 0.8141  [0m | [0m 16.14   [0m | [0m 89.26   [0m |
| [0m 7       