In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models.s4.s4 import S4Block as S4  # Can use full version instead of minimal S4D standalone below
from models.s4.s4d import S4D
from tqdm.auto import tqdm

# Dropout broke in PyTorch 1.11
if tuple(map(int, torch.__version__.split('.')[:2])) == (1, 11):
    print("WARNING: Dropout is bugged in PyTorch 1.11. Results may be worse.")
    dropout_fn = nn.Dropout
if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 12):
    dropout_fn = nn.Dropout1d
else:
    dropout_fn = nn.Dropout2d

  from .autonotebook import tqdm as notebook_tqdm
CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) not found. Install by going to extensions/kernels/ and running `python setup.py install`, for improved speed and memory efficiency. Note that the kernel changed for state-spaces 4.0 and must be recompiled.
Falling back on slow Cauchy and Vandermonde kernel. Install at least one of pykeops or the CUDA extension for better speed and memory efficiency.


In [2]:
import math
import os
from dataclasses import dataclass

import botorch
import gpytorch
import matplotlib.pyplot as plt
import numpy as np
import torch
from gpytorch.constraints import Interval
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from torch.quasirandom import SobolEngine

from botorch.acquisition.analytic import LogExpectedImprovement
from botorch.exceptions import ModelFittingError
from botorch.fit import fit_gpytorch_mll
from botorch.generation import MaxPosteriorSampling
from botorch.models import SingleTaskGP
from botorch.optim import optimize_acqf
from botorch.test_functions import Branin

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {device}")
dtype = torch.float
SMOKE_TEST = os.environ.get("SMOKE_TEST")

Running on cpu


In [3]:
import torch
print(torch.__version__)

2.2.1+cpu


In [4]:
branin = Branin(negate=True).to(device=device, dtype=dtype)

def branin_emb(x):
    """x is assumed to be in [-1, 1]^D"""
    lb, ub = branin.bounds
    return branin(lb + (ub - lb) * (x[..., :2] + 1) / 2)

In [5]:
fun = branin_emb
dim = 500

n_init = 40
max_cholesky_size = float("inf")  # Always use Cholesky

In [6]:
def get_initial_points(dim, n_pts, seed=0):
    sobol = SobolEngine(dimension=dim, scramble=True, seed=seed)
    X_init = (
        2 * sobol.draw(n=n_pts).to(dtype=dtype, device=device) - 1
    )  # points have to be in [-1, 1]^d
    return X_init

In [7]:
from torch.utils.data import Dataset, DataLoader

class my_dataset(Dataset):
    def __init__(self, train_x, train_y):
        self.train_x = train_x.unsqueeze(-1)
        self.train_y = train_y

        self.num = train_x.shape[0]

        # Generate data
        self.data = []
        self.data.append((self.train_x, self.train_y))

    def __getitem__(self, index):
        return self.train_x[index], self.train_y[index]

    def __len__(self):
        return self.num

In [26]:
# def create_candidate(
#     state,
#     model,  # GP model
#     X,  # Evaluated points on the domain [-1, 1]^d
#     Y,  # Function values
#     n_candidates=None,  # Number of candidates for Thompson sampling
#     num_restarts=10,
#     raw_samples=512,
#     acqf="ts",  # "ei" or "ts"
# ):
#     assert acqf in ("ts", "ei")
#     assert X.min() >= -1.0 and X.max() <= 1.0 and torch.all(torch.isfinite(Y))
#     if n_candidates is None:
#         n_candidates = min(5000, max(2000, 200 * X.shape[-1]))

#     # Scale the TR to be proportional to the lengthscales
#     x_center = X[Y.argmax(), :].clone()
#     weights = model.covar_module.base_kernel.lengthscale.detach().view(-1)
#     weights = weights / weights.mean()
#     weights = weights / torch.prod(weights.pow(1.0 / len(weights)))
#     tr_lb = torch.clamp(x_center - weights * state.length, -1.0, 1.0)
#     tr_ub = torch.clamp(x_center + weights * state.length, -1.0, 1.0)

