In [None]:
import mogpjax as mgpx 
import gpjax as gpx
import matplotlib.pyplot as plt 
import jax.numpy as jnp
import jax 
import jax.random as jr 
import jax.scipy as jsp

key = jr.PRNGKey(123)

$$\begin{align}
y_1(x) & = -\frac{\sin(10\pi(x+1))}{2x+1} - x^4 + \varepsilon_1 \\
y_2(x) & = \cos^2(y_1(x)) + \sin(3x) + \varepsilon_2 \\
y_3(x) & = y_2(x)y_1^2(x) + 3x + \varepsilon_3\,,
\end{align}$$

In [None]:
n = 20
noise = 0.05

f1 = lambda x: -jnp.sin(10*jnp.pi*(x+1))/(2*x + 1)-jnp.power(x, 4)
f2 = lambda x: jnp.square(jnp.cos(f1(x)))+jnp.sin(3*x)
f3 = lambda x: jnp.square(f1(x))*f2(x) + 3*x 

x = jnp.linspace(0, 1, n)

y1 = f1(x) + jr.normal(key, shape=(n, ))*noise
key, subkey = jr.split(key)
y2 = f2(x) + jr.normal(subkey, shape=(n, ))*noise
key, subkey = jr.split(key)
y3 = f3(x) + jr.normal(subkey, shape=(n, ))*noise

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.plot(x, y1, label=r'$y_1$')
ax.plot(x, y2, label=r'$y_2$')
ax.plot(x, y3, label=r'$y_3$')
ax.legend(loc='best')
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

In [None]:
X = x.reshape(-1, 1)
Y = jnp.stack([y1, y2, y3], axis=1)

n_data, n_tasks = Y.shape

In [None]:
from typing import Dict, List
from chex import dataclass
from gpjax.utils import concat_dictionaries
from gpjax.parameters import ParameterState
from jaxtyping import f64


task_kernel = gpx.kernels.Matern32(active_dims=[0, 1, 2])
data_kernel = gpx.kernels.Matern32()


@dataclass
class MultiOutputKernel:
    task_kernel: gpx.kernels.Kernel
    data_kernel: gpx.kernels.Kernel
    task_idxs: List[int] = None 

    def __post_init__(self):
        if self.task_idxs is None:
            self.task_idxs = self.task_kernel.active_dims

    def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
        task_params = self.task_kernel._initialise_params(key)
        data_params = self.data_kernel._initialise_params(key)
        params = {'task_params': task_params, 'data_params': data_params}
        return params

    def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
        Kxx = self.data_kernel(x, x, params['data_params'])
        Kii = gpx.kernels.gram(self.task_kernel, jnp.array(self.task_idxs).reshape(-1, 1), params['task_params'])
        return Kxx*Kii


mo_kernel = MultiOutputKernel(task_kernel=task_kernel, data_kernel=data_kernel)

p = gpx.Prior(kernel= mo_kernel) * gpx.Gaussian(num_datapoints=n_data*n_tasks)
parameter_state = gpx.initialise(p, key)
parameter_state.params

In [None]:
mo_kernel(X[:1, :], X[:1, :], parameter_state.params)

In [None]:
Kxx = gpx.kernels.gram(mo_kernel, X, parameter_state.params).reshape(n_data*n_tasks, n_data*n_tasks)

fig, ax = plt.subplots(figsize=(15, 12))
ax.matshow(Kxx[:20, :20])

In [None]:


noise_matrix = jnp.hstack([jnp.vstack([jnp.eye(n_tasks)*noise]*n_data)]*n_data)
plt.matshow(noise_matrix)

In [None]:
Kxx_noise = Kxx + noise_matrix

In [None]:
D = gpx.Dataset(X=X, y=Y.reshape(-1, 1))

In [None]:
parameter_state.params

In [None]:
import distrax as dx

def mll(
    params: dict,
):
    # Observation noise σ²
    obs_noise = 0.
    Kxx = gpx.kernels.gram(mo_kernel, x, params["kernel"]).reshape(n_data*n_tasks, n_data*n_tasks)
    Kxx += jnp.hstack([jnp.vstack([jnp.eye(n_tasks)*1e-6]*n_data)]*n_data)

    # Σ = (Kxx + Iσ²) = LLᵀ
    Sigma = Kxx +jnp.hstack([jnp.vstack([jnp.eye(n_tasks)*noise]*n_data)]*n_data)
    L = jnp.linalg.cholesky(Sigma)

    # p(y | x, θ), where θ are the model hyperparameters:
    marginal_likelihood = dx.MultivariateNormalTri(
        jnp.atleast_1d(jnp.zeros(n_data*n_tasks)), L
    )

    constant = jnp.array(-1.0)
    return constant * (marginal_likelihood.log_prob(jnp.atleast_1d(Y.reshape(-1, 1).squeeze())).squeeze())


In [None]:
mll(parameter_state.params)

In [None]:
parameter_state.