In [None]:
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from bloptools import test_functions

import os

os.environ[
    "PATH"
] = "/opt/homebrew/opt/llvm/bin:/Users/tom/opt/anaconda3/bin:/Users/tom/opt/anaconda3/condabin:/opt/homebrew/bin:/opt/homebrew/sbin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin:/Library/TeX/texbin:/opt/X11/bin"

plt.rcParams.update(
    {
        "text.usetex": True,
    }
)

In [None]:
from bloptools import devices

dofs = [
    {"device": devices.DOF(name="x1"), "limits": (-6, 6), "kind": "active"},
    {"device": devices.DOF(name="x2"), "limits": (-6, 6), "kind": "active"},
]


def digestion(db, uid):
    products = db[uid].table()

    for index, entry in products.iterrows():
        products.loc[index, "himmelblau"] = test_functions.himmelblau(entry.x1, entry.x2)

    return products


tasks = [
    {"key": "himmelblau", "kind": "minimize"},
]

In [None]:
from bloptools import devices

dofs = [
    {"device": devices.DOF(name="x1"), "limits": (0, 1), "kind": "active"},  # "latent_group": 0},
    {"device": devices.DOF(name="x2"), "limits": (0, 1), "kind": "active"},  # "latent_group": 0},
    {"device": devices.DOF(name="x3"), "limits": (0, 1), "kind": "active"},  # "latent_group": 0},
    {"device": devices.DOF(name="x4"), "limits": (0, 1), "kind": "active"},  # "latent_group": 0},
    {"device": devices.DOF(name="x5"), "limits": (0, 1), "kind": "active"},  # "latent_group": 0},
    {"device": devices.DOF(name="x6"), "limits": (0, 1), "kind": "active"},  # "latent_group": 0},
]


def digestion(db, uid):
    products = db[uid].table()

    for index, entry in products.iterrows():
        products.loc[index, "neg_hartmann"] = -test_functions.hartmann6(
            entry.x1, entry.x2, entry.x3, entry.x4, entry.x5, entry.x6
        )

    return products


tasks = [
    {"key": "neg_hartmann", "kind": "maximize", "transform": "log"},
]

In [None]:
import bloptools
from bloptools.utils import prepare_re_env

%run -i $prepare_re_env.__file__ --db-type=temp
from bloptools.bayesian import Agent

agent = Agent(
    dofs=dofs,
    tasks=tasks,
    digestion=digestion,
    db=db,
)

# RE(agent.learn("qr", n=128))
# RE(agent.learn("qei", n=1, iterations=4))

In [None]:
RE(
    agent.benchmark(
        output_dir="../data/benchmark_hartmann6_no_latent/",
        n_init=4,
        runs=64,
        learning_kwargs_list=[{"acq_func": "qei", "n": 4, "iterations": 64}],
    )
)
# {"acq_func": "qem", "n": 4, "iterations": 16},
# {"acq_func": "qei", "n": 4, "iterations": 16},
# {"acq_func": "qem", "n": 4, "iterations": 16}]))

In [None]:
os.mkdir()

In [None]:
import pandas as pd
import glob

paths = glob.glob("../data/benchmark_hartmann_no_latent/*.h5")

for path in paths:
    table = pd.read_hdf(path, key="table")

    plt.plot(-cummax(np.exp(table.neg_hartmann_fitness.values)), lw=1e-1)

In [None]:
RE(agent.learn("qei", n=2, iterations=16))

In [None]:
np.round(agent.tasks[0]["model"].covar_module.latent_transform.detach(), 3)

In [None]:
cummax = lambda iterable: np.array([np.nanmax(iterable[:i]) for i in range(1, len(iterable) + 1)])

In [None]:
agent.plot_tasks()

In [None]:
qei, _ = bloptools.bayesian.acquisition.get_acquisition_function(agent, "qei")

In [None]:
X, _ = agent.ask("qei", n=8)

In [None]:
import torch

x1 = torch.linspace(-6, 6, 63)
x2 = torch.linspace(-6, 6, 63)
X1, X2 = torch.meshgrid(x1, x2, indexing="ij")

