In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from functools import lru_cache

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from pathlib import Path
import plotly.graph_objects as go
import json
import pandas as pd
from string import ascii_lowercase
from PyAstronomy.pyasl import broadGaussFast
from scipy.interpolate import InterpolatedUnivariateSpline
from scipy.spatial import distance_matrix
import torch
from tqdm import tqdm

In [None]:
# Put this in your path for pretty plotting! https://gist.github.com/x94carbone/f5201b1c44963ff9453b9cc1d5f768ac
from mpl_utils import MPLAdjutant
adj = MPLAdjutant()
adj.set_defaults()
plt.rcParams["figure.figsize"] = (3, 2)

In [None]:
from value_agent import experiments, phases, value

In [None]:
@lru_cache
def read_data(path="results/results_sine2phase"):
    # Takes a while, it's about 9 GB of data to read into memory
    results = dict()
    for file in Path(path).rglob("sine2phase*.json"):
        with open(file, "r") as f:
            d = json.loads(json.load(f))
        c = experiments.Experiment.from_dict(d)
        results[str(c.name)] = c
    return results

# Two-phase sine result

## Plot the phases only

In [None]:
x, y, Z = phases.get_phase_plot_info(phases.phase_1_sine_on_2d_raster)
X, Y = np.meshgrid(x, y)
gradZ = np.array(np.gradient(Z))
gradZ = np.sqrt((gradZ**2).sum(axis=0))

In [None]:
extent = (x[0], x[-1], y[0], y[-1])
scale = 1
lw = 1

fig, axs = plt.subplots(1, 2, figsize=(4*scale, 2*scale), sharex=True, sharey=True)

ax = axs[0]
im = ax.imshow(Z, interpolation='bilinear', origin='lower', cmap=cm.binary, extent=extent)
ax.set_ylabel("$x_2$~[a.u.]")
adj.set_grids(ax, grid=False)
ax.text(0.1, 0.1, r"$p(\mathbf{x})$", ha="left", va="bottom", transform=ax.transAxes)

g = np.linspace(0, 1, 100)
ax.plot(g, 0.5 + np.sin(2.0 * np.pi * g) / 4, "y--", linewidth=lw, zorder=0)

# ax.text(0.9, 0.9, "$g(x_1) = \\frac{1}{2} + \\frac{1}{4} \sin(2\pi x_1)$", ha="right", va="top", transform=ax.transAxes, color="yellow")


ax = axs[1]
im = ax.imshow(gradZ, interpolation='bilinear', origin='lower', cmap=cm.binary, extent=extent)

ax.plot(g, 0.5 + np.sin(2.0 * np.pi * g) / 4, "y--", linewidth=lw, zorder=0, label="$g(x_1)$")

ax.plot(g, 0.55 + 0.25 * np.sin(2.0 * np.pi * g), "k--", linewidth=lw, label="$g(x_1) \pm 0.05$")
ax.plot(g, 0.45 + 0.25 * np.sin(2.0 * np.pi * g), "k--", linewidth=lw)
ax.text(0.1, 0.1, r"$||\nabla_\mathbf{x} p(\mathbf{x})||$", ha="left", va="bottom", transform=ax.transAxes)

adj.set_grids(ax, grid=False)
ax.set_xticks([0, 1])
ax.set_yticks([0, 1])
ax.legend(frameon=False)


ax = fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axes
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("$x_1$~[a.u.]", labelpad=15)

plt.subplots_adjust(wspace=0.1)
# plt.savefig("figures/01_sine/sine_phase_plot.svg", dpi=300, bbox_inches="tight")
plt.show()

## HFSVF convergence

In [None]:
# X = np.random.random(size=(2000, 2))
X = np.linspace(0, 1, 50)
X = np.array([[xx, yy] for xx in X for yy in X])

Y = phases.sine_on_2d_raster_observations(X)
V = value.value_function(X, Y)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 3))

