The Bayesian Optimisation for 2 dimensional input data and its convergence situation to the local optimum initial samples

We are using the bell curve function as the objective function


In [2]:
import numpy as np
from matplotlib import pyplot as plt
import simple_GP_temperature.helper_sklearn as helper
from simple_GP_temperature.BO_GP_temperature import BOTemperatureGP
from simple_GP_temperature.random_initial import RandomInit


In [3]:
class Bellcurve:
    def get_evaluation_score(self, X, noise_scale=0.01):
        """
        Run the experiment with given conditions
        :return: float The result(yield-e_factor) of experiment (round to 2 decimal places)
        """
        noise = np.random.normal(loc=0, scale=noise_scale)
        bellcurve = (4*np.exp(-X[0]**2) + 6*np.exp(-4*X[1]**2))

        return bellcurve+noise


In [4]:
class RandomInit2D:
    def generate_initial_samples(self, low=[-5,-5], high=[5, 5], n=1):
        X = helper.get_random_X(low=low, high=high, samples=n).flatten()
        return X

Random Initial Sampling

In [5]:
random2d = RandomInit2D()

In [6]:
objective = Bellcurve()

In [7]:
bayes_opt = BOTemperatureGP(evaluation_component=objective, initial_method=random2d, lower_bound=[-5,-5], upper_bound=[5,5])

In [8]:
mappings = bayes_opt.optimise()

next sample to be evaluated: [ 0.54 -4.24], score: 2.9761389162641105
next sample to be evaluated: [ 0.22 -3.72], score: 3.8211517341159973
next sample to be evaluated: [-0.01 -3.49], score: 4.006593296056119
next sample to be evaluated: [-0.19 -3.5 ], score: 3.8670024160780714
next sample to be evaluated: [ 0.22 -3.33], score: 3.82067291074696
next sample to be evaluated: [ 0.04 -3.38], score: 4.0000013812605495
next sample to be evaluated: [-1.0658141e-13 -3.5400000e+00], score: 3.992608853896108
next sample to be evaluated: [-0.04 -3.41], score: 4.003053144201008
next sample to be evaluated: [-2.6  -4.43], score: -0.006027209611480284
next sample to be evaluated: [3.97 4.72], score: 0.0008801025102975063
next sample to be evaluated: [4.91 2.24], score: -0.0025309830095910126
next sample to be evaluated: [1.35 4.9 ], score: 0.6482191443009664
next sample to be evaluated: [-0.11  4.43], score: 3.9483594931097
next sample to be evaluated: [-0.22  4.56], score: 3.825450682161425
next sa

In [10]:
type(mappings)

dict

In [12]:
list(mappings.keys())

[(0.6899999999998787, -4.470000000000011),
 (-4.800000000000004, -4.000000000000021),
 (0.5399999999998819, -4.240000000000016),
 (0.21999999999988873, -3.7200000000000273),
 (-0.010000000000106368, -3.490000000000032),
 (-0.19000000000010253, -3.500000000000032),
 (0.21999999999988873, -3.3300000000000356),
 (0.039999999999892566, -3.3800000000000345),
 (-1.0658141036401503e-13, -3.540000000000031),
 (-0.04000000000010573, -3.410000000000034),
 (-2.600000000000051, -4.430000000000012),
 (3.969999999999809, 4.719999999999793),
 (4.909999999999789, 2.2399999999998457),
 (1.3499999999998646, 4.899999999999789),
 (-0.11000000000010424, 4.429999999999799),
 (-0.2200000000001019, 4.559999999999796),
 (-0.18000000000010274, 4.239999999999803),
 (0.2499999999998881, 4.239999999999803),
 (0.11999999999989086, 4.569999999999796),
 (-0.020000000000106155, 3.43999999999982),
 (-0.020000000000106155, 3.389999999999821),
 (0.14999999999989022, 3.4199999999998205),
 (-0.050000000000105516, 4.8599999

In [13]:
list(mappings.values())

[2.4784298289926303,
 -0.011281642899068533,
 2.9761389162641105,
 3.8211517341159973,
 4.006593296056119,
 3.8670024160780714,
 3.82067291074696,
 4.0000013812605495,
 3.992608853896108,
 4.003053144201008,
 -0.006027209611480284,
 0.0008801025102975063,
 -0.0025309830095910126,
 0.6482191443009664,
 3.9483594931097,
 3.825450682161425,
 3.8687601636547515,
 3.7602426745592914,
 3.9466479520402227,
 3.9942798646936852,
 3.9957649636151733,
 3.9180391991300425,
 3.979034482503642,
 3.8351022483317525,
 3.989408522041606,
 3.99196680649536,
 3.757617006174098,
 3.972648856086763,
 4.0085462547532815,
 3.922249334444661,
 4.566507743451557,
 5.289059416428578,
 9.38345092567275,
 3.849494285614245,
 8.763376958385155,
 4.730837291430368,
 7.59355824081499,
 5.310408869204007,
 0.19142818161339456,
 8.888725655160462,
 0.4157216707171129,
 0.11487902246326415,
 9.92352488222885,
 9.364142996869212,
 5.554957735545747,
 0.11193219472703447,
 0.009039727067151638,
 0.039161091869635456,
 0.