In [1]:
!pip install botorch ax-platform ioh

Collecting botorch
  Downloading botorch-0.12.0-py3-none-any.whl.metadata (11 kB)
Collecting ax-platform
  Downloading ax_platform-0.4.3-py3-none-any.whl.metadata (11 kB)
Collecting ioh
  Downloading ioh-0.3.18-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.4 kB)
Collecting pyro-ppl>=1.8.4 (from botorch)
  Downloading pyro_ppl-1.9.1-py3-none-any.whl.metadata (7.8 kB)
Collecting gpytorch==1.13 (from botorch)
  Downloading gpytorch-1.13-py3-none-any.whl.metadata (8.0 kB)
Collecting linear-operator==0.5.3 (from botorch)
  Downloading linear_operator-0.5.3-py3-none-any.whl.metadata (15 kB)
Collecting jaxtyping==0.2.19 (from gpytorch==1.13->botorch)
  Downloading jaxtyping-0.2.19-py3-none-any.whl.metadata (5.7 kB)
Collecting pyre-extensions (from ax-platform)
  Downloading pyre_extensions-0.0.32-py3-none-any.whl.metadata (4.0 kB)
Collecting pyro-api>=0.1.1 (from pyro-ppl>=1.8.4->botorch)
  Downloading pyro_api-0.1.2-py3-none-any.whl.metadata (2.5 kB)
Collecting typin

In [2]:
import numpy as np
import ioh
import math
import torch

from gpytorch.kernels import RBFKernel
from gpytorch.constraints import GreaterThan
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.priors import LogNormalPrior
from botorch.models.transforms import Normalize, Standardize
from botorch.models import SingleTaskGP
from botorch.acquisition import LogExpectedImprovement

from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.service.utils.instantiation import ObjectiveProperties
from ax.service.ax_client import AxClient



In [7]:
class Rembo():

    def __init__(self, lc_prior, pb_id, dim, embedding_dim, problem, root):
        self.debug = False #debug print statement
        self.prior_mean = lc_prior
        self.unwrapped_pb = problem
        self.dim = dim
        self.embedding_dim = embedding_dim
        self.pb_id = pb_id

        self.l = ioh.logger.Analyzer(root=root, folder_name="vanilla_f"+str(pb_id)+"_"+str(dim)+"d", algorithm_name="rembo")
        self.unwrapped_pb.attach_logger(self.l)



    def reset_problem(self):
        self.unwrapped_pb.reset()

    def apply_embedding(self, x):
        x = np.array(list(x.values()))

        x = self.transform_matrix @ x.transpose()

        for i in range(x.shape[0]):
            if x[i] < self.unwrapped_pb.bounds.lb[i]:
                x[i] = self.unwrapped_pb.bounds.lb[i]
            elif x[i] > self.unwrapped_pb.bounds.ub[i]:
                x[i] = self.unwrapped_pb.bounds.ub[i]

        return {"f"+str(self.pb_id):(self.unwrapped_pb(x), 0)}

    def calc_objective(self):
        #the x value will be wrong but the y should allow logger to properly log progress
        return ([0 for x in range(self.dim)], self.unwrapped_pb.optimum.y)


    def run(self):
        self.transform_matrix = np.random.normal(0, 1, (self.dim, self.embedding_dim))

        #covar = RBFKernel(ard_num_dims=self.embedding_dim, lengthscale_prior=LogNormalPrior(self.prior_mean, math.sqrt(3)), lengthscale_constraint=GreaterThan(1e-4))

        gs = GenerationStrategy(
            steps=[
                # Quasi-random initialization step
                GenerationStep(
                    model=Models.SOBOL,
                    num_trials=self.dim,  # How many trials should be produced from this generation step
                ),
                # Bayesian optimization step using the custom acquisition function
                GenerationStep(
                    model=Models.BOTORCH_MODULAR,
                    num_trials=9*self.dim,
                    # For `BOTORCH_MODULAR`, we pass in kwargs to specify what surrogate or acquisition function to use.
                    model_kwargs={
                        "surrogate": Surrogate(botorch_model_class=SingleTaskGP, covar_module_class=RBFKernel, covar_module_options=dict(ard_num_dims=self.embedding_dim, lengthscale_prior=LogNormalPrior(self.prior_mean, math.sqrt(3)), lengthscale_constraint=GreaterThan(1e-4))),"botorch_acqf_class":LogExpectedImprovement
                    },
                ),
            ]
        )

        ax_client = AxClient(generation_strategy=gs)

        params = [{"name":"x"+str(i),"type":"range","bounds":[-math.sqrt(self.embedding_dim),math.sqrt(self.embedding_dim)]} for i in range(self.embedding_dim)]

        ax_client.create_experiment(
            name="f"+str(self.pb_id)+"_experiment",
            parameters=params,
            objectives={
                "f"+str(self.pb_id): ObjectiveProperties(minimize=True),
            },
        )

        for i in range(10*self.dim):
            parameters, trial_index = ax_client.get_next_trial()
            if self.debug:
                print(trial_index)
            ax_client.complete_trial(trial_index=trial_index, raw_data=self.apply_embedding(parameters))


        parameters, values = ax_client.get_best_parameters()
        print(f"Best parameters: {parameters}")
        print(f"Corresponding mean: {values[0]}, covariance: {values[1]}")



In [None]:
torch.manual_seed(0)