ax.scatter(X[:, 0], X[:, 1], c=V, s=6, marker="s")

plt.show()

## Plot everything (possibly for the appendix)

Load in the data. This was computed using random seeds `range(125, 155)` for the initial random conditions of 3 points and seeds `range(225,255)` for seeding `pytorch` during the Bayesian Optimization procedure.

In [None]:
results = read_data(path="results/results_sine2phase_sd_0_15")

In [None]:
len(results)

In [None]:
plotting = ["Random", "MaxVar", "EI", "UCB1", "UCB10", "UCB20", "UCB100"]
aq_strings = {"Random": "Random", "MaxVar": "MaxVar", "EI": "EI", "UCB1": "UCB(1)", "UCB10": "UCB(10)", "UCB20": "UCB(20)", "UCB100": "UCB(100)"}

In [None]:
plt.clf()

extent = (0, 1, 0, 1)
scale = 2
n_cols = 6
fontsize = 22
cseed = 125
eseed = 225

fig, axs = plt.subplots(len(plotting), n_cols, figsize=(n_cols * scale, len(plotting) * scale), sharey=True, sharex=True)

for ii_row, (name, letter_label) in enumerate(zip(plotting, list(ascii_lowercase))):
    for ii_col in range(n_cols):

        ax = axs[ii_row, ii_col]

        label = f"({letter_label}{ii_col+1})"
        t = ax.text(0.1, 0.1, label, ha="left", va="bottom", transform=ax.transAxes, fontsize=fontsize-8)
        t.set_bbox(dict(facecolor='white', alpha=0.8, edgecolor='white'))
        
        current = results[f"sine2phase-random-{name}-seed-{cseed}-{eseed}"]

        pred = current.recorded_Yhat[ii_col + 1]["mean"]
        n_reshape = int(np.sqrt(len(pred)))

        im = ax.imshow(
            pred.reshape(n_reshape, n_reshape).T,
            interpolation='bilinear',
            origin='lower',
            cmap=cm.binary,
            extent=extent
        )
        
        x = np.linspace(0, 1, 100)
        y = 0.5 + np.sin(2.0 * np.pi * x) / 4
        ax.plot(x, y, "y--", linewidth=1, zorder=0)
        
        X_plot = current.recorded_X[ii_col + 1][:, 0]
        Y_plot = current.recorded_X[ii_col + 1][:, 1]
        ax.scatter(X_plot, Y_plot, s=1.0, color="red")
        
        if ii_col == n_cols - 1:
            ax.text(1.05, 0.5, aq_strings[name], ha="left", va="center", transform=ax.transAxes, rotation=90, fontsize=fontsize)
        
        if ii_row == 0:
            n = len(X_plot)
            ax.set_title(f"$N={n}$", fontsize=fontsize)
            
        if ii_col == 0:
            X_plot = current.recorded_X[0][:, 0]
            Y_plot = current.recorded_X[0][:, 1]
            ax.scatter(X_plot, Y_plot, s=2, color="blue")

for ax in axs.flatten():
    adj.set_grids(ax, grid=False)
    ax.set_xticks([0, 1])
    ax.set_yticks([0, 1])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

ax = fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("$x_1$~[a.u.]", fontsize=fontsize)
ax.set_ylabel("$x_2$~[a.u.]", fontsize=fontsize)
    
plt.subplots_adjust(wspace=0.05, hspace=0.05)

plt.show()
# plt.savefig("figures/01_sine/sine_all.pdf", dpi=300, bbox_inches="tight")

## Metric - random initial conditions

In [None]:
plot_kwargs = {
    'linewidth': 0.5,
    'marker': 's',
    'ms': 1.0,
    'capthick': 0.3,
    'capsize': 2.0,
    'elinewidth': 0.3
}

