# Simulation study: learning the model parameters

In this notebook, we illustrate the first simulation study of our work (https://arxiv.org/abs/2406.04796). In this simulation study, we construct an covariance process between $d=3$ variables from the prior of a Wishart process, and sample $n=300$ observations from a multivariate normal distribution with a mean of zero and the constructed covariance process. We generate 10 different datasets with the same underlying covariance process. 

In [None]:
# Imports
import os
SELECTED_DEVICE = None
os.environ['CUDA_VISIBLE_DEVICES'] = f''
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
import jax
import sys
jax.config.update("jax_default_device", jax.devices()[0])
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrnd
import distrax as dx
import jaxkern as jk
import pandas as pd
import numpy as np
from factor_analyzer import FactorAnalyzer
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from tensorflow_probability.substrates import jax as tfp
#import tensorflow_probability as tfp
tfb = tfp.bijectors
module_path = os.path.abspath(os.path.join('../bayesianmodels/'))
if module_path not in sys.path:
    sys.path.append(module_path)
from uicsmodels.gaussianprocesses.fullwp import FullLatentWishartModel
from uicsmodels.gaussianprocesses.wputil import vec2tril, tril2vec, construct_wishart, construct_wishart_Lvec
from uicsmodels.gaussianprocesses.likelihoods import AbstractLikelihood
from uicsmodels.gaussianprocesses.gputil import sample_predictive
import time
module_path = os.path.abspath(os.path.join('../BANNER/'))
if module_path not in sys.path:
    sys.path.append(module_path)
import gpflow
from gpflow.kernels import SharedIndependent, SquaredExponential, Matern12
from gpflow import Parameter
from gpflow.inducing_variables import InducingPoints, SharedIndependentInducingVariables
import tensorflow as tf
from src.likelihoods.WishartProcessLikelihood import WishartLikelihood
from src.models.WishartProcess import WishartProcess
from util.training_util import run_adam
from scipy.stats import multivariate_normal

# Data/model settings:
gpkernel = jk.RBF()
n = 300
d = 3

# Generate data:
x = np.reshape(np.linspace(0, 1, n), (-1, 1))
Y = np.zeros((n, 10, d))
len_l = int(d * (d+1) / 2)
priors = dict(kernel=dict(lengthscale=dx.Uniform(0.34999, 0.35001)), # set RBF lengthscale to 0.35.
                likelihood=dict(L_vec=dx.Normal(loc=jnp.zeros((len_l, )), scale=jnp.ones((len_l, )))))
model = FullLatentWishartModel(x, Y[:, 0, :], cov_fn=gpkernel, priors=priors)
particles = model.init_fn(jrnd.PRNGKey(10), 10)
true_Sigma = jax.vmap(construct_wishart_Lvec, in_axes=(0, None))(particles.position['f'], jnp.array([1., 0., 1., 0., 0., 1.]))
for rep in range(10):
    for i in range(n):
        Y[i, rep, :] = np.random.multivariate_normal(mean=np.zeros((d,)), cov=np.array(true_Sigma[rep, i, :, :]))

### MCMC sampling

In [None]:
# Model settings:
gpkernel = jk.RBF()
num_burn = 4000000
num_samples = 1000000
num_thin = 1000
num_chains = 4
num_thin_samples = num_samples // num_thin

n, num_trials, d = Y.shape
for trial in range(num_trials):
    # Get training data:
    x_train, Y_train = x, Y[:, trial, :]
    len_l = int(d * (d+1) / 2)
    
    # We combine the samples across multiple chains:
    Sigma_train = np.zeros((num_thin_samples, n, d, d)) # array to store samples across chains.
    Sigmas_chains = np.zeros((num_thin_samples, num_chains, n * len_l))
    for repetition in range(num_chains):
        # Set priors and initialize model:
        priors = dict(kernel = dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp())),
                      likelihood = dict(L_vec=dx.Normal(loc=jnp.zeros((len_l, )), scale=jnp.ones((len_l, )))))
        model = FullLatentWishartModel(x_train, Y_train, cov_fn=gpkernel, priors=priors)

        # Inference with SMC:
        start = time.time()
        key = jrnd.PRNGKey(repetition)
        states = model.inference(key, mode='gibbs', sampling_parameters=dict(num_burn=num_burn,
                                                                            num_samples=num_samples,
                                                                            num_thin=num_thin))
        end = time.time()

        # Compute posterior distribution:
        Sigma_chain = jax.vmap(construct_wishart_Lvec)(model.states.position['f'], model.states.position['likelihood']['L_vec'])
        Sigmas_chains[:, repetition, :] = np.reshape(jax.vmap(lambda sigma: jax.vmap(lambda s: s[np.triu_indices(d)])(sigma))(Sigma_chain), (num_thin_samples, -1))
        indices = np.random.choice(np.arange(num_thin_samples), size=(num_thin_samples // num_chains,), replace=False)
        Sigma_train[repetition * (num_thin_samples // num_chains):(repetition+1) * (num_thin_samples // num_chains)] = Sigma_chain[indices, :, :, :]

    # Compute convergence of covariance processes:
    rhat = np.mean(tfp.mcmc.diagnostic.potential_scale_reduction(Sigmas_chains))
    print(rhat)

### Variational inference

In [None]:
# Model settings:
num_iterations = 1000000 # until convergence of the ELBO.
num_samples = 1000
num_initializations = 4
n, num_trials, d = Y.shape
nu = d + 1
latent_dim = int(nu * d)

# 10 fold cross validation:
for trial in range(num_trials):
    # Get training data:
    x_train, Y_train = np.tile(x, (1, d)), Y[:, trial, :]
    
    best_elbo = -np.inf 
    Sigmas_initializations = np.zeros((num_samples, num_initializations, n * (latent_dim // 2)))
    
    for repetition in range(num_initializations):
        # Set GP kernel function, likelihood, and initialize model:
        lengthscale_rbf = np.exp(np.random.normal())
        gpkernel = SharedIndependent(SquaredExponential(lengthscales=lengthscale_rbf, variance=1.0), output_dim=latent_dim) 
        V = np.random.normal(size=int(d * nu // 2)) # initialize scale matrix.
        V = Parameter(V)
        inducing_points = tf.identity(x_train)
        inducing_variable = SharedIndependentInducingVariables(InducingPoints(tf.identity(inducing_points)))
        likelihood = WishartLikelihood(d, nu, A=V, N=n, R=3, additive_noise=True, model_inverse=False, model_mean=True)
        model = WishartProcess(gpkernel, likelihood, D=d, nu=nu, inducing_variable=inducing_variable, num_data=n)
        gpflow.set_trainable(model.inducing_variable, False) # we do not want to use inducing variables.

        # Inference until convergence of the ELBO:
        start = time.time()
        logf = run_adam(model, (x_train, Y_train), num_iterations, minibatch_size=n, learning_rate=0.001)
        end = time.time()

        # Store estimates with the best ELBO:
        Sigma_train_repetition = model.predict_mc(x_train, num_samples)
        Sigmas_initializations[:, repetition, :] = np.reshape(jax.vmap(lambda sigma: jax.vmap(lambda s: s[np.triu_indices(d)])(sigma))(np.array(Sigma_train_repetition)), (num_samples, -1))
        if logf[-1] > best_elbo:
            best_elbo = logf[-1]
            Sigma_train = np.array(Sigma_train_repetition)
    
    # Compute convergence of covariance processes:
    rhat = np.mean(tfp.mcmc.diagnostic.potential_scale_reduction(Sigmas_initializations))
    print(rhat)
        

### Sequential Monte Carlo sampling

In [None]:
# Model settings:
gpkernel = jk.RBF()
num_particles = 1000
num_mcmc_steps = 2000

n_train, num_trials, d = Y.shape
for trial in range(num_trials):
    # Get training data:
    x_train, Y_train = x, Y[:, trial, :]
    
    # Set priors and initialize model:
    len_l = int(d * (d+1) / 2)
    priors = dict(kernel = dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp())),
                  likelihood = dict(L_vec=dx.Normal(loc=jnp.zeros((len_l, )), scale=jnp.ones((len_l, )))))
    model = FullLatentWishartModel(x_train, Y_train, cov_fn=gpkernel, priors=priors)
    
    # Inference with SMC:
    start = time.time()
    key = jrnd.PRNGKey(10)
    particles, num_iter, lml = model.inference(key, mode='gibbs-in-smc', 
                                                sampling_parameters=dict(num_particles=num_particles, 
                                                                        num_mcmc_steps=num_mcmc_steps))
    end = time.time()

    # Compute posterior distribution:
    key = jrnd.PRNGKey(5)
    Sigma_train = jax.vmap(construct_wishart_Lvec)(particles.particles['f'], particles.particles['likelihood']['L_vec'])