### Discontinuity analysis and spectral mixture kerels


In [None]:
!pip install git+ssh://git@github.com/UncertaintyInComplexSystems/bayesianmodels.git
!pip install numpy==1.23.5

In [7]:
import matplotlib.pyplot as plt

import jax
import jax.random as jrnd
import jax.numpy as jnp
import distrax as dx
import jaxkern as jk

from jax.config import config
config.update("jax_enable_x64", True)  # crucial for Gaussian processes
config.update("jax_default_device", jax.devices()[0])

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

from uicsmodels.gpmodels.fullgpmodels import FullLatentGPModel, FullMarginalGPModel
from uicsmodels.gpmodels.kernels import Discontinuous

In [None]:
def disc_plot(ax, x, y, x0, *args, **kwargs):
    if jnp.ndim(y) == 1:
        y = y[:, jnp.newaxis]
    if jnp.ndim(x) == 1:
        x = x[:, jnp.newaxis]
    ix_pre = x[:, 0] < x0
    ix_post = x[:, 0] >= x0
    ax.plot(x[ix_pre, 0], y[ix_pre, 0], *args, **kwargs)
    ax.plot(x[ix_post, 0], y[ix_post, 0], *args, **kwargs)

#
def disc_fill_between(ax, x, y1, y2, x0, *args, **kwargs):
    if jnp.ndim(y1) == 1:
        y1 = y1[:, jnp.newaxis]
    if jnp.ndim(y2) == 1:
        y2 = y2[:, jnp.newaxis]
    if jnp.ndim(x) == 1:
        x = x[:, jnp.newaxis]
    ix_pre = x[:, 0] < x0
    ix_post = x[:, 0] >= x0
    ax.fill_between(x[ix_pre, 0], y1[ix_pre, 0], y2[ix_pre, 0], **kwargs)
    ax.fill_between(x[ix_post, 0], y1[ix_post, 0], y2[ix_post, 0], **kwargs)

#

Simulate some data from a known GP with a discontinuity at $x_0$:

In [None]:
key = jrnd.PRNGKey(56)
key, key_x, key_y = jrnd.split(key, 3)

d = 2.0
f = lambda x, x0: jnp.sin(jnp.pi*x) + jnp.cos(4/7*x) + d*(x>=x0)
sigma = 0.3

x = jnp.linspace(0, 2*jnp.pi, num=200)
x0 = jnp.pi
n = 100

x_obs = 2*jnp.pi*jnp.sort(jrnd.uniform(key_x, shape=(n,)))
y = f(x_obs, x0) + sigma*jrnd.normal(key_y, shape=(n,))

plt.figure(figsize=(10, 6))
ax = plt.gca()
disc_plot(ax, x, f(x, x0), x0, 'k', label=r'$f$')
ax.plot(x_obs, y, 'rx', label='obs')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_xlim([0., 2*jnp.pi])
ax.axvline(x=x0, c='k')
ax.legend(loc='best');

Set up the alternative model (with a discontinuity at $x_0$) $\mathcal{M}_1$:

In [71]:
%%time
priors = dict(kernel=dict(lengthscale=dx.Transformed(dx.Normal(loc=0.,
                                                               scale=1.),
                                                     tfb.Exp()),
                          variance=dx.Transformed(dx.Normal(loc=0.,
                                                            scale=1.),
                                                  tfb.Exp())),
              likelihood=dict(obs_noise=dx.Transformed(dx.Normal(loc=0.,
                                                                 scale=1.),
                                                       tfb.Exp())))

# We could set up a similar dict with step sizes and other MCMC parameters
gpdisc = FullMarginalGPModel(x_obs, y,
                         cov_fn=Discontinuous(base_kernel=jk.RBF(),
                                              x0=x0),
                         priors=priors)

key, key_smc = jrnd.split(key)
particles, num_iter, log_marginal_likelihood_disc = gpdisc.inference(key_smc,
                                                                     mode='smc',
                                                                     sampling_parameters=dict(num_particles=1_000, num_mcmc_steps=100))

CPU times: user 16 s, sys: 69.5 ms, total: 16.1 s
Wall time: 16.6 s


Set up the null model (without a discontinuity) $\mathcal{M}_0$:

In [None]:
%%time
priors = dict(kernel=dict(lengthscale=dx.Transformed(dx.Normal(loc=0.,
                                                               scale=1.),
                                                     tfb.Exp()),
                          variance=dx.Transformed(dx.Normal(loc=0.,
                                                            scale=1.),
                                                  tfb.Exp())),
              likelihood=dict(obs_noise=dx.Transformed(dx.Normal(loc=0.,
                                                                 scale=1.),
                                                       tfb.Exp())))

# We could set up a similar dict with step sizes

gpcont = FullMarginalGPModel(x_obs, y,
                         cov_fn=jk.RBF(),
                         priors=priors)

key, key_smc = jrnd.split(key)
particles_cont, _, log_marginal_likelihood_cont = gpcont.inference(key_smc,
                                                                   mode='smc',
                                                                   sampling_parameters=dict(num_particles=1_000, num_mcmc_steps=50))

Plot the model fits and show the log Bayes factor of
$
    \log \text{BF}_{10} = \log \frac{p(D\mid \mathcal{M}_1)}{p(D\mid \mathcal{M}_0)} = \log p(D\mid \mathcal{M}_1) - \log p(D\mid \mathcal{M}_0) \enspace.
$

In [None]:
colors = ['#025464', '#E57C23']
labels = ['$\mathcal{M}_0$', '$\mathcal{M}_1$']

log_BF = log_marginal_likelihood_disc - log_marginal_likelihood_cont

key, key_cont, key_disc = jrnd.split(key, 3)
epsilon = 1e-6

x_pred = jnp.sort(jnp.append(jnp.linspace(0, 2*jnp.pi, 200), jnp.array([x0 - epsilon, x0 + epsilon])))

plt.figure(figsize=(10, 5))
ax = plt.gca()

post_pred = gpcont.predict_f(key_cont, x_pred=x_pred)
post_pred_mu = jnp.mean(post_pred, axis=0)
post_pred_hdi_lb, post_pred_hdi_ub = jnp.quantile(post_pred, axis=0,
                                                  q=jnp.array([0.025, 0.975]))
ax.plot(x_pred, post_pred_mu, color=colors[0], lw=2, label=labels[0])
ax.fill_between(jnp.squeeze(x_pred), post_pred_hdi_lb, post_pred_hdi_ub,
                    color=colors[0], alpha=0.3)

post_pred = gpdisc.predict_f(key_disc, x_pred=x_pred)
post_pred_mu = jnp.mean(post_pred, axis=0)
post_pred_hdi_lb, post_pred_hdi_ub = jnp.quantile(post_pred, axis=0,
                                                  q=jnp.array([0.025, 0.975]))

disc_plot(ax, x_pred, post_pred_mu, x0, color=colors[1], lw=2, label=labels[1])
disc_fill_between(ax, x_pred, post_pred_hdi_lb, post_pred_hdi_ub, x0, 
                  color=colors[1], alpha=0.3)

ax.plot(x_obs, y, 'rx')
ax.set_xlabel(r'$x$')
ax.set_ylabel(r'$y$')
ax.set_xlim([0, 2*jnp.pi])
ax.axvline(x=x0, c='k', ls=':')
ax.set_title('Log $BF_{{10}}={:0.2f}$'.format(log_BF));