# Dynamics between mental states over time

In this notebook, we illustrate how to use the Wishart process to model the dynamic covariance between five mental states over time. The dataset that we use in this notebook is available at https://osf.io/j4fg8/ and is described in the following paper:

Kossakowski, J. J., Groot, P. C., Haslbeck, J. M., Borsboom, D., & Wichers, M. (2017). Data from ‘critical slowing down as a personalized early warning signal for depression’. Journal of Open Psychology Data, 5(1), 1-1.

### Preprocessing
The data comes from a single subject who has been diagnosed with Major Depressive Disorder and monitored his mental state over the course of 237 days by filling in a questionnaire of daily life experiences several times a day. The subject had been using venlafaxine for 8.5 years, and this dosage is reduced gradually. We based our pre-processing on the following paper:

Wichers, M., Groot, P. C., Psychosystems, E. S. M., & EWS Group. (2016). Critical slowing down as a personalized early warning signal for depression. Psychotherapy and psychosomatics, 85(2), 114-116.

In [None]:
# Imports
import os
SELECTED_DEVICE = None
os.environ['CUDA_VISIBLE_DEVICES'] = f''
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
import jax
import sys
jax.config.update("jax_default_device", jax.devices()[0])
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrnd
import distrax as dx
import jaxkern as jk
import pandas as pd
import numpy as np
from factor_analyzer import FactorAnalyzer
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from tensorflow_probability.substrates import jax as tfp
tfb = tfp.bijectors
module_path = os.path.abspath(os.path.join('../bayesianmodels/'))
if module_path not in sys.path:
    sys.path.append(module_path)
from uicsmodels.gaussianprocesses.fullwp import FullLatentWishartModel
from uicsmodels.gaussianprocesses.wputil import vec2tril, tril2vec, construct_wishart, construct_wishart_Lvec
from uicsmodels.gaussianprocesses.likelihoods import AbstractLikelihood
from uicsmodels.gaussianprocesses.gputil import sample_predictive
import time
module_path = os.path.abspath(os.path.join('../BANNER/'))
if module_path not in sys.path:
    sys.path.append(module_path)
import gpflow
from gpflow.kernels import SharedIndependent, SquaredExponential, Matern12
from gpflow import Parameter
from gpflow.inducing_variables import InducingPoints, SharedIndependentInducingVariables
import tensorflow as tf
from src.likelihoods.WishartProcessLikelihood import WishartLikelihood
from src.models.WishartProcess import WishartProcess
from util.training_util import run_adam
from scipy.stats import multivariate_normal

# Load and select data:
data = pd.read_csv('/home/heshui/bayesianmodels/simulations/ESMdata.csv')
columns_affect = ['mood_irritat', 'mood_satisfi', 'mood_lonely', 'mood_anxious', 'mood_enthus', 
                  'mood_cheerf', 'mood_guilty', 'mood_doubt', 'mood_strong', 'pat_restl', 'pat_agitate']
data_subset = data[['concentrat'] + columns_affect + ['pat_worry', 'mood_suspic']].dropna()
x = np.array(data_subset['concentrat'])

# A few variables have Likert scales from -3 to 3, so we make sure all Likert scales have the same range:
other_likert_scale = ['mood_lonely', 'mood_anxious', 'mood_guilty']
data_subset[other_likert_scale] += 4

# Apply PCA with oblique rotation:
fa = FactorAnalyzer(method='principal', rotation='promax')
fa.fit(data_subset[columns_affect].to_numpy())
Y_pca = fa.transform(data_subset[columns_affect].to_numpy())
Y = np.concatenate((Y_pca, data_subset[['pat_worry', 'mood_suspic']].to_numpy()), axis=1) # Append suspicious and worrying.

# Remove slow non-linear time trends:
pf = PolynomialFeatures(degree=5)
xp = pf.fit_transform(np.tile(x, (Y.shape[1], 1)).T)
md2 = LinearRegression()
md2.fit(xp, Y)
trend = md2.predict(xp)
Y = Y - trend

# Scale day numbers to range [0, 1]:
minx = np.min(x)
maxx = np.max(x)
x = (x - minx) / (maxx - minx)

# We want to use only 25% of the data, to select a subset of the data: 
idx = np.arange(0, Y.shape[0], 4)
Y = Y[idx, :]
x = x[idx]
x = np.reshape(x, (-1, 1))