def plot_metric(ax, metrics, N_vals=[3, 40, 80, 120, 160, 200, 240], use_axhline=True, plot_kwargs=plot_kwargs):

    for key, value in metrics.items():
        value = np.array(value).reshape(-1, len(N_vals))
        mu = value.mean(axis=0)
        sd = value.std(axis=0)

        ax.errorbar(N_vals, mu, yerr=sd/3, label=aq_strings[key], **plot_kwargs)

    if use_axhline:
        ax.axhline(0.1, linestyle="--", linewidth=0.5, color="black", label="Uniform")

    adj.set_grids(ax, grid=False)
    ax.set_xticks(N_vals)
    ax.tick_params(axis="x", which="minor", bottom=False, top=False)

In [None]:
L = len(range(125, 155))
metrics_random = {key: [[[] for jj in range(L) ] for ii in range(L)] for key in plotting}

In [None]:
for aq_name, af in zip(aq_strings, plotting):
    for ii, cseed in enumerate(range(125, 155)):
        for jj, eseed in enumerate(range(225, 255)):

            name = f"sine2phase-random-{af}-seed-{cseed}-{eseed}"
            result = results[name]
            points = result.recorded_X

            for p in points:
                total_points = p.shape[0]
                x = p[:, 0]
                y = p[:, 1]
                y_upper = 0.55 + 0.25 * np.sin(2.0 * np.pi * x)
                y_lower = 0.45 + 0.25 * np.sin(2.0 * np.pi * x)
                where = np.where((y < y_upper) & (y > y_lower))[0]
                L = len(where)
                ratio = L / total_points
                metrics_random[af][ii][jj].append(ratio)

## Metric - grid initial conditions

In [None]:
L = len(range(125, 155))
metrics_grid = {key: [[] for ii in range(L)] for key in plotting}

In [None]:
for aq_name, af in zip(aq_strings, plotting):
    for ii, eseed in enumerate(range(225, 255)):

        name = f"sine2phase-grid-{af}-seed-x-{eseed}"
        result = results[name]
        points = result.recorded_X

        for p in points:
            total_points = p.shape[0]
            x = p[:, 0]
            y = p[:, 1]
            y_upper = 0.55 + 0.25 * np.sin(2.0 * np.pi * x)
            y_lower = 0.45 + 0.25 * np.sin(2.0 * np.pi * x)
            where = np.where((y < y_upper) & (y > y_lower))[0]
            L = len(where)
            ratio = L / total_points
            metrics_grid[af][ii].append(ratio)

## Plot all metrics together

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(4, 2), sharey=True)

ax = axs[0]
plot_metric(ax, metrics_random, N_vals=[3, 40, 80, 120, 160, 200, 240])
ax.text(0.1, 0.9, "$\mu \pm \sigma/3$", ha="left", va="top", transform=ax.transAxes)
ax.set_ylabel("$N_\mathrm{in}/N$")
ax.set_title("Random: 3-points")
ax.set_xticks([3, 80, 160, 240])

ax = axs[1]
plot_metric(ax, metrics_grid, N_vals=[9, 40, 80, 120, 160, 200, 240])
ax.legend(frameon=False, bbox_to_anchor=(1.05, 0.5), loc="center left")
ax.set_title("Grid: 9-points")
ax.set_xticks([9, 80, 160, 240])

# Big axis for the x-label
ax = fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("$N$", labelpad=15)

plt.subplots_adjust(wspace=0.1)
# plt.show()
plt.savefig("figures/01_sine/sine_metric.pdf", dpi=300, bbox_inches="tight")

## Metrics - analyze fixed $\sigma$ values

In [None]:
tmp_results = read_data(path="results/results_sine2phase")

In [None]:
L = len(range(125, 155))
metrics_random_sd_val = {key: [[[] for jj in range(L)] for ii in range(L)] for key in plotting}

