In [73]:
import os

os.environ["AWS_ACCESS_KEY_ID"] = "foo"
os.environ["AWS_SECRET_ACCESS_KEY"] = "bar"
os.environ["AWS_DEFAULT_REGION"] = "ap-southeast-2"

In [74]:
from pathlib import Path

import torchx

from torchx import specs
from torchx.components import utils

src_log_dir = f"{os.getcwd()}/output"
dst_log_dir = "/output"

def trainer(
    log_path: str,
    hidden_size_1: int,
    hidden_size_2: int,
    learning_rate: float,
    # epochs: int,
    dropout: float,
    batch_size: int,
    trial_idx: int = -1,
) -> specs.AppDef:

    # define the log path so we can pass it to the TorchX ``AppDef``
    if trial_idx >= 0:
        log_path = Path(log_path).joinpath(str(trial_idx)).absolute().as_posix()

    return utils.sh(
        "python",
        "mnist.py",
        "--log_path",
        log_path,
        "--hidden_size_1",
        str(hidden_size_1),
        "--hidden_size_2",
        str(hidden_size_2),
        "--learning_rate",
        str(learning_rate),
        # "--epochs",
        # str(epochs),
        "--dropout",
        str(dropout),
        "--batch_size",
        str(batch_size),
        # other config options
        # name="trainer",
        # script="mnist.py",
        image="ghcr.io/jbris/torchx-aws-test:1.0.0",
        mounts=[
            "type=bind",
            f"src={src_log_dir}",
            f"dst={dst_log_dir}",
            "perm=rwm"
        ]
    )

import tempfile
from ax.runners.torchx import TorchXRunner

# Make a temporary dir to log our results into

scheduler = "aws_batch"
scheduler="local_cwd"
scheduler="local_docker"

ax_runner = TorchXRunner(
    tracker_base="/tmp/",
    component=trainer,
    scheduler=scheduler,
    component_const_params={"log_path": dst_log_dir},
    cfg={"queue": "torchx_queue"},
)

In [75]:
from ax.core import (
    ChoiceParameter,
    ParameterType,
    RangeParameter,
    SearchSpace,
)

parameters = [
    # NOTE: In a real-world setting, hidden_size_1 and hidden_size_2
    # should probably be powers of 2, but in our simple example this
    # would mean that ``num_params`` can't take on that many values, which
    # in turn makes the Pareto frontier look pretty weird.
    RangeParameter(
        name="hidden_size_1",
        lower=4,
        upper=8,
        parameter_type=ParameterType.INT,
        log_scale=True,
    ),
    RangeParameter(
        name="hidden_size_2",
        lower=4,
        upper=8,
        parameter_type=ParameterType.INT,
        log_scale=True,
    ),
    RangeParameter(
        name="learning_rate",
        lower=1e-2,
        upper=1e-1,
        parameter_type=ParameterType.FLOAT,
        log_scale=True,
    ),
    # RangeParameter(
    #     name="epochs",
    #     lower=1,
    #     upper=2,
    #     parameter_type=ParameterType.INT,
    # ),
    RangeParameter(
        name="dropout",
        lower=0.0,
        upper=0.5,
        parameter_type=ParameterType.FLOAT,
    ),
    ChoiceParameter(  # NOTE: ``ChoiceParameters`` don't require log-scale
        name="batch_size",
        values=[16, 32],
        parameter_type=ParameterType.INT,
        is_ordered=True,
        sort_values=True,
    ),
]

search_space = SearchSpace(
    parameters=parameters,
    # NOTE: In practice, it may make sense to add a constraint
    # hidden_size_2 <= hidden_size_1
    parameter_constraints=[],
)

from ax.metrics.tensorboard import TensorboardMetric
from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer

