In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import norm
import GPy
from matplotlib.backends.backend_pdf import PdfPages

%matplotlib inline
sns.set_style("white")
import matplotlib

matplotlib.rcParams["text.usetex"] = True
plt.rcParams["font.size"] = 15


def f(x):
    res = (
        x
        + 8 * np.exp(-0.5 * np.square(x - 3) / np.square(1))
        + 5 * np.sin(2 * x)
        - 8 * np.exp(-0.5 * np.square(x - 3.8) / np.square(0.25))
        + 1 * np.exp(-0.5 * np.square(x - 0.8) / np.square(0.4))
    )
    return res


def f2(x):
    res = 3 - 40 * x + 38 * np.power(x, 2) - 11 * np.power(x, 3) + np.power(x, 4)
    return -res


def straddle(mean, std, th):
    return 1.96 * std - np.abs(mean - th)

In [None]:
np.random.seed(0)
n_split = 200
x = np.linspace(-0.5, 5.5, n_split).reshape(-1, 1)
y = f(x)

obs_x = []
obs_y = []
obs_i = []


pdf = PdfPages("lse.pdf")

q = [20, 190]
# q = [10, 85]
th = 7

for item in q:
    obs_x.append(x[item])
    obs_y.append(y[item])
    obs_i.append(item)


for i in range(15):
    m = np.mean(y)
    gp = GPy.models.GPRegression(np.array(obs_x), np.array(obs_y) - m)
    gp.Gaussian_noise.constrain_fixed(1e-2)
    gp.rbf.variance.constrain_fixed(16)
    gp.rbf.lengthscale.constrain_fixed(0.5)
    # gp.optimize()
    # print(gp)

    mean, var = gp.predict(x)
    mean += m
    mean = mean.flatten()
    var = var.flatten()
    ci = 1.96 * np.sqrt(var)

    upper_i = np.where((mean - ci) > th)
    # print(upper_i)
    upper_range = []
    upper_flag = False
    start = 0
    for xi in np.arange(n_split):
        if np.any(upper_i == xi):
            if not upper_flag:
                start = xi
                upper_flag = True
        else:
            if upper_flag:
                upper_flag = False
                upper_range.append((start, xi))
            else:
                upper_flag = False
    if upper_flag is True:
        upper_range.append((start, n_split - 1))

    lower_i = np.where((mean + ci) < th)
    lower_range = []
    lower_flag = False
    start = 0
    for xi in np.arange(n_split):
        if np.any(lower_i == xi):
            if not lower_flag:
                start = xi
                lower_flag = True
        else:
            if lower_flag:
                lower_flag = False
                lower_range.append((start, xi))
            else:
                lower_flag = False
    if lower_flag is True:
        lower_range.append((start, n_split - 1))

    fig, ax = plt.subplots()
    ax.plot(x.flatten(), mean, label=r"$\mu(x)$", zorder=2, lw=2)
    ax.fill_between(
        x.flatten(),
        mean - ci,
        mean + ci,
        label=r"$\mu(x)\pm 1.96\sigma(x)$",
        alpha=0.3,
        zorder=2,
    )
    ax.axhline(th, label=r"$\theta$", ls="--", c="tab:red", alpha=1, lw=2, zorder=1)

    ax.plot(
        x.flatten(), y.flatten(), c="black", ls="--", label=r"$f(x)$", zorder=1, lw=2.5
    )

    marker = "s"
    color = "white"
    ax.scatter(
        np.array(obs_x)[:-1],
        np.array(obs_y)[:-1],
        marker=marker,
        s=45,
        color=color,
        edgecolor="black",
        lw=1.5,
        zorder=3,
    )

    if i == 0:
        color = "white"
    else:
        color = "gold"
    ax.scatter(
        np.array(obs_x)[-1],
        np.array(obs_y)[-1],
        marker="s",
        s=45,
        color=color,
        edgecolor="black",
        lw=1.5,
        zorder=3,
    )

    for r in upper_range:
        ax.axvspan(x.flatten()[r[0]], x.flatten()[r[1]], alpha=0.2, color="tab:red")

    for r in lower_range:
        ax.axvspan(x.flatten()[r[0]], x.flatten()[r[1]], alpha=0.2, color="tab:green")
        
        
    ax.set_xlabel("$x$", fontsize=18)
    ax.set_ylabel("$f(x)$")
    ax.set_xlim(-0.7, 5.7)
    ax.set_ylim(-6, 14)
    ax.legend(borderaxespad=0, ncol=2, framealpha=0.7, fontsize=13, loc="lower right")
    ax.set_title("iteration {}".format(i + 1))

    margin = 1
    ac = straddle(mean[::margin], var[::margin], th)
    ac[np.array(obs_i, dtype=int) // margin] = 0
    next_i = np.argmax(ac) * margin

    fig.tight_layout()
    pdf.savefig(fig)
    obs_x.append(x[next_i])
    obs_y.append(y[next_i])
    obs_i.append(next_i)

print(gp)
pdf.close()