In [1]:
import os
from contextlib import contextmanager

from ax.utils.testing.mock import fast_botorch_optimize_context_manager
import plotly.io as pio

# Ax uses Plotly to produce interactive plots. These are great for viewing and analysis,
# though they also lead to large file sizes, which is not ideal for files living in GH.
# Changing the default to `png` strips the interactive components to get around this.
pio.renderers.default = "png"

SMOKE_TEST = os.environ.get("SMOKE_TEST")
NUM_EVALS = 10 if SMOKE_TEST else 30


@contextmanager
def dummy_context_manager():
    yield


if SMOKE_TEST:
    fast_smoke_test = fast_botorch_optimize_context_manager
else:
    fast_smoke_test = dummy_context_manager

In [2]:
from typing import Optional

from botorch.models.gpytorch import GPyTorchModel
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP
from torch import Tensor

In [3]:
NUM_EVALS = 10

In [4]:
class SimpleCustomGP(ExactGP, GPyTorchModel):

    _num_outputs = 1  # to inform GPyTorchModel API

    def __init__(self, train_X, train_Y, train_Yvar: Optional[Tensor] = None):
        # NOTE: This ignores train_Yvar and uses inferred noise instead.
        # squeeze output dim before passing train_Y to ExactGP
        super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(
            base_kernel=RBFKernel(ard_num_dims=train_X.shape[-1]),
        )
        self.to(train_X)  # make sure we're on the right device/dtype

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

In [5]:
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate

In [6]:
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models


gs = GenerationStrategy(
    steps=[
        # Quasi-random initialization step
        GenerationStep(
            model=Models.SOBOL,
            num_trials=5,  # How many trials should be produced from this generation step
        ),
        # Bayesian optimization step using the custom acquisition function
        GenerationStep(
            model=Models.BOTORCH_MODULAR,
            num_trials=-1,  # No limitation on how many trials should be produced from this step
            # For `BOTORCH_MODULAR`, we pass in kwargs to specify what surrogate or acquisition function to use.
            model_kwargs={
                "surrogate": Surrogate(SimpleCustomGP),
            },
        ),
    ]
)

In [7]:
import torch
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties
from botorch.test_functions import Branin


# Initialize the client - AxClient offers a convenient API to control the experiment
ax_client = AxClient(generation_strategy=gs)
# Setup the experiment
ax_client.create_experiment(
    name="branin_test_experiment",
    parameters=[
        {
            "name": "x1",
            "type": "range",
            # It is crucial to use floats for the bounds, i.e., 0.0 rather than 0.
            # Otherwise, the parameter would be inferred as an integer range.
            "bounds": [-5.0, 10.0],
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [0.0, 15.0],
        },
    ],
    objectives={
        "branin": ObjectiveProperties(minimize=True),
    },
)
# Setup a function to evaluate the trials
branin = Branin()


def evaluate(parameters):
    x = torch.tensor([[parameters.get(f"x{i+1}") for i in range(2)]])
    # The GaussianLikelihood used by our model infers an observation noise level,
    # so we pass an sem value of NaN to indicate that observation noise is unknown
    return {"branin": (branin(x).item(), float("nan"))}

[INFO 12-12 20:13:57] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 12-12 20:13:57] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-12 20:13:57] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 12-12 20:13:57] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).


In [8]:
with fast_smoke_test():
    for i in range(NUM_EVALS):
        parameters, trial_index = ax_client.get_next_trial()
        # Local evaluation here can be replaced with deployment to external system.
        ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))

[INFO 12-12 20:14:04] ax.service.ax_client: Generated new trial 0 with parameters {'x1': -2.415736, 'x2': 5.918705}.
[INFO 12-12 20:14:04] ax.service.ax_client: Completed trial 0 with data: {'branin': (24.720219, nan)}.
[INFO 12-12 20:14:04] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 5.900465, 'x2': 11.740656}.
[INFO 12-12 20:14:04] ax.service.ax_client: Completed trial 1 with data: {'branin': (131.987564, nan)}.
[INFO 12-12 20:14:04] ax.service.ax_client: Generated new trial 2 with parameters {'x1': -0.048725, 'x2': 0.426945}.
[INFO 12-12 20:14:04] ax.service.ax_client: Completed trial 2 with data: {'branin': (51.523506, nan)}.
[INFO 12-12 20:14:04] ax.service.ax_client: Generated new trial 3 with parameters {'x1': 9.444746, 'x2': 12.474402}.
[INFO 12-12 20:14:04] ax.service.ax_client: Completed trial 3 with data: {'branin': (100.050247, nan)}.
[INFO 12-12 20:14:04] ax.service.ax_client: Generated new trial 4 with parameters {'x1': 5.22983, 'x2': 9.105386}.
[IN

In [10]:
ax_client.get_trials_data_frame()



Unnamed: 0,trial_index,arm_name,trial_status,generation_method,branin,x1,x2
0,0,0_0,COMPLETED,Sobol,24.720219,-2.415736,5.918705
1,1,1_0,COMPLETED,Sobol,131.987564,5.900465,11.740656
2,2,2_0,COMPLETED,Sobol,51.523506,-0.048725,0.426945
3,3,3_0,COMPLETED,Sobol,100.050247,9.444746,12.474402
4,4,4_0,COMPLETED,Sobol,77.089905,5.22983,9.105386
5,5,5_0,COMPLETED,BoTorch,1.88196,9.675115,3.782731
6,6,6_0,COMPLETED,BoTorch,25.033379,-1.185895,4.688542
7,7,7_0,COMPLETED,BoTorch,59.078041,-2.323631,2.936244
8,8,8_0,COMPLETED,BoTorch,36.611038,6.004406,5.270057
9,9,9_0,COMPLETED,BoTorch,6.765411,10.0,5.198924


In [11]:
from botorch.models.fully_bayesian import PyroModel