#     if acqf == "ts":
#         dim = X.shape[-1]
#         sobol = SobolEngine(dim, scramble=True)
#         pert = sobol.draw(n_candidates).to(dtype=dtype, device=device)
#         pert = tr_lb + (tr_ub - tr_lb) * pert

#         # Create a perturbation mask
#         prob_perturb = min(20.0 / dim, 1.0)
#         mask = torch.rand(n_candidates, dim, dtype=dtype, device=device) <= prob_perturb
#         ind = torch.where(mask.sum(dim=1) == 0)[0]
#         mask[ind, torch.randint(0, dim, size=(len(ind),), device=device)] = 1

#         # Create candidate points from the perturbations and the mask
#         X_cand = x_center.expand(n_candidates, dim).clone()
#         X_cand[mask] = pert[mask]

#         # Sample on the candidate points
#         thompson_sampling = MaxPosteriorSampling(model=model, replacement=False)
#         with torch.no_grad():  # We don't need gradients when using TS
#             X_next = thompson_sampling(X_cand, num_samples=1)

#     elif acqf == "ei":
#         ei = LogExpectedImprovement(model, train_Y.max())
#         X_next, acq_value = optimize_acqf(
#             ei,
#             bounds=torch.stack([tr_lb, tr_ub]),
#             q=1,
#             num_restarts=num_restarts,
#             raw_samples=raw_samples,
#         )

#     return X_next

In [8]:
class S4Model(nn.Module):

    def __init__(
        self,
        d_input,
        d_output=10,
        d_model=256,
        n_layers=4,
        dropout=0.2,
        prenorm=False,
        lr=0.005
    ):
        super().__init__()

        self.prenorm = prenorm
        self.lr = lr

        # Linear encoder (d_input = 1 for grayscale and 3 for RGB)
        self.encoder = nn.Linear(d_input, d_model)

        # Stack S4 layers as residual blocks
        self.s4_layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()

        for _ in range(n_layers):
            self.s4_layers.append(
                S4D(d_model, dropout=dropout, transposed=True, lr=min(0.001, self.lr))
            )
            self.norms.append(nn.LayerNorm(d_model))
            self.dropouts.append(dropout_fn(dropout))

        # Linear decoder
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x):
        """
        Input x is shape (B, L, d_input)
        """
        x = self.encoder(x)  # (B, L, d_input) -> (B, L, d_model)
        # print(x.shape)

        x = x.transpose(-1, -2)  # (B, L, d_model) -> (B, d_model, L)
        for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts):
            # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L)

            z = x
            if self.prenorm:
                # Prenorm
                z = norm(z.transpose(-1, -2)).transpose(-1, -2)

            # Apply S4 block: we ignore the state input and output
            z, _ = layer(z)

            # Dropout on the output of the S4 block
            z = dropout(z)

            # Residual connection
            x = z + x

            if not self.prenorm:
                # Postnorm
                x = norm(x.transpose(-1, -2)).transpose(-1, -2)

        x = x.transpose(-1, -2) # (B, d_model, L) -> (B, L, d_model)

        # Pooling: average pooling over the sequence length
        x = x.mean(dim=1) # (B, L, d_model) -> (B, d_model)

        # Decode the outputs
        x = self.decoder(x)  # (B, d_model) -> (B, d_output)

        return x

In [9]:
d_input = 1
d_output = 1
d_model = 128
n_layers = 4
dropout = 0.1
prenorm = True

model = S4Model(
    d_input=d_input,
    d_output=d_output,
    d_model=d_model,
    n_layers=n_layers,
    dropout=dropout,
    prenorm=prenorm,
)

model = model.to(device)
if device == 'cuda':
    cudnn.benchmark = True

In [10]:
n_init = 10
dtype = torch.float

train_x = get_initial_points(dim, n_init)
train_y = torch.tensor(
    [branin_emb(x) for x in train_x], dtype=dtype, device=device
).unsqueeze(-1)