In [None]:
for aq_name, af in zip(aq_strings, plotting):
    for ii, cseed in tqdm(enumerate(range(125, 155))):
        for jj, eseed in enumerate(range(225, 255)):

            name = f"sine2phase-random-{af}-seed-{cseed}-{eseed}"
            result = tmp_results[name]
            points = result.recorded_X

            for p in points:
                X_dist = distance_matrix(p, p)
                distance = X_dist.copy()
                distance[distance == 0.0] = np.inf
                sd = distance.min(axis=1).mean()
                metrics_random_sd_val[af][ii][jj].append(sd)


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 2))

N_vals = [3, 40, 80, 120, 160, 200, 240]

for key, value in metrics_random_sd_val.items():
    value = np.array(value).reshape(-1, len(N_vals))
    mu = value.mean(axis=0)
    sd = value.std(axis=0)

    ax.errorbar(N_vals, mu, yerr=sd, label=aq_strings[key], **plot_kwargs)

adj.set_grids(ax, grid=False)
ax.set_xticks(N_vals)
ax.tick_params(axis="x", which="minor", bottom=False, top=False)



ax.text(0.1, 0.1, "$\mu \pm \sigma$", ha="left", va="bottom", transform=ax.transAxes)
ax.set_ylabel("$\sigma(X)$")
ax.set_title("Random: 3-points")
ax.set_xlabel("$|X|$")
ax.set_xticks([3, 80, 160, 240])

ax.legend(frameon=False, bbox_to_anchor=(1.05, 0.5), loc="center left")
ax.set_yscale("log")

ax.axhline(0.05, color="black", linestyle="--", linewidth=0.5, zorder=-1)
ax.axhline(0.15, color="black", linestyle="--", linewidth=0.5, zorder=-1)
ax.axhline(0.25, color="black", linestyle="--", linewidth=0.5, zorder=-1)

# ax.set_yticks([1e0, 1e-1, 1e-2])
# adj.set_ylim(ax, 1e-2, 1e0)

# plt.show()
plt.savefig("figures/01_sine/sine_metric_sigmas.pdf", dpi=300, bbox_inches="tight")

Plot examples.

In [None]:
tmp_results_adaptive = read_data(path="results/results_sine2phase")
tmp_results_005 = read_data(path="results/results_sine2phase_sd_0_05")
tmp_results_015 = read_data(path="results/results_sine2phase_sd_0_15")
tmp_results_025 = read_data(path="results/results_sine2phase_sd_0_25")

In [None]:
plt.clf()


results_list = [tmp_results_005, tmp_results_015, tmp_results_025, tmp_results_adaptive]

extent = (0, 1, 0, 1)
scale = 2
n_cols = len(results_list)
fontsize = 22
cseed = 125
eseed = 225



fig, axs = plt.subplots(len(plotting), n_cols, figsize=(n_cols * scale, len(plotting) * scale), sharey=True, sharex=True)

for ii_row, (name, letter_label) in enumerate(zip(plotting, list(ascii_lowercase))):
    for ii_col, res in enumerate(results_list):

        ax = axs[ii_row, ii_col]

        label = f"({letter_label}{ii_col+1})"
        t = ax.text(0.1, 0.1, label, ha="left", va="bottom", transform=ax.transAxes, fontsize=fontsize-8)
        t.set_bbox(dict(facecolor='white', alpha=0.8, edgecolor='white'))
        
        current = res[f"sine2phase-random-{name}-seed-{cseed}-{eseed}"]

        pred = current.recorded_Yhat[-1]["mean"]  # N = 240
        n_reshape = int(np.sqrt(len(pred)))

        im = ax.imshow(
            pred.reshape(n_reshape, n_reshape).T,
            interpolation='bilinear',
            origin='lower',
            cmap=cm.binary,
            extent=extent
        )
        
        x = np.linspace(0, 1, 100)
        y = 0.5 + np.sin(2.0 * np.pi * x) / 4
        ax.plot(x, y, "y--", linewidth=1, zorder=0)
        
        X_plot = current.recorded_X[-1][:, 0]
        Y_plot = current.recorded_X[-1][:, 1]
        ax.scatter(X_plot, Y_plot, s=1.0, color="red")
        
        if ii_col == n_cols - 1:
            ax.text(1.05, 0.5, aq_strings[name], ha="left", va="center", transform=ax.transAxes, rotation=90, fontsize=fontsize)
            
        X_plot = current.recorded_X[0][:, 0]
        Y_plot = current.recorded_X[0][:, 1]
        ax.scatter(X_plot, Y_plot, s=2, color="blue")