class MyTensorboardMetric(TensorboardMetric):

    # NOTE: We need to tell the new TensorBoard metric how to get the id /
    # file handle for the TensorBoard logs from a trial. In this case
    # our convention is to just save a separate file per trial in
    # the prespecified log dir.
    def _get_event_multiplexer_for_trial(self, trial):
        mul = event_multiplexer.EventMultiplexer(max_reload_threads=20)
        mul.AddRunsFromDirectory(Path(src_log_dir).joinpath(str(trial.index)).as_posix(), None)
        mul.Reload()
    
        return mul

    # This indicates whether the metric is queryable while the trial is
    # still running. We don't use this in the current tutorial, but Ax
    # utilizes this to implement trial-level early-stopping functionality.
    @classmethod
    def is_available_while_running(cls):
        return True

val_acc = MyTensorboardMetric(
    name="val_acc",
    tag="val_acc",
    lower_is_better=False,
)
model_num_params = MyTensorboardMetric(
    name="num_params",
    tag="num_params",
    lower_is_better=True,
)


In [76]:
from ax.core import MultiObjective, Objective, ObjectiveThreshold
from ax.core.optimization_config import MultiObjectiveOptimizationConfig


opt_config = MultiObjectiveOptimizationConfig(
    objective=MultiObjective(
        objectives=[
            Objective(metric=val_acc, minimize=False),
            Objective(metric=model_num_params, minimize=True),
        ],
    ),
    objective_thresholds=[
        ObjectiveThreshold(metric=val_acc, bound=0.94, relative=False),
        ObjectiveThreshold(metric=model_num_params, bound=80_000, relative=False),
    ],
)

from ax.core import Experiment

experiment = Experiment(
    name="torchx_mnist",
    search_space=search_space,
    optimization_config=opt_config,
    runner=ax_runner,
)

total_trials = 48  # total evaluation budget

from ax.modelbridge.dispatch_utils import choose_generation_strategy

gs = choose_generation_strategy(
    search_space=experiment.search_space,
    optimization_config=experiment.optimization_config,
    num_trials=total_trials,
  )

from ax.service.scheduler import Scheduler, SchedulerOptions

scheduler = Scheduler(
    experiment=experiment,
    generation_strategy=gs,
    options=SchedulerOptions(
        total_trials=total_trials, max_pending_trials=4,
        init_seconds_between_polls=5, seconds_between_polls_backoff_factor=1
    ),
)

scheduler.run_n_trials(1)

[INFO 01-07 10:31:24] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.
[INFO 01-07 10:31:24] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=5 num_trials=48 use_batch_trials=False
[INFO 01-07 10:31:24] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=9
[INFO 01-07 10:31:24] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=9
[INFO 01-07 10:31:24] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments.
[INFO 01-07 10:31:24] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch'

OptimizationResult()

In [77]:

from ax.service.utils.report_utils import exp_to_df

df = exp_to_df(experiment)
df.head(10)



  df = pd.concat(


Unnamed: 0,trial_index,arm_name,trial_status,generation_method,is_feasible,num_params,val_acc,hidden_size_1,hidden_size_2,learning_rate,dropout,batch_size
0,0,0_0,COMPLETED,Sobol,False,3192.0,0.09776,4,4,0.089042,0.452861,32


In [78]:
from ax.service.utils.report_utils import _pareto_frontier_scatter_2d_plotly

_pareto_frontier_scatter_2d_plotly(experiment)

  df = pd.concat(


ValueError: `upper` should be greater than `lower`, got: 0.09776025265455246 (<= 0.09776025265455246).

In [79]:
from ax.modelbridge.cross_validation import compute_diagnostics, cross_validate
from ax.plot.diagnostic import interact_cross_validation_plotly
from ax.utils.notebook.plotting import init_notebook_plotting, render

cv = cross_validate(model=gs.model)  # The surrogate model is stored on the ``GenerationStrategy``
compute_diagnostics(cv)

interact_cross_validation_plotly(cv)

ValueError: RandomModelBridge has no training data.  Either it has been incorrectly initialized or should not be cross validated.

In [None]:
from ax.plot.contour import interact_contour_plotly

interact_contour_plotly(model=gs.model, metric_name="val_acc")

In [None]:
interact_contour_plotly(model=gs.model, metric_name="num_params")