val_x = get_initial_points(dim, n_init//2)
val_y = torch.tensor(
    [branin_emb(x) for x in val_x], dtype=dtype, device=device
).unsqueeze(-1)

train_data = my_dataset(train_x=train_x, train_y=train_y)
trainloader = DataLoader(train_data, batch_size=4, shuffle=True)

val_data = my_dataset(train_x=val_x, train_y=val_y)
valloader = DataLoader(val_data, batch_size=4, shuffle=True)

In [None]:
# best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

pbar = tqdm(range(start_epoch, 100))
for epoch in pbar:
    # train()
    model.train()
    train_loss = 0
    pbar = tqdm(enumerate(trainloader))
    for batch_idx, (inputs, targets) in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        pbar.set_description(
            'Training | Batch Idx: (%d/%d) | Loss: %.3f' %
            (batch_idx, len(trainloader), train_loss/(batch_idx+1))
        )

    # val_acc = eval(epoch, valloader)
    if epoch % 10 == 9:
        model.eval()
        eval_loss = 0
        with torch.no_grad():
            pbar = tqdm(enumerate(valloader))
            for batch_idx, (inputs, targets) in pbar:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                eval_loss += loss.item()

                pbar.set_description(
                    'Validating | Batch Idx: (%d/%d) | Loss: %.3f' %
                    (batch_idx, len(valloader), eval_loss/(batch_idx+1))
                )

    # scheduler.step()
    # print(f"Epoch {epoch} learning rate: {scheduler.get_last_lr()}")

## BO

In [10]:
NUM_RESTARTS = 10 if not SMOKE_TEST else 2
RAW_SAMPLES = 256 if not SMOKE_TEST else 4
N_CANDIDATES = min(5000, max(2000, 200 * dim)) if not SMOKE_TEST else 4
print(NUM_RESTARTS, RAW_SAMPLES, N_CANDIDATES)

10 256 5000


In [11]:
dtype = torch.double
n_init = 10

X_ei = get_initial_points(dim, n_init)
Y_ei = torch.tensor(
    [branin_emb(x) for x in X_ei], dtype=dtype, device=device
).unsqueeze(-1)

# Disable input scaling checks as we normalize to [-1, 1]
with botorch.settings.validate_input_scaling(False):
    while len(Y_ei) < 100:
        train_Y = (Y_ei - Y_ei.mean()) / Y_ei.std()
        likelihood = GaussianLikelihood(noise_constraint=Interval(1e-8, 1e-3))
        model = SingleTaskGP(X_ei, train_Y, likelihood=likelihood)
        mll = ExactMarginalLogLikelihood(model.likelihood, model)
        optimizer = torch.optim.Adam([{"params": model.parameters()}], lr=0.1)
        model.train()
        model.likelihood.train()
        for _ in range(50):
            optimizer.zero_grad()
            output = model(X_ei)
            loss = -mll(output, train_Y.squeeze())
            loss.backward()
            optimizer.step()

        # Create a batch
        ei = LogExpectedImprovement(model, train_Y.max())
        candidate, acq_value = optimize_acqf(
            ei,
            bounds=torch.stack(
                [
                    -torch.ones(dim, dtype=dtype, device=device),
                    torch.ones(dim, dtype=dtype, device=device),
                ]
            ),
            q=1,
            num_restarts=NUM_RESTARTS,
            raw_samples=RAW_SAMPLES,
        )
        Y_next = torch.tensor(
            [branin_emb(x) for x in candidate], dtype=dtype, device=device
        ).unsqueeze(-1)

        # Append data
        X_ei = torch.cat((X_ei, candidate), axis=0)
        Y_ei = torch.cat((Y_ei, Y_next), axis=0)

        # Print current status
        print(f"{len(X_ei)}) Best value: {Y_ei.max().item():.2e}")

11) Best value: -9.02e+00
12) Best value: -4.70e+00




13) Best value: -4.70e+00
14) Best value: -4.70e+00
15) Best value: -4.70e+00
16) Best value: -2.15e+00