for ax in axs.flatten():
    adj.set_grids(ax, grid=False)
    ax.set_xticks([0, 1])
    ax.set_yticks([0, 1])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    
titles = [r"$\sigma=0.05$", r"$\sigma=0.15$", r"$\sigma=0.25$", r"Unconstrained"]
for title, ax in zip(titles, axs[0,:]):
    ax.set_title(title, fontsize=fontsize)

    
ax = fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("$x_1$~[a.u.]", fontsize=fontsize)
ax.set_ylabel("$x_2$~[a.u.]", fontsize=fontsize)
    
plt.subplots_adjust(wspace=0.05, hspace=0.05)

# plt.show()
plt.savefig("figures/01_sine/sine_all_different_sigma.pdf", dpi=300, bbox_inches="tight")

## Analyze the distribution of where the points actually are

In [None]:
N_N_vals = 7

In [None]:
dist_random = {key: [[] for ii in range(N_N_vals)] for key in plotting}
for af in plotting:
    for ii, cseed in enumerate(range(125, 155)):
        for jj, eseed in enumerate(range(225, 255)):

            name = f"sine2phase-random-{af}-seed-{cseed}-{eseed}"
            result = results[name]
            points = result.recorded_X

            for kk, p in enumerate(points):
                dist_random[af][kk].append(p)

dist_grid = {key: [[] for ii in range(N_N_vals)] for key in plotting}
for af in plotting:
    for jj, eseed in enumerate(range(225, 255)):

        name = f"sine2phase-grid-{af}-seed-x-{eseed}"
        result = results[name]
        points = result.recorded_X

        for kk, p in enumerate(points):
            dist_grid[af][kk].append(p)

In [None]:
dist_grid.keys()

In [None]:
type(dist)

In [None]:
scale = 2
n_cols = 6
fontsize = 22

dist = dist_random
name = "random"

fig, axs = plt.subplots(N_N_vals, n_cols, figsize=(n_cols * scale, len(plotting) * scale), sharex=True, sharey=True)

cmap = cm.get_cmap("plasma", n_cols + 1)

for row, ((key, value), letter_label) in enumerate(zip(dist.items(), ascii_lowercase)) :
    for n in range(1, N_N_vals):
        
        ax = axs[row, n - 1]
        
        label = f"({letter_label}{n})"
        t = ax.text(0.1, 0.9, label, ha="left", va="top", transform=ax.transAxes, fontsize=fontsize-8)
        t.set_bbox(dict(facecolor='white', alpha=0.8, edgecolor='white'))
        
        v = np.array(value[n])
        ax.hist(v[:, :, 0].flatten(), density=True, bins=20, color=cmap(n - 1))
        
        if row == 0:
            n = v.shape[1]
            ax.set_title(f"$N={n}$", fontsize=fontsize)

# Label the rows
for ii, key in enumerate(dist.keys()):
    ax = axs[ii, -1]
    ax.text(1.05, 0.5, aq_strings[key], ha="left", va="center", transform=ax.transAxes, rotation=90, fontsize=fontsize)
    
for ax in axs.flatten():
    adj.set_grids(ax, grid=False)
    ax.set_yticks([])

ax = fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axes
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel(r"$x_1$~[a.u.]", labelpad=15, fontsize=fontsize)
ax.set_ylabel(r"$\rho_\mathrm{Counts}(x_1)$", labelpad=10, fontsize=fontsize)
    
