# Hook BoTorch gradient BO with bandgap.

In [None]:
import os

from ax.utils.testing.mock import fast_botorch_optimize_context_manager
import plotly.io as pio

# Ax uses Plotly to produce interactive plots. These are great for viewing and analysis,
# though they also lead to large file sizes, which is not ideal for files living in GH.
# Changing the default to `png` strips the interactive components to get around this.
pio.renderers.default = "png"

NUM_EVALS = 20

## GP Model

In [None]:
from typing import Optional

from botorch.models.gpytorch import GPyTorchModel
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP
from torch import Tensor


class SimpleCustomGP(ExactGP, GPyTorchModel):

    _num_outputs = 1  # to inform GPyTorchModel API

    def __init__(self, train_X, train_Y, train_Yvar: Optional[Tensor] = None):
        # NOTE: This ignores train_Yvar and uses inferred noise instead.
        # squeeze output dim before passing train_Y to ExactGP
        super().__init__(train_X, train_Y.squeeze(-1), GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(
            base_kernel=RBFKernel(ard_num_dims=train_X.shape[-1]),
        )
        self.to(train_X)  # make sure we're on the right device/dtype

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

In [None]:
from ax.models.torch.botorch_modular.model import BoTorchModel
from ax.models.torch.botorch_modular.surrogate import Surrogate


ax_model = BoTorchModel(
    surrogate=Surrogate(
        # The model class to use
        botorch_model_class=SimpleCustomGP,
        # Optional, MLL class with which to optimize model parameters
        # mll_class=ExactMarginalLogLikelihood,
        # Optional, dictionary of keyword arguments to model constructor
        # model_options={}
    ),
    # Optional, acquisition function class to use - see custom acquisition tutorial
    # botorch_acqf_class=qExpectedImprovement,
)

## Create the AX experiment

In [None]:
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models


gs = GenerationStrategy(
    steps=[
        # Quasi-random initialization step
        GenerationStep(
            model=Models.SOBOL,
            num_trials=5,  # How many trials should be produced from this generation step
        ),
        # Bayesian optimization step using the custom acquisition function
        GenerationStep(
            model=Models.BOTORCH_MODULAR,
            num_trials=-1,  # No limitation on how many trials should be produced from this step
            # For `BOTORCH_MODULAR`, we pass in kwargs to specify what surrogate or acquisition function to use.
            model_kwargs={
                "surrogate": Surrogate(SimpleCustomGP),
            },
        ),
    ]
)

In [None]:
import torch
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties
from botorch.test_functions import Branin


# Initialize the client - AxClient offers a convenient API to control the experiment
ax_client = AxClient(generation_strategy=gs)
# Setup the experiment
ax_client.create_experiment(
    name="branin_test_experiment",
    parameters=[
        {
            "name": "x1",
            "type": "range",
            # It is crucial to use floats for the bounds, i.e., 0.0 rather than 0.
            # Otherwise, the parameter would be inferred as an integer range.
            "bounds": [-0.05, 0.05],
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [-0.05, 0.05],
        },
    ],
    objectives={
        "indirect_band_gap": ObjectiveProperties(minimize=True),
    },
)

### Define the function to optimize.

In [None]:
import numpy as np
from julia import Julia
jl = Julia(runtime="/home/cedric/.juliaup/bin/julia", compiled_modules=False)

# Load DFTK
from julia import Pkg, Main
DFTK_path = "/home/cedric/PostPhD/Dev/JuliaMolSim/DFTK.jl/"
Pkg.activate(DFTK_path)
Pkg.resolve()
Pkg.instantiate()

import os
Main.include(os.path.join(DFTK_path, "examples/strain/silicon_strain_engineering.jl"))

%load_ext julia.magic

In [None]:
def evaluate(parameters):
    # Only vary stress along x and y.
    x = np.array(np.squeeze([[parameters.get(f"x{i+1}") for i in range(2)] + [0, 0, 0, 0]]))
    # The GaussianLikelihood used by our model infers an observation noise level,
    # so we pass an sem value of NaN to indicate that observation noise is unknown
    indirect_band_gap = %julia austrip(strain_indirect_band_gap($x))
    indirect_band_gap = np.abs(indirect_band_gap - 0.05) # optimize to target value.
    return {"indirect_band_gap": (indirect_band_gap, float("nan"))}

## Evaluat indirect bandgap at 0 stress

In [None]:
evaluate({"x1": 0.0, "x2": 0.0})

### Running the BO loop

In [None]:
for i in range(NUM_EVALS):
    parameters, trial_index = ax_client.get_next_trial()
    # Local evaluation here can be replaced with deployment to external system.
    ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))

### Viewing the evaluated trials

In [None]:
ax_client.get_trials_data_frame()

In [None]:
parameters, values = ax_client.get_best_parameters()
print(f"Best parameters: {parameters}")
print(f"Corresponding mean: {values[0]}, covariance: {values[1]}")

### Plotting the response surface and optimization progress

In [None]:
from ax.utils.notebook.plotting import render

render(ax_client.get_contour_plot())

In [None]:
best_parameters, values = ax_client.get_best_parameters()
best_parameters, values[0]

In [None]:
render(ax_client.get_optimization_trace(objective_optimum=0.02))render(ax_client.get_optimization_trace(objective_optimum=0.02))

## Run some more

In [None]:
for i in range(NUM_EVALS):
    parameters, trial_index = ax_client.get_next_trial()
    # Local evaluation here can be replaced with deployment to external system.
    ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))

In [None]:
render(ax_client.get_optimization_trace(objective_optimum=0.02))

In [None]:
render(ax_client.get_contour_plot())

In [None]:
# And some more.
for i in range(10):
    parameters, trial_index = ax_client.get_next_trial()
    # Local evaluation here can be replaced with deployment to external system.
    ax_cWWSYScreenshot from 2024-03-07 22-03-04Screenshot from 2024-03-07 22-03-04Alient.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))

In [None]:
render(ax_client.get_optimization_trace(objective_optimum=0.0))