In [1]:
from botorch.models.gpytorch import GPyTorchModel
from gpytorch.distributions import MultivariateNormal
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.priors import GammaPrior


class SimpleCustomGP(ExactGP, GPyTorchModel):

    _num_outputs = 1  # to inform GPyTorchModel API
    
    def __init__(self, train_X, train_Y):
        # squeeze output dim before passing train_Y to ExactGP
        super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(
            base_kernel=RBFKernel(ard_num_dims=train_X.shape[-1]),
        )
        self.to(train_X)  # make sure we're on the right device/dtype
        
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

In [2]:
from botorch.fit import fit_gpytorch_model

def _get_and_fit_simple_custom_gp(Xs, Ys, **kwargs):
    model = SimpleCustomGP(Xs[0], Ys[0])
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    fit_gpytorch_model(mll)
    return model

In [3]:
import random
import numpy as np

def branin(parameterization, *args):
    x1, x2 = parameterization["x1"], parameterization["x2"]
    y = (x2 - 5.1 / (4 * np.pi ** 2) * x1 ** 2 + 5 * x1 / np.pi - 6) ** 2
    y += 10 * (1 - 1 / (8 * np.pi)) * np.cos(x1) + 10
    # let's add some synthetic observation noise
    y += random.normalvariate(0, 0.1)
    return {"branin": (y, 0.0)}

In [4]:
from ax import ParameterType, RangeParameter, SearchSpace

search_space = SearchSpace(
    parameters=[
        RangeParameter(
            name="x1", parameter_type=ParameterType.FLOAT, lower=-5, upper=10
        ),
        RangeParameter(
            name="x2", parameter_type=ParameterType.FLOAT, lower=0, upper=15
        ),
    ]
)

In [5]:
from ax import SimpleExperiment

exp = SimpleExperiment(
    name="test_branin",
    search_space=search_space,
    evaluation_function=branin,
    objective_name="branin",
    minimize=True,
)


In [6]:
from ax.modelbridge import get_sobol

sobol = get_sobol(exp.search_space)
exp.new_batch_trial(generator_run=sobol.gen(5))

BatchTrial(experiment_name='test_branin', index=0, status=TrialStatus.CANDIDATE)

In [7]:
from ax.modelbridge.factory import get_botorch

for i in range(5):
    print(f"Running optimization batch {i+1}/5...")
    model = get_botorch(
        experiment=exp,
        data=exp.eval(),
        search_space=exp.search_space,
        model_constructor=_get_and_fit_simple_custom_gp,
    )
    batch = exp.new_trial(generator_run=model.gen(1))
    
print("Done!")

Running optimization batch 1/5...
Running optimization batch 2/5...
Running optimization batch 3/5...
Running optimization batch 4/5...
Running optimization batch 5/5...
Done!


In [8]:
help(model)

Help on TorchModelBridge in module ax.modelbridge.torch object:

class TorchModelBridge(ax.modelbridge.array.ArrayModelBridge)
 |  TorchModelBridge(experiment: ax.core.experiment.Experiment, search_space: ax.core.search_space.SearchSpace, data: ax.core.data.Data, model: ax.models.torch_base.TorchModel, transforms: List[Type[ax.modelbridge.transforms.base.Transform]], transform_configs: Union[Dict[str, Dict[str, Union[int, float, str, botorch.acquisition.acquisition.AcquisitionFunction, Dict[str, Any], NoneType]]], NoneType] = None, torch_dtype: Union[torch.dtype, NoneType] = None, torch_device: Union[torch.device, NoneType] = None, status_quo_name: Union[str, NoneType] = None, status_quo_features: Union[ax.core.observation.ObservationFeatures, NoneType] = None, optimization_config: Union[ax.core.optimization_config.OptimizationConfig, NoneType] = None, fit_out_of_design: bool = False, default_model_gen_options: Union[Dict[str, Union[int, float, str, botorch.acquisition.acquisition.Acqu