plt.subplots_adjust(wspace=0.05, hspace=0.05)
plt.savefig(f"figures/01_sine/sine_hist_{name}.pdf", bbox_inches="tight", dpi=300)
# plt.show()

# Four-phase result

In [None]:
# results = dict()
# for file in Path("results").rglob("xrd4phase*.json"):
#     if "gridstart" in str(file):
#         continue
#     with open(file, "r") as f:
#         d = json.loads(json.load(f))
#     c = experiments.Experiment.from_dict(d)
#     results[str(c.name)] = c

In [None]:
results = dict()
for file in Path("results").rglob("xrd4phase-gridstart*.json"):
    with open(file, "r") as f:
        d = json.loads(json.load(f))
    c = experiments.Experiment.from_dict(d)
    results[str(c.name)] = c

In [None]:
plotting = ["Random", "MaxVar", "EI", "UCB(1)", "UCB(10)", "UCB(20)", "UCB(100)"]

In [None]:
plt.clf()

extent = (0, 1, 0, 1)
scale = 2
n_cols = 6
fontsize = 22
seed = 125

fig, axs = plt.subplots(len(plotting), n_cols, figsize=(n_cols * scale, len(plotting) * scale), sharey=True, sharex=True)

for ii_row, (name, letter_label) in enumerate(zip(plotting, list(ascii_lowercase))):
    for ii_col in range(n_cols):

        ax = axs[ii_row, ii_col]

        label = f"({letter_label}{ii_col+1})"
        t = ax.text(0.5, 0.1, label, ha="center", va="bottom", transform=ax.transAxes, fontsize=fontsize-8)
        t.set_bbox(dict(facecolor='white', alpha=0.8, edgecolor='white'))
        
        # current = results[f"xrd4phase-{name}-seed={seed}"]
        current = results[f"xrd4phase-gridstart-{name}"]

        pred = current.recorded_Yhat[ii_col + 1]["mean"]
        n_reshape = int(np.sqrt(len(pred)))

        im = ax.imshow(
            pred.reshape(n_reshape, n_reshape).T,
            interpolation='bilinear',
            origin='lower',
            cmap=cm.binary,
            extent=extent
        )
        
        X_plot = current.recorded_X[ii_col + 1][:, 0]
        Y_plot = current.recorded_X[ii_col + 1][:, 1]
        ax.scatter(X_plot, Y_plot, s=1.0, color="red")
        
        if ii_col == n_cols - 1:
            ax.text(1.05, 0.5, name, ha="left", va="center", transform=ax.transAxes, rotation=90, fontsize=fontsize)
        
        if ii_row == 0:
            n = len(X_plot)
            ax.set_title(f"$N={n}$", fontsize=fontsize)
            
        if ii_col == 0:
            X_plot = current.recorded_X[0][:, 0]
            Y_plot = current.recorded_X[0][:, 1]
            ax.scatter(X_plot, Y_plot, s=2, color="blue")

for ax in axs.flatten():
    adj.set_grids(ax, grid=False)
    ax.set_xticks([0, 1])
    ax.set_yticks([0, 1])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

ax = fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("$x_1$~[a.u.]", fontsize=fontsize)
ax.set_ylabel("$x_2$~[a.u.]", fontsize=fontsize)

plt.subplots_adjust(wspace=0.05, hspace=0.05)

plt.show()
# plt.savefig("figures/02_four_phase/four_phase_all.pdf", dpi=300, bbox_inches="tight")

# UV

In [None]:
df = pd.read_csv("value_agent/uv_data.csv")
X = df[["NCit", "pH", "HA"]].to_numpy()
X[:, 1] += 16.0
Y = df.iloc[:, 4:].to_numpy()

In [None]:
df

