In [None]:
import pylab as plt
from jax import random, numpy as jnp

from bojax.experiment import NewExperimentRequest, TrialUpdate
from bojax.parameter_space import ParameterSpace, Parameter, ContinuousPrior
from bojax.service import BayesianOptimisation

if __name__ == '__main__':
    num_steps = 20


    def objective(x):
        return -0.5 * jnp.sum(x ** 4 - 16 * x ** 2 + 5 * x)


    def example(ndim):

        lower_bound = 39.16616 * ndim
        upper_bound = 39.16617 * ndim
        print(f"Optimal value in ({lower_bound}, {upper_bound}).")

        x_max = -2.903534

        print(f"Global optimum at {jnp.ones(ndim) * x_max}")

        parameter_space = ParameterSpace(
            parameters=[
                Parameter(
                    name=f'x{i}',
                    prior=ContinuousPrior(
                        lower=-5,
                        upper=5.,
                        mode=0.,
                        uncert=10.
                    )
                )
                for i in range(ndim)
            ]
        )
        new_experiment_request = NewExperimentRequest(
            parameter_space=parameter_space,
            init_explore_size=10
        )
        bo_experiment = BayesianOptimisation.create_new_experiment(new_experiment=new_experiment_request)

        for i in range(num_steps):
            trial_id = bo_experiment.create_new_trial(
                key=random.PRNGKey(i),
                random_explore=False,
                beta=1.
            )
            trial = bo_experiment.get_trial(trial_id=trial_id)
            params = []
            for param_name in sorted(trial.param_values.keys()):
                param = trial.param_values[param_name]
                params.append(param.value)
            params = jnp.asarray(params)
            print(params)

            obj_val = float(objective(params))
            bo_experiment.post_measurement(
                trial_id=trial_id,
                trial_update=TrialUpdate(ref_id='a', objective_measurement=obj_val)
            )
            fig = bo_experiment.visualise()
            plt.show()
            plt.close('all')


    example(5)
