In [1]:
import random
import time

In [2]:
import jax
import jax.numpy as jnp

In [3]:
import matplotlib
import matplotlib.pyplot as plt

In [4]:
from basics import definitions as defs
from gp_utils import gp

In [5]:
font = {
    'family': 'serif',
    'weight': 'normal',
    'size': 7,
}
axes = {'titlesize': 7, 'labelsize': 7}
matplotlib.rc('font', **font)
matplotlib.rc('axes', **axes)

GPParams = defs.GPParams
SubDataset = defs.SubDataset

In [6]:
def plot_function_samples(
    mean_func,
    cov_func,
    params,
    warp_func=None,
    num_samples=1,
    random_seed=0,
    x_min=0,
    x_max=1,
):
  """Plot function samples from a 1-D Gaussian process."""
  key = jax.random.PRNGKey(random_seed)
  key, y_key = jax.random.split(key, 2)
  x = jnp.linspace(x_min, x_max, 100)[:, None]
  y = gp.sample_from_gp(
      y_key,
      mean_func,
      cov_func,
      params,
      x,
      warp_func=warp_func,
      num_samples=num_samples,
      method='svd',
  )
  fig = plt.figure(dpi=200, figsize=(2, 1))
  plt.plot(x, y)
  plt.xlabel('x')
  plt.ylabel('f(x)')

In [7]:
###########################################################
### Define a ground truth GP and generate training data ###
###########################################################

# @title Define a ground truth GP and generate training data
params = GPParams(
    model={
        'lengthscale': 0.1,
        'signal_variance': 10.0,
        'noise_variance': 1e-6,
        'constant': 5.0,
    }
)  # parameters of the GP