17) Best value: -2.15e+00
18) Best value: -2.15e+00
19) Best value: -2.15e+00




20) Best value: -2.15e+00
21) Best value: -2.15e+00




22) Best value: -2.15e+00




23) Best value: -2.15e+00
24) Best value: -2.15e+00
25) Best value: -2.15e+00
26) Best value: -2.15e+00
27) Best value: -2.15e+00
28) Best value: -2.15e+00
29) Best value: -2.15e+00
30) Best value: -2.15e+00
31) Best value: -2.15e+00
32) Best value: -2.15e+00
33) Best value: -2.15e+00
34) Best value: -2.15e+00
35) Best value: -1.54e+00
36) Best value: -1.54e+00
37) Best value: -1.54e+00
38) Best value: -1.54e+00
39) Best value: -1.54e+00
40) Best value: -1.54e+00
41) Best value: -1.54e+00
42) Best value: -4.48e-01
43) Best value: -4.48e-01
44) Best value: -4.48e-01
45) Best value: -4.48e-01
46) Best value: -4.48e-01
47) Best value: -4.48e-01
48) Best value: -4.48e-01
49) Best value: -4.48e-01
50) Best value: -4.48e-01
51) Best value: -4.48e-01
52) Best value: -4.48e-01
53) Best value: -4.48e-01
54) Best value: -4.48e-01
55) Best value: -4.48e-01
56) Best value: -4.48e-01
57) Best value: -4.48e-01
58) Best value: -4.48e-01
59) Best value: -4.48e-01
60) Best value: -4.48e-01
61) Best val

In [12]:
X_Sobol = (
    SobolEngine(dim, scramble=True, seed=0)
    .draw(len(X_ei))
    .to(dtype=dtype, device=device)
    * 2
    - 1
)
Y_Sobol = torch.tensor(
    [branin_emb(x) for x in X_Sobol], dtype=dtype, device=device
).unsqueeze(-1)

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc

%matplotlib inline

names = ["GP-qEI", "Sobol"] # , "EI", "Sobol"
runs = [Y_ei, Y_Sobol] # , Y_ei, Y_Sobol
fig, ax = plt.subplots(figsize=(8, 6))

for name, run in zip(names, runs):
    fx = np.maximum.accumulate(run.cpu())
    plt.plot(fx, marker="", lw=3)

plt.plot([0, len(Y_ei)], [fun.optimal_value, fun.optimal_value], "k--", lw=3)
plt.xlabel("Function value", fontsize=18)
plt.xlabel("Number of evaluations", fontsize=18)
plt.title("20D Ackley", fontsize=24)
plt.xlim([0, len(Y_ei)])
plt.ylim([-15, 1])

plt.grid(True)
plt.tight_layout()
plt.legend(
    names + ["Global optimal value"],
    loc="lower center",
    bbox_to_anchor=(0, -0.08, 1, 1),
    bbox_transform=plt.gcf().transFigure,
    ncol=4,
    fontsize=16,
)
plt.show()

In [12]:
# %matplotlib inline

# names = ["EI"] # , "Sobol"
# runs = [Y_ei] #, Y_Sobol
# fig, ax = plt.subplots(figsize=(8, 6))

# for name, run in zip(names, runs):
#     fx = np.maximum.accumulate(run.cpu())
#     plt.plot(-fx + branin.optimal_value, marker="", lw=3)

# plt.ylabel("Regret", fontsize=18)
# plt.xlabel("Number of evaluations", fontsize=18)
# plt.title(f"{dim}D Embedded Branin", fontsize=24)
# plt.xlim([0, len(Y_ei)])
# plt.yscale("log")

# plt.grid(True)
# plt.tight_layout()
# plt.legend(
#     names + ["Global optimal value"],
#     loc="lower center",
#     bbox_to_anchor=(0, -0.08, 1, 1),
#     bbox_transform=plt.gcf().transFigure,
#     ncol=4,
#     fontsize=16,
# )
# plt.show()

: 