In [47]:
import jax
import jax.numpy as jnp
from gpjax.kernels.base import AbstractKernel
import numpy as np
def dummy_npwarn_decorator_factory():
  def npwarn_decorator(x):
    return x
  return npwarn_decorator
np._no_nep50_warning = getattr(np, '_no_nep50_warning', dummy_npwarn_decorator_factory)

# Enable Float64 for more stable matrix inversions.
from jax import config
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt

config.update("jax_enable_x64", True)

with install_import_hook("gpjax", "beartype.beartype"):
    import gpjax as gpx


key = jr.key(123)

cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
def rbf(x, y, amplitude, lengthscale):
    """Plain RBF kernel between scalars x and y."""
    r = (x - y)
    return amplitude**2 * jnp.exp(-0.5 * (r / lengthscale)**2)

def drbf_dx(x, y, amplitude, lengthscale):
    """Derivative of RBF w.r.t x."""
    return (-(x - y) / lengthscale**2) * rbf(x, y, amplitude, lengthscale)

def drbf_dy(x, y, amplitude, lengthscale):
    """Derivative of RBF w.r.t y (which is just -drbf_dx w.r.t x)."""
    return ((x - y) / lengthscale**2) * rbf(x, y, amplitude, lengthscale)

def d2rbf_dxdy(x, y, amplitude, lengthscale):
    """Second partial derivative: d^2/dx dy of the RBF."""
    sq_term = ((x - y)**2 / lengthscale**4) - (1.0 / lengthscale**2)
    return amplitude**2 * sq_term * jnp.exp(-0.5 * ((x - y)/lengthscale)**2)

class DerivativeKernel(AbstractKernel):
    """
    A multi-output kernel that returns
      k00, k01, k10, k11
    for function + derivative.
    """

    def __init__(self, amplitude=1.0, lengthscale=1.0):
        super().__init__()
        self.amplitude = amplitude
        self.lengthscale = lengthscale
        
    def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
        """
        x, y shape: [n, 2], 
          where x[:, 0] is the location, x[:, 1] in {0,1} is the output-index (f or f').
        produce an [n, m] kernel matrix for all pairs.
        """
        # x.shape = (n, 2), y.shape = (m, 2)
        # We'll compute a (n, m) kernel matrix for each pair in x, y.
        
        # Let’s separate out the spatial locations and derivative indicators:
        x_loc = x[:, 0]
        x_out = x[:, 1].astype(int)  # 0 or 1
        y_loc = y[:, 0]
        y_out = y[:, 1].astype(int)
        
        # We'll want to build a big (n, m) matrix:
        def k_ij(i, j):
            xi = x_loc[i]
            yi = y_loc[j]
            di = x_out[i]
            dj = y_out[j]
            
            if di == 0 and dj == 0:
                # k00
                return rbf(xi, yi, self.amplitude, self.lengthscale)
            elif di == 0 and dj == 1:
                # k01 = partial wrt y
                return drbf_dy(xi, yi, self.amplitude, self.lengthscale)
            elif di == 1 and dj == 0:
                # k10 = partial wrt x
                return drbf_dx(xi, yi, self.amplitude, self.lengthscale)
            else:
                # k11 = partial^2 wrt x and y
                return d2rbf_dxdy(xi, yi, self.amplitude, self.lengthscale)

        # We can vmap over rows and columns to build up the matrix:
        n = x.shape[0]
        m = y.shape[0]
        # Build function that handles single pair (i, j):
        k_map = jax.vmap(
            lambda i: jax.vmap(lambda j: k_ij(i, j))(jnp.arange(m))
        )(jnp.arange(n))
        
        return k_map

In [151]:
class RegularIrrigationSim:
    def __init__(self, theta0, starts, lam, sigma_theta):
        self.soil_moist = theta0 * np.dot(np.ones_like(starts), 0 >= starts)
        self.theta0 = theta0
        self.lam = lam
        self.sigma_theta = sigma_theta
        self.time = 0.0
        self.starts = starts
        self.data = {
            'latent' : {
                't' : [0.0],
                'y' : [self.soil_moist]
            },
            'observed' : {
                't' : [],
                'y' : []
            }
        }
    
    def update(self, dt):
        self.soil_moist = self.theta0 * np.dot(np.exp(-self.lam * (self.time - self.starts)), self.time >= self.starts)
        self.time += dt
        self.data['latent']['t'].append(self.time)
        self.data['latent']['y'].append(self.soil_moist)

    def observation(self):
        self.data['observed']['t'].append(self.time)
        self.data['observed']['y'].append(self.soil_moist * np.exp(np.random.normal(0, self.sigma_theta)))


theta0 = 50
starts = np.array([0, 100, 200])
lam = 0.05
dt = 0.1
obs_interval = 3 
t = 60 * 24
sigma_theta = 0.3
min_threshold = sigma_theta
sim = RegularIrrigationSim(theta0, starts, lam, sigma_theta)
obs_timer = 0
while (sim.time < t) and (sim.soil_moist > min_threshold):
    sim.update(dt)
    obs_timer += dt
    if obs_timer >= obs_interval:
        sim.observation()
        obs_timer = 0
sim.data['observed']['t'] = sim.data['observed']['t'][:-1]
sim.data['observed']['y'] = sim.data['observed']['y'][:-1]

Xobs = jnp.array(sim.data['observed']['t']).astype(jnp.float64)
X = jnp.arange(0, max(Xobs), 1) 


def add_dims(input_array, n_dims): 
    arrays = [] 
    for col in range(n_dims):
        array = jnp.column_stack([input_array, jnp.array([col] * len(input_array))]) 
        arrays.append(array) 
    return jnp.concat(arrays, axis=0).astype(jnp.float64)

Y = jnp.array(sim.data['observed']['y']).reshape(-1, 1)
Y = jnp.concat([Y, Y], axis=0)
Xobs = add_dims(Xobs, 2)

D = gpx.Dataset(Xobs, Y)


def initialise_gp(kernel, mean, dataset):
    prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
    likelihood = gpx.likelihoods.Gaussian(
        num_datapoints=dataset.n, obs_stddev=jnp.array([1.0e-3], dtype=jnp.float64)
    )
    posterior = prior * likelihood
    return posterior

mean = gpx.mean_functions.Zero()
kernel = DerivativeKernel()

posterior = initialise_gp(kernel, mean, D) 

def optimise_mll(posterior, dataset, NIters=1000, key=key):
    # define the MLL using dataset_train
    objective = lambda p, d: -gpx.objectives.conjugate_mll(p, d)
    # Optimise to minimise the MLL
    opt_posterior, history = gpx.fit_scipy(
        model=posterior,
        objective=objective,
        train_data=dataset,
    )
    return opt_posterior

opt_post = optimise_mll(posterior, D)

IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.