In [8]:
from ax.service.ax_client import AxClient
from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep
from ax.modelbridge.registry import Models
import torch

## How to view generation strategy on a given `AxClient` optimization: 

In [2]:
ax_client = AxClient()

[INFO 05-05 07:39:43] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 2 decimal points.


In [4]:
ax_client.create_experiment(
    parameters=[{"name": "x", "type": "range", "bounds": [-5.0, 10.0]}]
)

[INFO 05-05 07:40:22] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials]). Iterations after 5 will take longer to generate due to  model-fitting.


In [5]:
ax_client.generation_strategy  # To view the generation strategy summary

GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials])

In [6]:
ax_client.generation_strategy._steps  # To view all the settings of the generation strategy steps

[GenerationStep(model=<Models.SOBOL: 'Sobol'>, num_trials=5, min_trials_observed=3, max_parallelism=None, use_update=False, enforce_num_trials=True, model_kwargs={'deduplicate': True, 'seed': None}, model_gen_kwargs=None, index=0),
 GenerationStep(model=<Models.GPEI: 'GPEI'>, num_trials=-1, min_trials_observed=0, max_parallelism=3, use_update=False, enforce_num_trials=True, model_kwargs=None, model_gen_kwargs=None, index=1)]

## Solution 1: Create custom generation strategy with all the same settings as would've been automatically chosen, but also with `model_kwargs` passed to the GPEI step
Kwargs passed as `model_kwargs` are distributed between the underlying `Model` and `ModelBridge` according to their keyword names. To see what the available kwargs are, refer to the `model_class` and `bridge_class` properties of the `ModelSetup` corresponding to a given entry in the [`Models` registry enum](https://github.com/facebook/Ax/blob/master/ax/modelbridge/registry.py#L183). In this case, we will be passing kwargs to `Models.GPEI`, and the [corresponding `ModelSetup`](https://github.com/facebook/Ax/blob/master/ax/modelbridge/registry.py#L137-L142) includes `TorchModelBridge` as its model bridge and `BotorchModel` as its model. `TorchModelBridge` takes as kwargs `torch_dtype` and `torch_device` ([API docs](https://ax.dev/api/modelbridge.html#module-ax.modelbridge.torch), [source code](https://github.com/facebook/Ax/blob/master/ax/modelbridge/torch.py#L56-L57), so we will be passing those as `model_kwargs` to the corresponding generation step. Here is how:

In [10]:
gs = GenerationStrategy(
    steps=[
        # I omit a few settings when copying from the auto-selected generation strategy above, since those
        # settings are just the defaults, but you can include them too if you'd like. Refer to the docs on
        # `GenerationStep` for meanings of these settings: 
        # https://ax.dev/api/modelbridge.html#ax.modelbridge.generation_strategy.GenerationStep
        GenerationStep(
            model=Models.SOBOL, 
            num_trials=5, 
            min_trials_observed=3, 
            model_kwargs={'deduplicate': True, 'seed': None}, 
        ),
        GenerationStep(
            model=Models.GPEI,
            num_trials=-1,
            max_parallelism=3,  # Can set higher parallelism if needed
            model_kwargs = {"torch_dtype": torch.float, "torch_device": torch.device("cuda")}
        )
    ]
)

In [20]:
# Can also view the kwargs and their typing  for a given `Models` registry enum entry like so:
model_kwargs, model_bridge_kwargs = Models.GPEI.view_kwargs()
model_bridge_kwargs  # Showing the bridge kwargs, since those are what we are interested in in this case.

{'experiment': ax.core.experiment.Experiment,
 'search_space': ax.core.search_space.SearchSpace,
 'data': ax.core.data.Data,
 'model': ax.models.torch_base.TorchModel,
 'transforms': typing.List[typing.Type[ax.modelbridge.transforms.base.Transform]],
 'transform_configs': typing.Union[typing.Dict[str, typing.Dict[str, typing.Union[int, float, str, botorch.acquisition.acquisition.AcquisitionFunction]]], NoneType],
 'torch_dtype': typing.Union[torch.dtype, NoneType],
 'torch_device': typing.Union[torch.device, NoneType],
 'status_quo_name': typing.Union[str, NoneType],
 'status_quo_features': typing.Union[ax.core.observation.ObservationFeatures, NoneType],
 'optimization_config': typing.Union[ax.core.optimization_config.OptimizationConfig, NoneType]}

### Now we instantiate `AxClient` with our new generation strategy:

In [11]:
ax_client_with_custom_GS = AxClient(generation_strategy=gs)

[INFO 05-05 07:56:30] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 2 decimal points.


In [12]:
ax_client_with_custom_GS.create_experiment(
    parameters=[{"name": "x", "type": "range", "bounds": [-5.0, 10.0]}]
)

In [14]:
ax_client_with_custom_GS.generation_strategy

GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials])

In [15]:
ax_client_with_custom_GS.generation_strategy._steps  # Device and dtype settings are now propagated!

[GenerationStep(model=<Models.SOBOL: 'Sobol'>, num_trials=5, min_trials_observed=3, max_parallelism=None, use_update=False, enforce_num_trials=True, model_kwargs={'deduplicate': True, 'seed': None}, model_gen_kwargs=None, index=0),
 GenerationStep(model=<Models.GPEI: 'GPEI'>, num_trials=-1, min_trials_observed=0, max_parallelism=3, use_update=False, enforce_num_trials=True, model_kwargs={'torch_dtype': torch.float32, 'torch_device': device(type='cuda')}, model_gen_kwargs=None, index=1)]

## Solution 2: Hack existing generation strategy by replacing model kwargs directly in one of its steps

In [22]:
ax_client = AxClient()

[INFO 05-05 08:09:57] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 2 decimal points.


In [23]:
ax_client.create_experiment(
    parameters=[{"name": "x", "type": "range", "bounds": [-5.0, 10.0]}]
)

[INFO 05-05 08:10:04] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 trials, GPEI for subsequent trials]). Iterations after 5 will take longer to generate due to  model-fitting.


In [24]:
ax_client.generation_strategy._steps[1]._replace(  # We modify the original auto-selected generation strategy.
    model_kwargs={"torch_dtype": torch.float, "torch_device": torch.device("cuda")}
)

GenerationStep(model=<Models.GPEI: 'GPEI'>, num_trials=-1, min_trials_observed=0, max_parallelism=3, use_update=False, enforce_num_trials=True, model_kwargs={'torch_dtype': torch.float32, 'torch_device': device(type='cuda')}, model_gen_kwargs=None, index=1)

In [25]:
ax_client.generation_strategy._steps  # Steps of the generation strategy are now modified.

[GenerationStep(model=<Models.SOBOL: 'Sobol'>, num_trials=5, min_trials_observed=3, max_parallelism=None, use_update=False, enforce_num_trials=True, model_kwargs={'deduplicate': True, 'seed': None}, model_gen_kwargs=None, index=0),
 GenerationStep(model=<Models.GPEI: 'GPEI'>, num_trials=-1, min_trials_observed=0, max_parallelism=3, use_update=False, enforce_num_trials=True, model_kwargs=None, model_gen_kwargs=None, index=1)]

## Done!
Whether you choose solution 1 or 2, once you've set up your `AxClient` with a custom or modified generation strategy, you can start calling `ax_client.get_next_trial` as usual.