In [1]:
from pathlib import Path
import sys

from ax.service.ax_client import AxClient
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
import seaborn as sns

sys.path.insert(0, "/Users/williamjenkins/Research/Projects/BOGP/Source/BOGP")
from aggregate import Aggregator

ROOT = Path("/Users/williamjenkins/Research/Projects/BOGP/Data/range_estimation/simulation/serial_example/rec_r=3.0__src_z=62.0__snr=20/sequential_qei/seed_0292288111")
figpath = ROOT / "figures"

In [2]:
X = np.load(ROOT / "X.npy").squeeze()
y = np.load(ROOT / "y.npy")
y_test = np.load(ROOT / "y_test.npy")
cov_test = np.load(ROOT / "cov_test.npy")
alpha_test = np.load(ROOT / "alpha_test.npy")
df, _ = Aggregator.extract_results(ROOT / "results.json")

print(
    f"X: {X.shape}\n"
    f"y: {y.shape}\n"
    f"y_test: {y_test.shape}\n"
    f"cov_test: {cov_test.shape}\n"
    f"alpha_test: {alpha_test.shape}"
)

X: (1000,)
y: (1000,)
y_test: (90, 1000)
cov_test: (90, 1000)
alpha_test: (90, 1000)


In [6]:
N_WARMUP = 10

for i in range(90):
    i_stop = N_WARMUP + i - 1 if i > 0 else N_WARMUP + i
    mean = y_test[i]
    ucb = mean + 2 * cov_test[i]
    lcb = mean - 2 * cov_test[i]
    # lcb[lcb < 0] = 0
    alpha = alpha_test[i]
    if i > 0:
        alpha_prev = alpha_test[i - 1]
    

    fig, axs = plt.subplots(figsize=(12, 6), nrows=2, height_ratios=[0.7, 0.3], facecolor="w")
    ax = axs[0]
    ax.plot(X, y, c="g", label="True")
    ax.plot(X, mean, label="Mean")
    ax.fill_between(X, lcb, ucb, alpha=0.5, label="Cov")
    ax.axvline(X[np.argmax(alpha)], color="k", linestyle=":")
    ax.scatter(df["rec_r"].iloc[0:i_stop], df["bartlett"].iloc[0:i_stop], marker="x", c="k", label="Samples", alpha=0.5, zorder=10)
    if i > 0:
        ax.scatter(df["rec_r"].iloc[i_stop], df["bartlett"].iloc[i_stop], marker="*", color="r", linewidth=4, zorder=20)
        ax.axvline(X[np.argmax(alpha_prev)], color="r", linestyle=":")


    sns.rugplot(data=df[0:i_stop], x="rec_r", ax=ax, height=-.05, clip_on=False, color="k", alpha=0.35)
    ax.set_xlim(-0.1, 10.1)
    ax.set_xticks([])
    ax.set_xlabel(None)
    ax.set_ylim(-0.1, 1.1)
    ax.set_ylabel("$f(x)$", rotation=0, ha="right")
    ax.text(0, 1, f"Iteration {i}")
    ax.legend()

    ax = axs[1]
    ax.plot(X, alpha)
    ax.axvline(X[np.argmax(alpha)], color="k", linestyle=":")
    if i > 0:
        ax.axvline(X[np.argmax(alpha_prev)], color="r", linestyle=":")
    ax.set_xlim(-0.1, 10.1)
    ax.set_xticklabels([])
    ax.set_ylabel("$\\alpha(x)$", rotation=0, ha="right")


    figpath.mkdir(exist_ok=True)
    fig.savefig(figpath / f"{i:03d}.png", bbox_inches="tight")
    plt.close()