### Define likelihood with EMA mean
We assume $Y_i \sim \mathcal{MVN}_d \left( \mu_i, \Sigma_i \right)$ with $\mu_i$ being an exponential moving average function: $\text{EMA}(y_{i+1, j}) = \alpha [y_{ij} + (1 - \alpha)y_{i-1, j} + (1 - \alpha)^2 y_{i-2,j} + \ldots + (1 - \alpha)^{k-1}y_{i-(k-1),j}] $ and $k=10$.

In [None]:
# Define a mean function (this is mu_i for Y_i ~ MVN(mu_i, Sigma_i)):
def ema(Y, k=10):
    """ This function implements an exponential moving average mean function.  

    Args:
        Y: a matrix with observations of shape (number of observations, number of variables).
        k: an integer representing the number of previous observations to take into account for computing the EMA (optional).

    Returns:
        A matrix of shape (number of observations, number of variables) with the moving average.
    """
    alpha = 2 / (k + 1)
    n = Y.shape[0]
    exponents = jnp.power(1 - alpha, jnp.arange(k))
    moving_average = jax.vmap(lambda y_d: alpha * jnp.convolve(y_d, exponents)[:n], in_axes=(1, ))(Y)
    moving_average = moving_average.T
    return moving_average


class Wishart_with_EMA(AbstractLikelihood):

    def __init__(self, nu, d, Y, k=10):
        self.nu = nu
        self.d = d
        self.mean = ema(Y, k=k)

    def link_function(self, f):
        """Identity function
        """
        return f

    def likelihood(self, params, f=None, Sigma=None, mean=None):
        assert f is not None or Sigma is not None, 'Provide either f or Sigma'
        if Sigma is None:
            if jnp.ndim(f):
                f = jnp.reshape(f, (-1, self.nu, self.d))
            L_vec = params['L_vec']
            L = vec2tril(L_vec, self.d)
            Sigma = construct_wishart(F=f, L=L)
        mean = self.mean if mean is None else mean
        return dx.MultivariateNormalFullCovariance(loc=mean, covariance_matrix=Sigma)

### 10-fold cross validation with MCMC

In [None]:
# Array to save performance over 10 folds:
ll_mcmc = np.zeros((10))

# Model settings:
gpkernel = jk.RBF() + jk.Matern12()
num_burn = 5000000
num_samples = 1000000 
num_thin = 1000
num_chains = 4
num_thin_samples = num_samples // num_thin

