# Dynamics between mental states as a function of venlafaxine dosage

In this notebook, we illustrate how to use the Wishart process to model the dynamic covariance between five mental states using venlafaxine dosage, an antidepressant, as an input variable. The dataset that we use here is available at https://osf.io/j4fg8/ and is described in the following paper:

Kossakowski, J. J., Groot, P. C., Haslbeck, J. M., Borsboom, D., & Wichers, M. (2017). Data from ‘critical slowing down as a personalized early warning signal for depression’. Journal of Open Psychology Data, 5(1), 1-1.

### Preprocessing
The data comes from a single subject who has been diagnosed with Major Depressive Disorder and monitored his mental state over the course of 237 days by filling in a questionnaire of daily life experiences several times a day.  the subject had been using venlafaxine for 8.5 years, and this dosage is reduced gradually. We based our pre-processing on the following paper:

Wichers, M., Groot, P. C., Psychosystems, E. S. M., & EWS Group. (2016). Critical slowing down as a personalized early warning signal for depression. Psychotherapy and psychosomatics, 85(2), 114-116.

In [None]:
# Imports
import os
SELECTED_DEVICE = None
os.environ['CUDA_VISIBLE_DEVICES'] = f''
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
import jax
import sys
jax.config.update("jax_default_device", jax.devices()[0])
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrnd
import distrax as dx
import jaxkern as jk
import pandas as pd
import numpy as np
from factor_analyzer import FactorAnalyzer
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from tensorflow_probability.substrates import jax as tfp
tfb = tfp.bijectors
module_path = os.path.abspath(os.path.join('../bayesianmodels/'))
if module_path not in sys.path:
    sys.path.append(module_path)
from uicsmodels.gaussianprocesses.fullwp import FullLatentWishartModel
from uicsmodels.gaussianprocesses.wputil import vec2tril, tril2vec, construct_wishart
from uicsmodels.gaussianprocesses.likelihoods import AbstractLikelihood
import time

# Load and select data:
data = pd.read_csv('/home/heshui/bayesianmodels/simulations/ESMdata.csv')
columns_affect = ['mood_irritat', 'mood_satisfi', 'mood_lonely', 'mood_anxious', 'mood_enthus', 
                  'mood_cheerf', 'mood_guilty', 'mood_doubt', 'mood_strong', 'pat_restl', 'pat_agitate']
data_subset = data[['concentrat'] + columns_affect + ['pat_worry', 'mood_suspic']].dropna()
x = np.array(data_subset['concentrat'])

# A few variables have Likert scales from -3 to 3, so we make sure all Likert scales have the same range:
other_likert_scale = ['mood_lonely', 'mood_anxious', 'mood_guilty']
data_subset[other_likert_scale] += 4

# Apply PCA with oblique rotation:
fa = FactorAnalyzer(method='principal', rotation='promax')
fa.fit(data_subset[columns_affect].to_numpy())
Y_pca = fa.transform(data_subset[columns_affect].to_numpy())
Y = np.concatenate((Y_pca, data_subset[['pat_worry', 'mood_suspic']].to_numpy()), axis=1) # Append suspicious and worrying.

# Remove slow non-linear time trends:
pf = PolynomialFeatures(degree=5)
xp = pf.fit_transform(np.tile(x, (Y.shape[1], 1)).T)
md2 = LinearRegression()
md2.fit(xp, Y)
trend = md2.predict(xp)
Y = Y - trend

# Scale day numbers to range [0, 1]:
minx = np.min(x)
maxx = np.max(x)
x = (x - minx) / (maxx - minx)

# We want to use only 25% of the data, to select a subset of the data: 
idx = np.arange(0, Y.shape[0], 4)
Y = Y[idx, :]
x = x[idx]
x = np.reshape(x, (-1, 1))

### Define likelihood with EMA mean
We assume $Y_i \sim \mathcal{MVN}_d \left( \mu_i, \Sigma_i \right)$ with $\mu_i$ being an exponential moving average function: $\text{EMA}(y_{i+1, j}) = \alpha [y_{ij} + (1 - \alpha)y_{i-1, j} + (1 - \alpha)^2 y_{i-2,j} + \ldots + (1 - \alpha)^{k-1}y_{i-(k-1),j}] $ and $k=10$.

In [None]:
# Define a mean function (this is mu_i for Y_i ~ MVN(mu_i, Sigma_i)):
def ema(Y, k=10):
    """ This function implements an exponential moving average mean function.  

    Args:
        Y: a matrix with observations of shape (number of observations, number of variables).
        k: an integer representing the number of previous observations to take into account for computing the EMA (optional).

    Returns:
        A matrix of shape (number of observations, number of variables) with the moving average.
    """
    alpha = 2 / (k + 1)
    n = Y.shape[0]
    exponents = jnp.power(1 - alpha, jnp.arange(k))
    moving_average = jax.vmap(lambda y_d: alpha * jnp.convolve(y_d, exponents)[:n], in_axes=(1, ))(Y)
    moving_average = moving_average.T
    return moving_average


class Wishart_with_EMA(AbstractLikelihood):

    def __init__(self, nu, d, Y, k=10):
        self.nu = nu
        self.d = d
        self.mean = ema(Y, k=k)

    def link_function(self, f):
        """Identity function
        """
        return f

    def likelihood(self, params, f=None, Sigma=None, mean=None):
        assert f is not None or Sigma is not None, 'Provide either f or Sigma'
        if Sigma is None:
            if jnp.ndim(f):
                f = jnp.reshape(f, (-1, self.nu, self.d))
            L_vec = params['L_vec']
            L = vec2tril(L_vec, self.d)
            Sigma = construct_wishart(F=f, L=L)
        mean = self.mean if mean is None else mean
        return dx.MultivariateNormalFullCovariance(loc=mean, covariance_matrix=Sigma)

### Modelling of dynamic covariance between mental states using the Wishart process with SMC


In [None]:
# Model settings:
gpkernel = jk.RBF() + jk.Matern12()
num_particles = 1000
num_mcmc_steps = 5000

# Set priors and initialize model:
n, d = Y.shape
len_l = int(d * (d+1) / 2)
priors = dict(kernel = [dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp())),
                        dict(lengthscale=dx.Transformed(dx.Normal(loc=0., scale=1.), tfb.Exp()))],
              likelihood = dict(L_vec=dx.Normal(loc=jnp.zeros((len_l, )), scale=jnp.ones((len_l, )))))
model = FullLatentWishartModel(x, Y, cov_fn=gpkernel, priors=priors)
model.likelihood = Wishart_with_EMA(d+1, d, Y)

# Inference with SMC:
start = time.time()
key = jrnd.PRNGKey(10)
particles, num_iter, lml = model.inference(key, mode='gibbs-in-smc', 
                                            sampling_parameters=dict(num_particles=num_particles, 
                                                                    num_mcmc_steps=num_mcmc_steps))
end = time.time()

# Construct covariance process over full range:
key = jrnd.PRNGKey(5)
x_pred = np.linspace(np.min(x), np.max(x), 100)
Sigma_pred = model.predict_Sigma(key, x_pred)