In [1]:
%gui qt
import GPy
import numpy as np
from matplotlib import pyplot as plt
from pymodaq.utils.data import DataRaw, Axis
from pymodaq.utils import math_utils as mutils
from scipy.optimize import minimize

In [24]:
def black_box(x: float):
    #return (np.sin(x) + np.cos(2*x)) * mutils.gauss1D(x, 0, np.pi)
    slope = 0.1
    coeff = 1
    x0s = np.pi/6
    return (coeff * slope) ** 2 / ((coeff * slope) ** 2 + (x - x0s) ** 2)

bounds = (-np.pi, np.pi)
axis = Axis('x', 'rad', data=np.linspace(bounds[0], bounds[1], 100, endpoint=True))
dwa = DataRaw('blackbox', data=[black_box(axis.get_data())], axes=[axis])

viewer_gp = dwa.plot('qt')
viewer_ucb = dwa.plot('qt')

In [25]:
X = np.random.choice(axis.get_data(), (5,))
X = np.expand_dims(X, 1)
Y = black_box(X)

dwa_choice = DataRaw('choice', data=[np.squeeze(black_box(X))], axes=[Axis('choices', data=np.squeeze(X))],
                     symbol_size=12,
                     symbol='d')

kernel = GPy.kern.RBF(input_dim=1)
shape = (len(axis))
tol = 1e-5

kappa = 10
kappa_decay_rate = 0.9

for ind in range(100):
    m = GPy.models.GPRegression(X, black_box(X), kernel)
    
    m.optimize()
    #m.plot(plot_raw=False)
    
    def ucb(x: np.ndarray, model):
        y_pred, y_std = m.predict(np.expand_dims(x, 1))
        quantiles = m.predict_quantiles(np.expand_dims(x, 1))
        ucb = y_pred + kappa * abs((y_pred - quantiles[0]))
        return ucb
    
    mean, variance = m.predict(np.expand_dims(axis.get_data(), 1))
    quantiles = m.predict_quantiles(np.expand_dims(axis.get_data(), 1))
    mean_likely = m.likelihood.gp_link.transf(mean)
    variance_likely = m.likelihood.gp_link.transf(variance)
    gradient = m.predictive_gradients(np.expand_dims(axis.get_data(), 1))
    
    ucb_plot = ucb(axis.get_data(), m)
    ucb_optimize = lambda x, model: -float(ucb(x, model)[0])
    ind_ucb_max = np.argmax(ucb_plot)
    x0 = axis.get_data_at(ind_ucb_max)
    dwa_predict = DataRaw('predict', data=[mean_likely.reshape(shape),
                                           ucb_plot.reshape(shape),
                                           black_box(axis.get_data()).reshape(shape)],
                          axes=[axis],
                          errors=[np.abs(mean-quantiles[0]).reshape(shape),
                                  np.zeros(shape),
                                  np.zeros(shape)
                                 ],
                         labels=['mean', 'ucb', 'black box'])
    
    dwa_choice = DataRaw('choice', data=[np.squeeze(black_box(X))],
                         axes=[Axis('choices', data=np.squeeze(X))],
                         labels=['Probed points'],
                         symbol_size=12,
                         symbol='d')
    viewer_gp = dwa_predict.plot('qt', scatter_dwa=dwa_choice, viewer=viewer_gp) 
    res = minimize(ucb_optimize, x0=x0, args=(m,), bounds=(bounds,),)
    #print(res)

    dwa_ucb = DataRaw('ucb', data=[np.array([ucb_optimize(np.array([x]), m) for x in axis.get_data()])],
                      axes=[axis], labels=['ucb'])
    viewer_ucb = dwa_ucb.plot('qt', scatter_dwa=DataRaw('next', dim='Data1D',
                                                        data=[np.array([ucb_optimize(res.x, m)])],
                                           labels=['next'],
                                           axes=[Axis('next', data=res.x)]),
                             viewer=viewer_ucb)
    
    #print(f'next probed point is: {res.x}')

    X = np.concatenate((X, np.expand_dims(res.x, 1)), axis=0)
    Y = np.concatenate((Y, np.expand_dims(black_box(res.x), 1)), axis=0)

    kappa *= kappa_decay_rate
    print(f'kappa is {kappa}')
    
    if np.abs(res.x - X[-2]) < tol:
        print(f'optimisation done in {ind} steps. Best value reached at {X[np.argmax(Y)]}, should be {np.pi/6}')
        break
    #print(X.T)


kappa is 9.0
kappa is 8.1
kappa is 7.29
kappa is 6.561
kappa is 5.9049000000000005
kappa is 5.3144100000000005
kappa is 4.7829690000000005
kappa is 4.3046721
kappa is 3.8742048900000006
kappa is 3.4867844010000004
kappa is 3.1381059609000004
kappa is 2.82429536481
kappa is 2.541865828329
kappa is 2.2876792454961
kappa is 2.05891132094649
kappa is 1.853020188851841
kappa is 1.6677181699666568
kappa is 1.5009463529699911
kappa is 1.350851717672992
kappa is 1.2157665459056928
kappa is 1.0941898913151236
kappa is 0.9847709021836112
optimisation done in 21 steps. Best value reached at [0.52359058], should be 0.5235987755982988
