In [1]:
%%time
import math
import torch

from botorch.test_functions import SixHumpCamel
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.utils.transforms import standardize, normalize
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition import (
    ExpectedImprovement,
    ProbabilityOfImprovement,
    qMaxValueEntropy,
)
from botorch.acquisition.max_value_entropy_search import qLowerBoundMaxValueEntropy
import matplotlib.pyplot as plt
from botorch.optim import optimize_acqf
from botorch.acquisition import qNoisyExpectedImprovement, qProbabilityOfImprovement
from time import time
import cProfile, io, pstats
from pstats import SortKey
import scipy

SMOKE_TEST = True

torch.manual_seed(123456)
torch.set_default_dtype(torch.double)

bounds = torch.tensor(SixHumpCamel._bounds).T
bounds_norm = torch.tensor([[0.0, 0.0], [1.0, 1.0]])
train_X = bounds[0] + (bounds[1] - bounds[0]) * torch.rand(5, 2)
train_Y = SixHumpCamel(negate=True)(train_X).unsqueeze(-1)

train_X = normalize(train_X, bounds=bounds)
train_Y = standardize(train_Y + 0.05 * torch.randn_like(train_Y))

model = SingleTaskGP(train_X, train_Y)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll, max_attempts=1)
scipy.__version__

CPU times: user 2.01 s, sys: 2.29 s, total: 4.3 s
Wall time: 3.88 s


'1.7.3'

In [2]:
%%time

acqf = qNoisyExpectedImprovement(model, train_X)

# eval and maximise acq functions
with cProfile.Profile() as pr:
    optimize_acqf(
        acq_function=acqf,
        bounds=bounds_norm,
        q=5,
        num_restarts=5,
        raw_samples=5,
        sequential=False,
    )

sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr).sort_stats(sortby)
ps.print_stats(.1)

scipy.__version__

         196165 function calls (180446 primitive calls) in 0.236 seconds

   Ordered by: cumulative time
   List reduced from 457 to 46 due to restriction <0.1>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.236    0.236 /opt/anaconda3/lib/python3.9/site-packages/botorch/optim/optimize.py:415(optimize_acqf)
        1    0.000    0.000    0.236    0.236 /opt/anaconda3/lib/python3.9/site-packages/botorch/optim/optimize.py:547(_optimize_acqf)
        1    0.000    0.000    0.236    0.236 /opt/anaconda3/lib/python3.9/site-packages/botorch/optim/optimize.py:258(_optimize_acqf_batch)
        1    0.000    0.000    0.233    0.233 /opt/anaconda3/lib/python3.9/site-packages/botorch/optim/optimize.py:294(_optimize_batch_candidates)
        1    0.000    0.000    0.233    0.233 /opt/anaconda3/lib/python3.9/site-packages/botorch/generation/gen.py:43(gen_candidates_scipy)
        1    0.000    0.000    0.231    0.231 /opt/anaconda3/lib/pyth

'1.7.3'

In [5]:
%%time
with cProfile.Profile() as pr:
    for _ in range(100):
        test_X = torch.rand(20, 5, 2, requires_grad=True)
        acqf(test_X).sum().backward()

sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr).sort_stats(sortby)
ps.print_stats(.1)

scipy.__version__

         138702 function calls (126802 primitive calls) in 0.286 seconds

   Ordered by: cumulative time
   List reduced from 240 to 24 due to restriction <0.1>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  500/100    0.001    0.000    0.169    0.002 /opt/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1494(_call_impl)
      100    0.000    0.000    0.168    0.002 /opt/anaconda3/lib/python3.9/site-packages/botorch/utils/transforms.py:326(decorated)
      100    0.001    0.000    0.168    0.002 /opt/anaconda3/lib/python3.9/site-packages/botorch/utils/transforms.py:266(decorated)
      100    0.001    0.000    0.165    0.002 /opt/anaconda3/lib/python3.9/site-packages/botorch/acquisition/monte_carlo.py:325(forward)
      100    0.000    0.000    0.116    0.001 /opt/anaconda3/lib/python3.9/site-packages/torch/_tensor.py:428(backward)
      100    0.000    0.000    0.115    0.001 /opt/anaconda3/lib/python3.9/site-packages/torch/autograd/__init__

'1.7.3'