In [None]:
x0 = np.linspace(X[:, 0].min(), X[:, 0].max(), 20)
x1 = np.linspace(X[:, 1].min(), X[:, 1].max(), 20)
x2 = np.linspace(X[:, 2].min(), X[:, 2].max(), 20)
coordinates = np.array([ [xx, yy, zz] for xx in x0 for yy in x1 for zz in x2 ])

In [None]:
obs = phases.truth_uv(coordinates)
value_truth = value.value_function(coordinates, obs)

In [None]:
dat = experiments.UVData.from_random(truth=phases.truth_uv, xmin=X.min(axis=0), xmax=X.max(axis=0), seed=125)

We can do some simple Bayesian optimization under constraint. From the botorch docs:

> inequality_constraints (Optional[List[Tuple[Tensor, Tensor, float]]]) – A list of tuples (indices, coefficients, rhs), with each tuple encoding an inequality constraint of the form sum_i (X[indices[i]] * coefficients[i]) >= rhs

Let's do a simple constraint where we want $x_1 + x_2 >= -1$. which would lead to a constraint `[(torch.tensor([0, 1]), torch.tensor([1, 1]), -1)]`

$x_1 + x_2 + x_3 \leq 56 \Rightarrow -x_1 - x_2 - x_3 \geq -56$

In [None]:
inequality_constraints=[(torch.tensor([0, 1, 2]), torch.tensor([-1, -1, -1]).float(), -56.0)]

In [None]:
exp = experiments.Experiment(
    dat,
    aqf="EI",
    # aqf_kwargs={"beta": 100.0},
    optimize_acqf_kwargs={"q": 1, "num_restarts": 5, "raw_samples": 20, "inequality_constraints": inequality_constraints}
)

In [None]:
exp.run(pbar=True, n_experiments=320)

In [None]:
x = exp.recorded_X[-1]
v = value.value_function(x, phases.truth_uv(x))

In [None]:
scat = go.Scatter3d(
    x=x[:, 0],
    y=x[:, 1] - 16.0,
    z=x[:, 2],
    marker=dict(color=v, size=5),
    marker_colorscale="viridis",
    mode='markers'
)

# mesh = go.Mesh3d( 
#     x=X.flatten(),
#     y=Y.flatten(),
#     z=Z.flatten(),
#     colorbar_title='Height',
#     intensity=preds,
#     showscale=True,
#     opacity=0.750,
#     alphahull=10,
#     colorscale="viridis"
# )

fig = go.Figure(data=[scat])

fig.update_layout(
    # title='something', 
    autosize=False,
    width=500, 
    height=500,
    margin=dict(l=0, r=0, b=0, t=0),
    scene=dict(
        xaxis_title="NCit",
        yaxis_title='"pH"',
        zaxis_title="HA",
    ),
)

fig.show()
# fig.write_image("test.pdf")
# plt.savefig("test.pdf", bbox_inches="tight", dpi=300)

Close and far points:

In [None]:
obs_1 = phases.truth_uv(np.array([[7.2, -11.4 + 16.0, 12.87]]))
obs_2 = phases.truth_uv(np.array([[7.5, -12.68 + 16.0, 12.34]]))

obs_3 = phases.truth_uv(np.array([[7.04, -13.12 + 16.0, 2.61]]))
obs_4 = phases.truth_uv(np.array([[6.95, -14.89 + 16.0, 4.55]]))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2))

ax.plot(obs_1.squeeze(), "r-")
ax.plot(obs_2.squeeze(), "r-")
ax.plot(obs_3.squeeze(), "b-")
ax.plot(obs_4.squeeze(), "c-")

plt.show()

In [None]:
value_truth_original_data = value.value_function(X, Y)

In [None]:
value_truth_3d = value_truth.reshape(20, 20, 20)

In [None]:
obs_3d = obs.reshape(20, 20, 20, 200)

In [None]:
obs_3d_gradient = np.sqrt((np.array(np.gradient(obs_3d)[:3]).mean(axis=-1)**2).sum(axis=0))