dim=40
pb_id=8
root="./data"
problem = ioh.get_problem(pb_id, dimension=dim, instance=1, problem_class=ioh.ProblemClass.BBOB)

model = Rembo(1, pb_id, dim, int(dim/4), problem, root)

for i in range(5):
    print(i)
    model.run()

    model.reset_problem()

[INFO 12-30 05:57:15] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 12-30 05:57:15] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x0. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-30 05:57:15] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-30 05:57:15] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-30 05:57:15] ax.service.

0


  warn("Encountered exception in computing model fit quality: " + str(e))
[INFO 12-30 05:57:16] ax.service.ax_client: Generated new trial 7 with parameters {'x0': -1.508999, 'x1': -1.817596, 'x2': -3.150223, 'x3': 1.24228, 'x4': 1.647539, 'x5': 0.351691, 'x6': -0.876493, 'x7': -0.414609, 'x8': 2.275473, 'x9': -0.932426} using model Sobol.
[INFO 12-30 05:57:16] ax.service.ax_client: Completed trial 7 with data: {'f8': (1787603.278912, 0)}.
  warn("Encountered exception in computing model fit quality: " + str(e))
[INFO 12-30 05:57:16] ax.service.ax_client: Generated new trial 8 with parameters {'x0': -0.933274, 'x1': 0.789947, 'x2': 0.343894, 'x3': 2.254765, 'x4': -1.515207, 'x5': 0.95381, 'x6': 3.115347, 'x7': -2.075668, 'x8': 0.160132, 'x9': 2.370943} using model Sobol.
[INFO 12-30 05:57:16] ax.service.ax_client: Completed trial 8 with data: {'f8': (2968912.447476, 0)}.
  warn("Encountered exception in computing model fit quality: " + str(e))
[INFO 12-30 05:57:16] ax.service.ax_client:

Best parameters: {'x0': 0.5582013845124512, 'x1': 0.14052190623170535, 'x2': 0.5065463298983244, 'x3': 0.5199581125660595, 'x4': 0.44931644906084234, 'x5': -0.6187126541409373, 'x6': 0.019364410163081658, 'x7': -0.2537541838333617, 'x8': 0.31690995464161187, 'x9': 0.5433159667751162}
Corresponding mean: {'f8': 114815.18546937173}, covariance: {'f8': {'f8': 531575.785822261}}
1


  warn("Encountered exception in computing model fit quality: " + str(e))
[INFO 12-30 06:51:03] ax.service.ax_client: Generated new trial 7 with parameters {'x0': -0.919928, 'x1': -2.152617, 'x2': -2.02625, 'x3': -1.007504, 'x4': 2.436834, 'x5': -3.15119, 'x6': -0.142965, 'x7': -2.300177, 'x8': -1.280806, 'x9': 1.193316} using model Sobol.
[INFO 12-30 06:51:03] ax.service.ax_client: Completed trial 7 with data: {'f8': (4253840.221518, 0)}.
  warn("Encountered exception in computing model fit quality: " + str(e))
[INFO 12-30 06:51:03] ax.service.ax_client: Generated new trial 8 with parameters {'x0': -1.504553, 'x1': 0.901549, 'x2': 1.429186, 'x3': -2.544778, 'x4': -2.57344, 'x5': -2.256949, 'x6': 1.308186, 'x7': -1.145461, 'x8': -1.604423, 'x9': -1.845562} using model Sobol.
[INFO 12-30 06:51:03] ax.service.ax_client: Completed trial 8 with data: {'f8': (3068893.762788, 0)}.
  warn("Encountered exception in computing model fit quality: " + str(e))
[INFO 12-30 06:51:03] ax.service.ax_cl

Best parameters: {'x0': 0.12826651453735538, 'x1': 0.16929263615042212, 'x2': -0.14982599783092088, 'x3': -0.2547699616717467, 'x4': -0.1609145388649722, 'x5': 0.7190262393171527, 'x6': 0.16416353622803692, 'x7': 0.4254191508562233, 'x8': 0.450887737853209, 'x9': 0.2720298016512568}
Corresponding mean: {'f8': 106928.37143443478}, covariance: {'f8': {'f8': 891214.7841662357}}
2


  warn("Encountered exception in computing model fit quality: " + str(e))
[INFO 12-30 07:43:23] ax.service.ax_client: Generated new trial 7 with parameters {'x0': -0.126539, 'x1': 2.262527, 'x2': -1.501852, 'x3': 0.113125, 'x4': -0.283155, 'x5': -0.164174, 'x6': -2.921378, 'x7': 2.665239, 'x8': 0.879551, 'x9': -0.96317} using model Sobol.
[INFO 12-30 07:43:23] ax.service.ax_client: Completed trial 7 with data: {'f8': (1926056.712936, 0)}.
  warn("Encountered exception in computing model fit quality: " + str(e))
[INFO 12-30 07:43:23] ax.service.ax_client: Generated new trial 8 with parameters {'x0': -0.742023, 'x1': -1.16256, 'x2': 0.028416, 'x3': 1.705475, 'x4': 1.785758, 'x5': -1.477688, 'x6': 0.23048, 'x7': 1.314406, 'x8': 2.289756, 'x9': 0.859907} using model Sobol.
[INFO 12-30 07:43:23] ax.service.ax_client: Completed trial 8 with data: {'f8': (1328376.224008, 0)}.
  warn("Encountered exception in computing model fit quality: " + str(e))
[INFO 12-30 07:43:23] ax.service.ax_client: 