In [2]:
%matplotlib inline
import torch

from botorch.fit import fit_gpytorch_model
from botorch.models import SingleTaskGP
from gpytorch.mlls import ExactMarginalLogLikelihood

## setup model

In [3]:
d = 5

bounds = torch.stack([-torch.ones(d), torch.ones(d)])

train_X = bounds[0] + (bounds[1] - bounds[0]) * torch.rand(50, d)
train_Y = 1 - torch.norm(train_X, dim=-1, keepdim=True)

model = SingleTaskGP(train_X, train_Y)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_model(mll);

In [4]:
from botorch.acquisition import qExpectedImprovement
from botorch.sampling import IIDNormalSampler

sampler = IIDNormalSampler(num_samples=100, resample=True)
qEI = qExpectedImprovement(model, best_f=train_Y.max(), sampler=sampler)

## Optimizing the acquisition function

In [5]:
N = 5
q = 2

In [6]:
from botorch.optim.initializers import initialize_q_batch_nonneg

# generate a large number of random q-batches
Xraw = bounds[0] + (bounds[1] - bounds[0]) * torch.rand(100 * N, q, d)
Yraw = qEI(Xraw)  # evaluate the acquisition function on these q-batches

# apply the heuristic for sampling promising initial conditions
X = initialize_q_batch_nonneg(Xraw, Yraw, N)

# we'll want gradients for the input
X.requires_grad_(True);

In [7]:
# set up the optimizer, make sure to only pass in the candidate set here
optimizer = torch.optim.Adam([X], lr=0.01)
X_traj = []  # we'll store the results

# run a basic optimization loop
for i in range(75):
    optimizer.zero_grad()
    # this performs batch evaluation, so this is an N-dim tensor
    losses = - qEI(X)  # torch.optim minimizes
    loss = losses.sum()
    
    loss.backward()  # perform backward pass
    optimizer.step()  # take a step
    
    # clamp values to the feasible set
    for j, (lb, ub) in enumerate(zip(*bounds)):
        X.data[..., j].clamp_(lb, ub) # need to do this on the data not X itself
    
    # store the optimization trajecatory
    X_traj.append(X.detach().clone())
    
    if (i + 1) % 15 == 0:
        print(f"Iteration {i+1:>3}/75 - Loss: {loss.item():>4.3f}")
    

Iteration  15/75 - Loss: -0.326
Iteration  30/75 - Loss: -0.529
Iteration  45/75 - Loss: -0.568
Iteration  60/75 - Loss: -0.564
Iteration  75/75 - Loss: -0.541


In [9]:
import numpy as np

np.array(torch.rand(50, d))

array([[0.47058266, 0.99794865, 0.49987686, 0.16786104, 0.13061285],
       [0.06820774, 0.5556341 , 0.10346764, 0.13704604, 0.81089324],
       [0.28773904, 0.948784  , 0.501875  , 0.7463471 , 0.60500836],
       [0.29191488, 0.6558614 , 0.8919577 , 0.8561606 , 0.30511737],
       [0.41843045, 0.800798  , 0.71512145, 0.5041714 , 0.6887779 ],
       [0.20482177, 0.91034764, 0.43708825, 0.8872794 , 0.7380854 ],
       [0.70153916, 0.47579175, 0.4672292 , 0.51348484, 0.09339857],
       [0.7364302 , 0.85111195, 0.0241515 , 0.38854498, 0.39775878],
       [0.30882168, 0.6236709 , 0.6571233 , 0.46451408, 0.562189  ],
       [0.5001771 , 0.8483812 , 0.9488413 , 0.87194806, 0.6885145 ],
       [0.68727106, 0.9938194 , 0.51558894, 0.53999126, 0.35124528],
       [0.11529607, 0.84712875, 0.7916093 , 0.3952828 , 0.5187924 ],
       [0.79604214, 0.8535609 , 0.6326108 , 0.06820387, 0.09980428],
       [0.6581847 , 0.11283666, 0.77676135, 0.57432526, 0.76091045],
       [0.09858644, 0.15066391, 0.

In [10]:
torch.rand(100 * N, q, d)

tensor([[[0.5755, 0.5888, 0.8243, 0.8729, 0.0501],
         [0.1290, 0.8853, 0.9669, 0.5929, 0.4498]],

        [[0.3573, 0.6302, 0.7881, 0.8128, 0.4783],
         [0.3830, 0.0973, 0.2658, 0.7673, 0.1790]],

        [[0.3010, 0.4046, 0.5424, 0.9651, 0.6443],
         [0.2152, 0.1845, 0.9541, 0.8952, 0.1014]],

        ...,

        [[0.0811, 0.9054, 0.3772, 0.2057, 0.9966],
         [0.5297, 0.1481, 0.3191, 0.5470, 0.6809]],

        [[0.6012, 0.5516, 0.5992, 0.8367, 0.7838],
         [0.0425, 0.3714, 0.9717, 0.7650, 0.5864]],

        [[0.0629, 0.8466, 0.8290, 0.1924, 0.3088],
         [0.1591, 0.9698, 0.5169, 0.9902, 0.4205]]])

In [11]:
Xraw 

tensor([[[ 0.1801,  0.0172, -0.8830,  0.8171,  0.6486],
         [ 0.6180,  0.2402,  0.2326, -0.6120,  0.8851]],

        [[ 0.9838,  0.4096,  0.4929, -0.9934,  0.7679],
         [ 0.8102, -0.3548,  0.1331, -0.6534, -0.6121]],

        [[ 0.3512, -0.1069, -0.2517,  0.0086,  0.5166],
         [-0.7324, -0.8320, -0.2600,  0.4776,  0.0859]],

        ...,

        [[-0.4636, -0.4286,  0.5687, -0.1589,  0.0866],
         [-0.7236, -0.0082,  0.1155, -0.6146,  0.3696]],

        [[-0.5296,  0.2122,  0.8464,  0.7161,  0.3341],
         [-0.0597, -0.9228, -0.3113, -0.7929,  0.4629]],

        [[-0.5756, -0.1258,  0.6335, -0.8887, -0.4236],
         [-0.5542, -0.6055, -0.6820, -0.9451,  0.0429]]])

In [13]:
len(Yraw)

500