In [None]:
x_recorded = exp.recorded_X[-1]
yhat_recorded = exp.recorded_Yhat[-1]["mean"].reshape(100, 100, 100)
v = value.value_function(x_recorded, phases.truth_uv(x_recorded))

In [None]:
x_recorded_mins = X.min(axis=0) - np.array([0, 16, 0])
x_recorded_maxs = X.max(axis=0) - np.array([0, 16, 0])

In [None]:
shift_min = x_recorded.min(axis=0)
shifted = (x_recorded - shift_min)
shift_max = shifted.max(axis=0)
x_recorded = shifted / shift_max

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(6, 2), sharey=True)

dim0 = 0
dim1 = 1
dim2 = 2

ax = axs[0]
ax.scatter(x_recorded[:, dim0], x_recorded[:, dim1], x_recorded[:, dim2], c=v)
im = ax.imshow(yhat_recorded.mean(axis=-1).T, interpolation='bilinear', origin='lower', cmap=cm.binary, extent=[0, 1, 0, 1])
# ax.scatter(X_shifted[:, dim0], X_shifted[:, dim1], color="red")
ax.set_yticks([0, 1.0])
ax.set_yticklabels([x_recorded_mins[dim1], x_recorded_maxs[dim1]])
ax.set_xticks([0, 1.0])
ax.set_xticklabels([x_recorded_mins[dim0], x_recorded_maxs[dim0]])
adj.set_grids(ax, grid=False)

ax = axs[1]
im = ax.imshow(value_truth_3d.mean(axis=-1).T, interpolation='bilinear', origin='lower', cmap=cm.binary, extent=[0, 1, 0, 1])
# ax.scatter(x_recorded[:, dim0], x_recorded[:, dim1], s=0.5, c=v)
# ax.scatter(X_shifted[:, dim0], X_shifted[:, dim1], color="red")
ax.set_yticks([0, 1.0])
ax.set_yticklabels([x_recorded_mins[dim1], x_recorded_maxs[dim1]])
ax.set_xticks([0, 1.0])
ax.set_xticklabels([x_recorded_mins[dim0], x_recorded_maxs[dim0]])
adj.set_grids(ax, grid=False)


# ax.scatter(X_shifted[:, dim0], X_shifted[:, dim1], color="red")

ax = axs[2]
# ax.scatter(x_recorded[:, dim0], x_recorded[:, dim1], s=0.5, c=v)
im = ax.imshow(obs_3d_gradient.mean(axis=-1).T, interpolation='bilinear', origin='lower', cmap=cm.binary, extent=[0, 1, 0, 1])
ax.set_yticks([0, 1.0])
ax.set_yticklabels([x_recorded_mins[dim1], x_recorded_maxs[dim1]])
ax.set_xticks([0, 1.0])
ax.set_xticklabels([x_recorded_mins[dim0], x_recorded_maxs[dim0]])
adj.set_grids(ax, grid=False)


axs[1].set_xlabel("[NCit]")
axs[0].set_ylabel("pH")

plt.show()


In [None]:
ii = 10
jj = 200
N = 50

In [None]:
x0 = np.linspace(X[ii, 0], X[jj, 0], N)
x1 = np.linspace(X[ii, 1], X[jj, 1], N)
x2 = np.linspace(X[ii, 2], X[jj, 2], N)

In [None]:
X_test = np.array([x0, x1, x2]).T

In [None]:
# Y_test = phases.truth_uv(X_test) 
Y_test = phases.truth_uv(X_test)

In [None]:
cmap = cm.get_cmap("viridis", N)

In [None]:
xgrid = np.linspace(450, 750, 200)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2))

for kk, y in enumerate(Y_test):
    ax.plot(xgrid, y, color=cmap(kk))

ax.plot(xgrid, Y[ii, :], "k--")
ax.plot(xgrid, Y[jj, :], "k--")

plt.show()

In [None]:
x_grid