In [1]:
%%time
import math
import torch

from botorch.test_functions import SixHumpCamel
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.utils.transforms import standardize, normalize
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition import (
    ExpectedImprovement,
    ProbabilityOfImprovement,
    qMaxValueEntropy,
)
import matplotlib.pyplot as plt
from botorch.optim import optimize_acqf
from botorch.acquisition import qNoisyExpectedImprovement, qProbabilityOfImprovement
from time import time
import cProfile, io, pstats
from pstats import SortKey

SMOKE_TEST = True

torch.manual_seed(123456)

bounds = torch.tensor(SixHumpCamel._bounds).T
bounds_norm = torch.tensor([[0.0, 0.0], [1.0, 1.0]])
train_X = bounds[0] + (bounds[1] - bounds[0]) * torch.rand(5, 2)
train_Y = SixHumpCamel(negate=True)(train_X).unsqueeze(-1)

train_X = normalize(train_X, bounds=bounds)
train_Y = standardize(train_Y + 0.05 * torch.randn_like(train_Y))

model = SingleTaskGP(train_X, train_Y)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll, max_attempts=10)

from botorch.acquisition.max_value_entropy_search import qLowerBoundMaxValueEntropy

candidate_set_size = 1000 if not SMOKE_TEST else 5
candidate_set = torch.rand(
    candidate_set_size, bounds_norm.size(1), device=bounds.device, dtype=bounds.dtype
)
qGIBBON = qLowerBoundMaxValueEntropy(model, candidate_set)
None

CPU times: user 2.94 s, sys: 1.32 s, total: 4.26 s
Wall time: 3.34 s




In [2]:
%%time

# prep different acqusition functions
candidate_set = torch.rand(
    10000, bounds.size(1), device=bounds.device, dtype=bounds.dtype
)
acq = qLowerBoundMaxValueEntropy(model, candidate_set)
# acqs["MES"] = qMaxValueEntropy(model, candidate_set)
# acqs["EI"] = qNoisyExpectedImprovement(model, train_X)
# acqs["PI"] = qProbabilityOfImprovement(model, best_f=train_Y.max())

# prep grid to evaluate acq functions
n = 100 if not SMOKE_TEST else 2
xv, yv = torch.meshgrid([torch.linspace(0, 1, n), torch.linspace(0, 1, n)])
test_x = torch.stack([xv.reshape(n * n, 1), yv.reshape(n * n, 1)], -1)

# eval and maximise acq functions
with cProfile.Profile() as pr:
    acq(test_x).detach().reshape(n, n)
    optimize_acqf(
        acq_function=acq, bounds=bounds_norm, q=1, num_restarts=5, raw_samples=100
    )

sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr).sort_stats(sortby)
ps.print_stats("scipy")
    
None

         69596 function calls (63730 primitive calls) in 0.059 seconds

   Ordered by: cumulative time
   List reduced from 470 to 40 due to restriction <'scipy'>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.053    0.053 /opt/anaconda3/lib/python3.9/site-packages/botorch/generation/gen.py:43(gen_candidates_scipy)
        1    0.000    0.000    0.052    0.052 /opt/anaconda3/lib/python3.9/site-packages/scipy/optimize/_minimize.py:45(minimize)
        1    0.000    0.000    0.052    0.052 /opt/anaconda3/lib/python3.9/site-packages/scipy/optimize/lbfgsb.py:210(_minimize_lbfgsb)
       42    0.001    0.000    0.051    0.001 /opt/anaconda3/lib/python3.9/site-packages/scipy/optimize/optimize.py:65(_compute_if_needed)
       22    0.000    0.000    0.051    0.002 /opt/anaconda3/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:231(_update_fun)
       21    0.000    0.000    0.051    0.002 /opt/anaconda3/lib/pyth

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
%%time

# prep different acqusition functions
candidate_set = torch.rand(
    10000, bounds.size(1), device=bounds.device, dtype=bounds.dtype
)
acq = qLowerBoundMaxValueEntropy(model, candidate_set)
# acqs["MES"] = qMaxValueEntropy(model, candidate_set)
# acqs["EI"] = qNoisyExpectedImprovement(model, train_X)
# acqs["PI"] = qProbabilityOfImprovement(model, best_f=train_Y.max())

# prep grid to evaluate acq functions
n = 100 if not SMOKE_TEST else 2
xv, yv = torch.meshgrid([torch.linspace(0, 1, n), torch.linspace(0, 1, n)])
test_x = torch.stack([xv.reshape(n * n, 1), yv.reshape(n * n, 1)], -1)

# eval and maximise acq functions

with cProfile.Profile() as pr:
    acq(test_x).detach().reshape(n, n)
    optimize_acqf(
        acq_function=acq,
        bounds=bounds_norm,
        q=3,
        num_restarts=5,
        raw_samples=100,
        sequential=True,
    )

sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr).sort_stats(sortby)
ps.print_stats("scipy")

None

         827035 function calls (762007 primitive calls) in 0.826 seconds

   Ordered by: cumulative time
   List reduced from 638 to 73 due to restriction <'scipy'>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        3    0.000    0.000    0.631    0.210 /opt/anaconda3/lib/python3.9/site-packages/botorch/generation/gen.py:43(gen_candidates_scipy)
        3    0.000    0.000    0.621    0.207 /opt/anaconda3/lib/python3.9/site-packages/scipy/optimize/_minimize.py:45(minimize)
        3    0.002    0.001    0.621    0.207 /opt/anaconda3/lib/python3.9/site-packages/scipy/optimize/lbfgsb.py:210(_minimize_lbfgsb)
      214    0.010    0.000    0.616    0.003 /opt/anaconda3/lib/python3.9/site-packages/scipy/optimize/optimize.py:65(_compute_if_needed)
      111    0.000    0.000    0.615    0.006 /opt/anaconda3/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:231(_update_fun)
      107    0.000    0.000    0.615    0.006 /opt/anaconda3/lib/py