# 10 fold cross validation:
n, d = Y.shape
for fold in range(1, 11):
    # Get training and testing data:
    n_train = (n // 10) * fold
    x_train, Y_train = x[:n_train], Y[:n_train]
    x_test, Y_test = x[n_train:n_train+10], Y[n_train:n_train+10]
    
    # We combine the samples across multiple chains:
    Sigma_pred = np.zeros((num_thin_samples, Y_test.shape[0], d, d)) # array to store samples across chains.
    Sigmas_chains = np.zeros((num_thin_samples, num_chains, Y_test.shape[0] * len_l)) # used for convergence.
    for repetition in range(num_chains):
        # Set priors and initialize model:
        len_l = int(d * (d+1) / 2)
        priors = dict(kernel = [dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp())),
                                dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp()))],
                    likelihood = dict(L_vec=dx.Normal(loc=jnp.zeros((len_l, )), scale=jnp.ones((len_l, )))))
        model = FullLatentWishartModel(x_train, Y_train, cov_fn=gpkernel, priors=priors)
        model.likelihood = Wishart_with_EMA(d+1, d, Y_train)

        # Inference with SMC:
        start = time.time()
        key = jrnd.PRNGKey(repetition)
        states = model.inference(key, mode='gibbs', sampling_parameters=dict(num_burn=num_burn,
                                                                            num_samples=num_samples,
                                                                            num_thin=num_thin))
        end = time.time()

        # Compute posterior distribution:
        key, subkey = jrnd.split(key, 2)
        keys = jrnd.split(subkey, num_samples)
        f_pred = jax.vmap(sample_predictive, 
                        in_axes=(0, None, None, 0, None, None, 0))(keys, 
                                                                    x_train[:, 0], x_test[:, 0],
                                                                    model.states.position['f'], gpkernel, None, 
                                                                    [{'lengthscale': model.states.position['kernel'][0]['lengthscale'], 'variance': np.ones((1000,))},
                                                                    {'lengthscale': model.states.position['kernel'][1]['lengthscale'], 'variance': np.ones((1000,))}])                                                                                                                                                      # {'lengthscale': result['lengthscale2'][0, :], 'variance': np.ones((1000,))}])
        Sigma_chain = jax.vmap(construct_wishart_Lvec)(f_pred, model.states.position['likelihood']['L_vec'])
        Sigmas_chains[:, repetition, :] = np.reshape(jax.vmap(lambda sigma: jax.vmap(lambda s: s[np.triu_indices(d)])(sigma))(Sigma_chain), (num_thin_samples, -1))
        indices = np.random.choice(np.arange(num_thin_samples), size=(num_thin_samples // num_chains,), replace=False)
        Sigma_pred[repetition * (num_thin_samples // num_chains):(repetition+1) * (num_thin_samples // num_chains)] = Sigma_chain[indices, :, :, :]

    # Compute convergence of covariance processes:
    rhat = np.mean(tfp.mcmc.diagnostic.potential_scale_reduction(Sigmas_chains))
    print(rhat)
    
    # Get posterior predictive distribution (using the EMA function):
    Y_pred = np.zeros((num_thin_samples, Y_test.shape[0], d))
    mean_distribution = np.zeros((num_thin_samples, Y_test.shape[0], d))
    for s in range(num_thin_samples): 
        Y_temp = Y_train
        for i in range(Y_test.shape[0]): # we predict Y_test one step ahead, and then update the EMA.
            mu = ema(Y_temp)
            mean_distribution[s, i, :] = mu[-1, :]
            Y_pred[s, i, :] = np.random.multivariate_normal(mu[-1, :], Sigma_pred[s, i])
            Y_temp = np.concatenate((Y_temp, np.reshape(Y_pred[s, i, :], (1, -1))), axis=0)
    ll_mcmc[fold-1] = np.sum(model.likelihood.likelihood(None, Sigma=np.mean(Sigma_pred, axis=0), mean=np.mean(mean_distribution, axis=0)).log_prob(Y_test))

### 10-fold cross validation with variational inference

In [None]:
# Define log_likelihood
def log_likelihood(y_true, cov_est, mean_est=None):
    cov_est = np.mean(cov_est, axis=0)
    ll = np.zeros((y_true.shape[0]))
    for n in range(y_true.shape[0]):
        ll[n] = multivariate_normal.logpdf(y_true[n], mean_est[n], cov_est[n])
    return np.sum(ll) / y_true.shape[0]


# Array to save performance over 10 folds:
ll_vi = np.zeros((10))
lengthscale_rbf = [np.exp(np.random.normal()) for rep in range(4)]
lengthscale_m12 = [np.exp(np.random.normal()) for rep in range(4)]

# Model settings:
num_iterations = 1000000 # until convergence of the ELBO.
num_samples = 1000
num_initializations = 4
n, d = Y.shape
nu = d + 1
latent_dim = int(nu * d)

# 10 fold cross validation:
for fold in range(1, 11):
    # Get training and testing data:
    n_train = (n // 10) * fold
    x_train, Y_train = x[:n_train], Y[:n_train]
    x_train = np.tile(x_train, (1, d))
    x_test, Y_test = x[n_train:n_train+10], Y[n_train:n_train+10]
    x_test = np.tile(x_test, (1, d))
    len_l = int(d * (d+1) / 2)
    
    best_elbo = -np.inf
    Sigmas_initializations = np.zeros((num_samples, num_initializations, n_train * len_l))
    for repetition in range(num_initializations):
        # Set GP kernel function, likelihood, and initialize model:
        gpkernel = SharedIndependent(SquaredExponential(lengthscales=lengthscale_rbf[repetition], variance=1.0)
                                     + Matern12(lengthscales=lengthscale_m12[repetition], variance=1.0), output_dim=latent_dim) 
        V = np.random.normal(size=int(d * nu // 2)) # initialize scale matrix.
        V = Parameter(V)
        inducing_points = tf.identity(x_train)
        inducing_variable = SharedIndependentInducingVariables(InducingPoints(tf.identity(inducing_points)))
        likelihood = WishartLikelihood(d, nu, A=V, N=n_train, R=3, additive_noise=True, model_inverse=False, model_mean=True) # mean function.
        model = WishartProcess(gpkernel, likelihood, D=d, nu=nu, inducing_variable=inducing_variable, num_data=n_train)
        gpflow.set_trainable(model.inducing_variable, False) # we do not want to use inducing variables.

        # Inference until convergence of the ELBO:
        start = time.time()
        logf = run_adam(model, (x_train, Y_train), num_iterations, minibatch_size=n_train, learning_rate=0.001)
        end = time.time()

        # Store estimates with the best ELBO:
        Sigma_chain = np.array(model.predict_mc(x_train, num_samples)) 
        if logf[-1] > best_elbo:
            best_elbo = logf[-1]
            Sigma_train = Sigma_chain
            Sigmas_initializations[:, repetition, :] = np.reshape(jax.vmap(lambda sigma: jax.vmap(lambda s: s[np.triu_indices(d)])(sigma))(Sigma_chain), (num_samples, -1))
            Sigma_pred = np.array(model.predict_mc(x_test, num_samples)) 

    # Get posterior predictive distribution (using the EMA function):
    Y_pred = np.zeros((num_samples, Y_test.shape[0], d))
    mean_distribution = np.zeros((num_samples, Y_test.shape[0], d))
    for s in range(num_samples): 
        Y_temp = Y_train
        for i in range(Y_test.shape[0]): # we predict Y_test one step ahead, and then update the EMA.
            mu = ema(Y_temp)
            mean_distribution[s, i, :] = mu[-1, :]
            Y_pred[s, i, :] = np.random.multivariate_normal(mu[-1, :], Sigma_pred[s, i])
            Y_temp = np.concatenate((Y_temp, np.reshape(Y_pred[s, i, :], (1, -1))), axis=0)

    # Compute convergence of covariance processes and log-likelihood:
    rhat = np.mean(tfp.mcmc.diagnostic.potential_scale_reduction(Sigmas_initializations))
    print(rhat)
    ll_vi[fold-1] = log_likelihood(Y_test, Sigma_pred, np.mean(mean_distribution, axis=0))

### 10-fold cross validation with SMC

In [None]:
# Array to save performance over 10 folds:
ll_smc = np.zeros((10))

# Model settings:
gpkernel = jk.RBF() + jk.Matern12()
num_particles = 1000
num_mcmc_steps = 5000

# 10 fold cross validation:
n, d = Y.shape
for fold in range(1, 11):
    # Get training and testing data:
    n_train = (n // 10) * fold
    x_train, Y_train = x[:n_train], Y[:n_train]
    x_test, Y_test = x[n_train:n_train+10], Y[n_train:n_train+10]
    
    # Set priors and initialize model:
    len_l = int(d * (d+1) / 2)
    priors = dict(kernel = [dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp())),
                            dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp()))],
                likelihood = dict(L_vec=dx.Normal(loc=jnp.zeros((len_l, )), scale=jnp.ones((len_l, )))))
    model = FullLatentWishartModel(x_train, Y_train, cov_fn=gpkernel, priors=priors)
    model.likelihood = Wishart_with_EMA(d+1, d, Y_train)

    # Inference with SMC:
    start = time.time()
    key = jrnd.PRNGKey(10)
    particles, num_iter, lml = model.inference(key, mode='gibbs-in-smc', 
                                                sampling_parameters=dict(num_particles=num_particles, 
                                                                        num_mcmc_steps=num_mcmc_steps))
    end = time.time()

    # Compute posterior distribution and the posterior predictive distribution (using the EMA function):
    key = jrnd.PRNGKey(5)
    Sigma_pred = model.predict_Sigma(key, x_test)
    Y_pred = np.zeros((num_particles, Y_test.shape[0], d))
    mean_distribution = np.zeros((num_particles, Y_test.shape[0], d))
    for s in range(num_particles): 
        Y_temp = Y_train
        for i in range(Y_test.shape[0]): # we predict Y_test one step ahead, and then update the EMA.
            mu = ema(Y_temp)
            mean_distribution[s, i, :] = mu[-1, :]
            Y_pred[s, i, :] = np.random.multivariate_normal(mu[-1, :], Sigma_pred[s, i])
            Y_temp = np.concatenate((Y_temp, np.reshape(Y_pred[s, i, :], (1, -1))), axis=0)
    ll_smc[fold-1] = np.sum(model.likelihood.likelihood(None, Sigma=np.mean(Sigma_pred, axis=0), mean=np.mean(mean_distribution, axis=0)).log_prob(Y_test))