# ImageNet - Bayesian Optimization 
## 02463 Active ML and Agency - Group BO 2

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
from train_model import train_model

In [None]:
import skopt
from skopt import gp_minimize

space = [
    skopt.space.Real(0, 1, name='module__dropout'),
]

def objective(x):
    # Decode the JSON encoded categorical values.
    params = {
        'module__dropout': x[0],
    }
    print("Param values: ", params)
    accuracy = train_model(params, num_epochs=10, n_training_samples=50000)
    # Return negative accuracy for minimization.
    return -accuracy

opt_bo = gp_minimize(
    objective,
    space,
    acq_func='EI',
    n_calls=20,
    n_random_starts=5,
    verbose=True,
    xi=0.1,
    noise=0.01**2
)

In [None]:
# do random search
from skopt import dummy_minimize
opt_random = dummy_minimize(
    objective,
    space,
    n_calls=20,
    verbose=True,
    random_state=42
)

In [None]:
## comparison between random search and bayesian optimization
## we can plot the maximum oob per iteration of the sequence

# collect the maximum each iteration of BO
y_bo = np.maximum.accumulate(-opt_bo.func_vals).ravel()
y_random = np.maximum.accumulate(-opt_random.func_vals).ravel()
# define iteration number
xs = range(1, len(y_bo) + 1)

plt.plot(xs, y_random, 'o-', color = 'red', label='Random Search')
plt.plot(xs, y_bo, 'o-', color = 'blue', label='Bayesian Optimization')
plt.legend()
plt.xlabel('Iterations')
plt.ylabel('accuracy')
plt.title('Bayesian Optimization')
plt.show()

In [None]:
#print(opt_bo.func_vals.min(), opt_bo.func_vals.max())
#print(min(opt_bo.x_iters), max(opt_bo.x_iters))
plt.scatter(np.abs(opt_bo.func_vals), opt_bo.x_iters, color='blue')
plt.xlabel("Accuracy")
plt.xlim(0.55, 0.7)
plt.ylabel("Dropout value")
plt.grid()
plt.show()

In [None]:
plt.scatter(np.abs(opt_random.func_vals), opt_random.x_iters, color='blue')
plt.xlabel("Accuracy")
plt.xlim(0.55, 0.7)
plt.ylabel("Dropout value")
plt.grid()
plt.show()