# Simulation study: state switching and out-of-sample prediction using Wishart processes

In this notebook, we illustrate the second simulation study of our work (https://arxiv.org/abs/2406.04796). In this simulation study, we construct an covariance process between $d=3$ variables from the prior of a Wishart process, and sample $n=600$ observations from a multivariate normal distribution with a mean of zero and the constructed covariance process, of which the first 300 observations are used for training, and the remaining 300 data points for out-of-sample prediction. We generate 10 different datasets with the same underlying covariance process, and compare MCMC, variational inference and SMC.

In [None]:
# Imports
import os
SELECTED_DEVICE = None
os.environ['CUDA_VISIBLE_DEVICES'] = f''
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
import jax
import sys
jax.config.update("jax_default_device", jax.devices()[0])
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrnd
import distrax as dx
import jaxkern as jk
import pandas as pd
import numpy as np
from tensorflow_probability.substrates import jax as tfp
tfb = tfp.bijectors
module_path = os.path.abspath(os.path.join('../bayesianmodels/'))
if module_path not in sys.path:
    sys.path.append(module_path)
from uicsmodels.gaussianprocesses.fullwp import FullLatentWishartModel
from uicsmodels.gaussianprocesses.wputil import vec2tril, tril2vec, construct_wishart, construct_wishart_Lvec
from uicsmodels.gaussianprocesses.likelihoods import AbstractLikelihood
import time
module_path = os.path.abspath(os.path.join('../BANNER/'))
if module_path not in sys.path:
    sys.path.append(module_path)
import gpflow
from gpflow.kernels import SharedIndependent, SquaredExponential, Periodic
from gpflow import Parameter
from gpflow.inducing_variables import InducingPoints, SharedIndependentInducingVariables
import tensorflow as tf
from src.likelihoods.WishartProcessLikelihood import WishartLikelihood
from src.models.WishartProcess import WishartProcess
from util.training_util import run_adam

# Data settings:
n = 600
d = 3

# Generate data:
x = np.reshape(np.linspace(0, 2, n), (-1, 1))
Y = np.zeros((n, 10, d))
true_Sigma = np.zeros((n, d, d))
r = 0.8
for trial in range(10):
    for i in range(n):
        if i % 50 == 0:
            if r == 0.8:
                r = 0.0
            else:
                r = 0.8
        true_Sigma[i] = np.array([[1., r, r], [r, 1, r], [r, r, 1.]])
        Y[i, trial] = np.random.multivariate_normal(mean=np.zeros((d,)), cov=true_Sigma[i])

### MCMC sampling
We start by modelling the covariance process using a Periodic covariance function:

In [None]:
# Model settings:
gpkernel = jk.Periodic()
num_burn = 5000000
num_samples = 1000000
num_thin = 1000
num_chains = 4
num_thin_samples = num_samples // num_thin
_, num_trials, d = Y.shape

for trial in range(num_trials):
    # Get training data and testing data:
    x_train, Y_train = np.tile(x[:300], (1, d)), Y[:300, trial]
    x_test, Y_test = np.tile(x[300:], (1, d)), Y[300:, trial]
    n_train, n_test = Y_train.shape[0], Y_test.shape[0]
    len_l = int(d * (d+1) / 2)
    
    # We combine the samples across multiple chains:
    Sigma_train = np.zeros((num_thin_samples, n_train, d, d)) # array to store samples across chains.
    Sigma_pred = np.zeros((num_thin_samples, n_test, d, d))
    Sigmas_chains = np.zeros((num_thin_samples, num_chains, n_train * len_l))
    for repetition in range(num_chains):
        # Set priors and initialize model:
        priors = dict(kernel = dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp()),
                                    period=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp())),
                      likelihood = dict(L_vec=dx.Normal(loc=jnp.zeros((len_l, )), scale=jnp.ones((len_l, )))))
        model = FullLatentWishartModel(x_train, Y_train, cov_fn=gpkernel, priors=priors)

        # Inference with SMC:
        start = time.time()
        key = jrnd.PRNGKey(repetition)
        states = model.inference(key, mode='gibbs', sampling_parameters=dict(num_burn=num_burn,
                                                                            num_samples=num_samples,
                                                                            num_thin=num_thin))
        end = time.time()

        # Compute posterior distribution:
        Sigma_chain = jax.vmap(construct_wishart_Lvec)(model.states.position['f'], model.states.position['likelihood']['L_vec'])
        Sigmas_chains[:, repetition, :] = np.reshape(jax.vmap(lambda sigma: jax.vmap(lambda s: s[np.triu_indices(d)])(sigma))(Sigma_chain), (num_thin_samples, -1))
        indices = np.random.choice(np.arange(num_thin_samples), size=(num_thin_samples // num_chains,), replace=False)
        Sigma_train[repetition * (num_thin_samples // num_chains):(repetition+1) * (num_thin_samples // num_chains)] = Sigma_chain[indices, :, :, :]
        Sigma_pred = model.predict_Sigma(key, x_test)
        
    # Compute convergence of covariance processes:
    rhat = np.mean(tfp.mcmc.diagnostic.potential_scale_reduction(Sigmas_chains))
    print(rhat)

Next, we use a Locally Periodic covariance function:

In [None]:
# Model settings:
gpkernel = jk.RBF() * jk.Periodic()
num_burn = 5000000
num_samples = 1000000 
num_thin = 1000 
num_chains = 4
num_thin_samples = num_samples // num_thin

_, num_trials, d = Y.shape
for trial in range(num_trials):
    # Get training data and testing data:
    x_train, Y_train = np.tile(x[:300], (1, d)), Y[:300, trial]
    x_test, Y_test = np.tile(x[300:], (1, d)), Y[300:, trial]
    n_train, n_test = Y_train.shape[0], Y_test.shape[0]
    len_l = int(d * (d+1) / 2)
    
    # We combine the samples across multiple chains:
    Sigma_train = np.zeros((num_thin_samples, n_train, d, d)) # array to store samples across chains.
    Sigma_pred = np.zeros((num_thin_samples, n_test, d, d))
    Sigmas_chains = np.zeros((num_thin_samples, num_chains, n_train * len_l))
    for repetition in range(num_chains):
        # Set priors and initialize model:
        priors = dict(kernel = [dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp())),
                                dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp()),
                                     period=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp()))],
                      likelihood = dict(L_vec=dx.Normal(loc=jnp.zeros((len_l, )), scale=jnp.ones((len_l, )))))
        model = FullLatentWishartModel(x_train, Y_train, cov_fn=gpkernel, priors=priors)

        # Inference with SMC:
        start = time.time()
        key = jrnd.PRNGKey(repetition)
        states = model.inference(key, mode='gibbs', sampling_parameters=dict(num_burn=num_burn,
                                                                            num_samples=num_samples,
                                                                            num_thin=num_thin))
        end = time.time()

        # Compute posterior distribution:
        Sigma_chain = jax.vmap(construct_wishart_Lvec)(model.states.position['f'], model.states.position['likelihood']['L_vec'])
        Sigmas_chains[:, repetition, :] = np.reshape(jax.vmap(lambda sigma: jax.vmap(lambda s: s[np.triu_indices(d)])(sigma))(Sigma_chain), (num_thin_samples, -1))
        indices = np.random.choice(np.arange(num_thin_samples), size=(num_thin_samples // num_chains,), replace=False)
        Sigma_train[repetition * (num_thin_samples // num_chains):(repetition+1) * (num_thin_samples // num_chains)] = Sigma_chain[indices, :, :, :]
        Sigma_pred = model.predict_Sigma(key, x_test)
        
    # Compute convergence of covariance processes:
    rhat = np.mean(tfp.mcmc.diagnostic.potential_scale_reduction(Sigmas_chains))
    print(rhat)

### Variational inference
We again first try to model the covariance process using a Periodic covariance function, and then a Locally Periodic covariance function:

In [None]:
# Model settings:
num_iterations = 1000000 # until convergence of the ELBO.
num_samples = 1000
num_initializations = 4
_, num_trials, d = Y.shape
nu = d + 1
latent_dim = int(nu * d)
lengthscale = [np.exp(np.random.normal()) for rep in range(num_initializations)]
period = [np.exp(np.random.normal()) for rep in range(num_initializations)]

for trial in range(num_trials):
    # Get training data and testing data:
    x_train, Y_train = np.tile(x[:300], (1, d)), Y[:300, trial]
    x_test, Y_test = np.tile(x[300:], (1, d)), Y[300:, trial]
    n_train, n_test = Y_train.shape[0], Y_test.shape[0]
    len_l = int(d * (d+1) / 2)
    
    best_elbo = -np.inf 
    Sigmas_initializations = np.zeros((num_samples, num_initializations, n_train * len_l))
    for repetition in range(num_initializations):
        # Set GP kernel function, likelihood, and initialize model:
        kernel = SharedIndependent(Periodic(SquaredExponential(lengthscales=lengthscale[repetition], variance=1.0), period=period[repetition]), output_dim=latent_dim)
        V = np.random.normal(size=int(d * nu // 2)) # initialize scale matrix.
        V = Parameter(V)
        inducing_points = tf.identity(x_train)
        inducing_variable = SharedIndependentInducingVariables(InducingPoints(tf.identity(inducing_points)))
        likelihood = WishartLikelihood(d, nu, A=V, N=n_train, R=3, additive_noise=True, model_inverse=False)
        model = WishartProcess(kernel, likelihood, D=d, nu=nu, inducing_variable=inducing_variable, num_data=n_train)
        gpflow.set_trainable(model.inducing_variable, False) # we do not want to use inducing variables.

        # Inference until convergence of the ELBO:
        start = time.time()
        logf = run_adam(model, (x_train, Y_train), num_iterations, minibatch_size=n_train, learning_rate=0.001)
        end = time.time()

        # Store estimates with the best ELBO:
        Sigma_chain = np.array(model.predict_mc(x_train, num_samples)) 
        if logf[-1] > best_elbo:
            best_elbo = logf[-1]
            Sigma_train = Sigma_chain
            Sigmas_initializations[:, repetition, :] = np.reshape(jax.vmap(lambda sigma: jax.vmap(lambda s: s[np.triu_indices(d)])(sigma))(Sigma_chain), (num_samples, -1))
            Sigma_pred = np.array(model.predict_mc(x_test, num_samples)) 

    # Compute convergence of covariance processes:
    rhat = np.mean(tfp.mcmc.diagnostic.potential_scale_reduction(Sigmas_initializations))
    print(rhat)

In [None]:
# Model settings:
num_iterations = 1000000 # until convergence of the ELBO.
num_samples = 1000
num_initializations = 4
_, num_trials, d = Y.shape
nu = d + 1
latent_dim = int(nu * d)
lengthscale_rbf = [np.exp(np.random.normal()) for rep in range(num_initializations)]
lengthscale_periodic = [np.exp(np.random.normal()) for rep in range(num_initializations)]
period = [np.exp(np.random.normal()) for rep in range(num_initializations)]

for trial in range(num_trials):
    # Get training data and testing data:
    x_train, Y_train = np.tile(x[:300], (1, d)), Y[:300, trial]
    x_test, Y_test = np.tile(x[300:], (1, d)), Y[300:, trial]
    n_train, n_test = Y_train.shape[0], Y_test.shape[0]
    len_l = int(d * (d+1) / 2)
    
    best_elbo = -np.inf 
    Sigmas_initializations = np.zeros((num_samples, num_initializations, n_train * len_l))
    for repetition in range(num_initializations):
        # Set GP kernel function, likelihood, and initialize model:
        kernel = SharedIndependent(SquaredExponential(lengthscales=lengthscale_rbf[repetition], variance=1.0) *
                                   Periodic(SquaredExponential(lengthscales=lengthscale_periodic[repetition], variance=1.0), period=period[repetition]), output_dim=latent_dim)
        V = np.random.normal(size=int(d * nu // 2)) # initialize scale matrix.
        V = Parameter(V)
        inducing_points = tf.identity(x_train)
        inducing_variable = SharedIndependentInducingVariables(InducingPoints(tf.identity(inducing_points)))
        likelihood = WishartLikelihood(d, nu, A=V, N=n_train, R=3, additive_noise=True, model_inverse=False)
        model = WishartProcess(kernel, likelihood, D=d, nu=nu, inducing_variable=inducing_variable, num_data=n_train)
        gpflow.set_trainable(model.inducing_variable, False) # we do not want to use inducing variables.

        # Inference until convergence of the ELBO:
        start = time.time()
        logf = run_adam(model, (x_train, Y_train), num_iterations, minibatch_size=n_train, learning_rate=0.001)
        end = time.time()

        # Store estimates with the best ELBO:
        Sigma_chain = np.array(model.predict_mc(x_train, num_samples)) 
        if logf[-1] > best_elbo:
            best_elbo = logf[-1]
            Sigma_train = Sigma_chain
            Sigmas_initializations[:, repetition, :] = np.reshape(jax.vmap(lambda sigma: jax.vmap(lambda s: s[np.triu_indices(d)])(sigma))(Sigma_chain), (num_samples, -1))
            Sigma_pred = np.array(model.predict_mc(x_test, num_samples)) 

    # Compute convergence of covariance processes:
    rhat = np.mean(tfp.mcmc.diagnostic.potential_scale_reduction(Sigmas_initializations))
    print(rhat)
    break

### SMC sampling
First we will use a Periodic covariance function:

In [None]:
 # Model settings:
gpkernel = jk.Periodic()
num_particles = 1000
num_mcmc_steps = 3000
_, num_trials, d = Y.shape

for trial in range(num_trials):
    # Get training data and testing data:
    x_train, Y_train = np.tile(x[:300], (1, d)), Y[:300, trial]
    x_test, Y_test = np.tile(x[300:], (1, d)), Y[300:, trial]
    n_train, n_test = Y_train.shape[0], Y_test.shape[0]

    # Set priors and initialize model:
    len_l = int(d * (d+1) / 2)
    priors = dict(kernel = dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp()),
                                period=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp())),
                  likelihood = dict(L_vec=dx.Normal(loc=jnp.zeros((len_l, )), scale=jnp.ones((len_l, )))))
    model = FullLatentWishartModel(x_train, Y_train, cov_fn=gpkernel, priors=priors)
    
    # Inference with SMC:
    start = time.time()
    key = jrnd.PRNGKey(10)
    particles, num_iter, lml = model.inference(key, mode='gibbs-in-smc', 
                                                sampling_parameters=dict(num_particles=num_particles, 
                                                                        num_mcmc_steps=num_mcmc_steps))
    end = time.time()

    # Compute posterior distribution and make out-of-sample predictions:
    key = jrnd.PRNGKey(5)
    Sigma_train = jax.vmap(construct_wishart_Lvec)(particles.particles['f'], particles.particles['likelihood']['L_vec'])
    Sigma_pred = model.predict_Sigma(key, x_test)


Next, we multiply a Periodic and RBF covariance function to obtain a Locally Periodic covariance function:

In [None]:
 # Model settings:
gpkernel = jk.RBF() * jk.Periodic()
num_particles = 1000
num_mcmc_steps = 3000
_, num_trials, d = Y.shape

for trial in range(num_trials):
    # Get training data and testing data:
    x_train, Y_train = np.tile(x[:300], (1, d)), Y[:300, trial]
    x_test, Y_test = np.tile(x[300:], (1, d)), Y[300:, trial]
    n_train, n_test = Y_train.shape[0], Y_test.shape[0]

    # Set priors and initialize model:
    len_l = int(d * (d+1) / 2)
    priors = dict(kernel = [dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp())),
                            dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp()),
                                 period=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp()))],
                likelihood = dict(L_vec=dx.Normal(loc=jnp.zeros((len_l, )), scale=jnp.ones((len_l, )))))
    model = FullLatentWishartModel(x_train, Y_train, cov_fn=gpkernel, priors=priors)
    
    # Inference with SMC:
    start = time.time()
    key = jrnd.PRNGKey(10)
    particles, num_iter, lml = model.inference(key, mode='gibbs-in-smc', 
                                                sampling_parameters=dict(num_particles=num_particles, 
                                                                        num_mcmc_steps=num_mcmc_steps))
    end = time.time()

    # Compute posterior distribution:
    key = jrnd.PRNGKey(5)
    Sigma_train = jax.vmap(construct_wishart_Lvec)(particles.particles['f'], particles.particles['likelihood']['L_vec'])
    Sigma_pred = model.predict_Sigma(key, x_test)