# Generating prior samples from a GP

In [None]:
import matplotlib.pyplot as plt

plt.plot()  # Required to reset the rcParams for some reason
plt.style.use(["default", "./araa-gps.mplstyle"])
plt.close()

In [None]:
import jax
import jax.numpy as jnp
import seaborn as sns
from tinygp import GaussianProcess, kernels

from paths import figures

In [None]:
t = jnp.linspace(0, 20, 100)


def plot_samples(ax, kernel):
    gp = GaussianProcess(kernel, t, diag=1e-5)
    y = gp.sample(jax.random.PRNGKey(10), (3,))
    ax.plot(t, y.T, "k", lw=0.75, alpha=0.8)
    ax.axhline(0, color="k", lw=1, alpha=0.5)
    ax.set_xlim(t.min(), t.max())
    ax.set_ylim(-4.0, 4.0)


def plot_grid(title, kernel):
    fig, axes = plt.subplots(2, 2, figsize=(4, 4), sharex=True, sharey=True)
    plot_samples(axes[0, 0], kernel(1.0))
    plot_samples(axes[0, 1], kernel(2.0))
    plot_samples(axes[1, 0], 2 * kernel(1.0))
    plot_samples(axes[1, 1], 2 * kernel(2.0))
    fig.subplots_adjust(hspace=0.1, wspace=0.15)
    sns.despine()
    for ax in axes[:, 0]:
        ax.set_ylabel("f(t)")
    for ax in axes[-1, :]:
        ax.set_xlabel("t")
    fig.suptitle(title)
    return fig

In [None]:
plot_grid("Exponential", kernels.Exp).savefig(
    figures / "samples1.pdf", bbox_inches="tight"
)
plot_grid("Matérn-3/2", kernels.Matern32).savefig(
    figures / "samples2.pdf", bbox_inches="tight"
)
plot_grid("Exponential Squared", kernels.ExpSquared).savefig(
    figures / "samples3.pdf", bbox_inches="tight"
)

In [None]:
# plot_grid(
#     r"Exponential Sine Squared ($\Gamma=1/2$)",
#     lambda scale: 0.8
#     * kernels.ExpSquared(2.0)
#     * kernels.ExpSineSquared(2 * scale, gamma=1.0)
# ).savefig(figures / "samples4.pdf", bbox_inches="tight")