xg = torch.cat([X1.unsqueeze(-1), X2.unsqueeze(-1)], dim=-1)
obj_grid = qei(xg.reshape(-1, 1, 2)).reshape(xg.shape[:2]).detach()

In [None]:
import scipy as sp

In [None]:
optima = [(3, 2), (-2.805, 3.1313), (-3.779, -3.283), (3.584, -1.848)]


fig, axes = plt.subplots(1, 4, figsize=(12, 4), dpi=256)
axes = axes.ravel()

y = -agent._get_task_fitness(0)

post = agent.model.posterior(xg)
pred = -post.mean[..., 0].detach()
err = post.variance[..., 0].detach()

norm = mpl.colors.Normalize(-500, 0)

cmap = "magma"

zoom = 8

uxg1 = sp.ndimage.zoom(xg[..., 0], zoom=zoom)
uxg2 = sp.ndimage.zoom(xg[..., 1], zoom=zoom)
umu = sp.ndimage.zoom(pred, zoom=zoom)
usig = np.exp(sp.ndimage.zoom(np.log(err.sqrt()), zoom=zoom))
uobj = sp.ndimage.zoom(obj_grid, zoom=zoom)

samp = axes[0].scatter(*agent.active_inputs.values.T, c=-y, s=16, norm=norm, cmap=cmap)
# axes[1].pcolormesh(uxg1, uxg2, -umu, norm=norm, cmap=cmap)
# err = axes[2].pcolormesh(uxg1, uxg2, ulv, cmap=cmap)
# acq = axes[3].pcolormesh(uxg1, uxg2, uobj, cmap=cmap)


mean = axes[1].imshow(-umu.T[::-1], norm=norm, extent=[-6, 6, -6, 6], cmap=cmap, aspect="auto")
err = axes[2].imshow(usig.T[::-1], norm=mpl.colors.LogNorm(), extent=[-6, 6, -6, 6], cmap=cmap, aspect="auto")
acqf = axes[3].imshow(uobj.T[::-1], extent=[-6, 6, -6, 6], cmap=cmap, aspect="auto")

for ax in axes:
    for x, y in optima:
        ax.scatter(x, y, facecolor="w", edgecolor="k", marker="o", lw=5e-1, s=16)

    ax.set_xlim(-6, 6)
    ax.set_ylim(-6, 6)
    ax.set_xticks(np.arange(-6, 7, 2))
    ax.set_yticks(np.arange(-6, 7, 2))
    ax.set_xlabel("$x_1$")
    ax.set_ylabel("$x_2$")

axes[3].scatter(*X.T, marker="d", lw=5e-1, color="k", s=16)

# axes[0].set_title("sampled points")
# axes[1].set_title("posterior mean")
# axes[2].set_title("posterior log std. dev.")
# axes[3].set_title("$q$-expected improvement")


# clb1 = fig.colorbar(mappable=mpl.cm.ScalarMappable(norm=norm), ax=axes[0:2], location="bottom", shrink=0.8, aspect=32)
clb1 = fig.colorbar(samp, ax=axes[0], location="bottom", shrink=0.8, aspect=32)
clb1.set_label("sampled points")
clb2 = fig.colorbar(mean, ax=axes[1], location="bottom", shrink=0.8, aspect=32)
clb2.set_label("posterior mean")
clb3 = fig.colorbar(err, ax=axes[2], location="bottom", shrink=0.8, aspect=32)
clb3.set_label("posterior std. dev.")
clb4 = fig.colorbar(acqf, ax=axes[3], location="bottom", shrink=0.8, aspect=32)
clb4.set_label("$q$-expected improvement")
# clb3 = fig.colorbar(acq, ax=axes[3], location="bottom", shrink=0.8, aspect=32)

plt.tight_layout()
plt.savefig("../plots/bayesian_himmelblau.pdf", bbox_inches="tight", transparent=True, pad_inches=0, dpi=256)

In [None]:
np.exp(ulv)

In [None]:
plt.imshow(obj)