# Copula Wishart processes

Wishart processes are used to model input-dependent covariance matrices. They can naturally be combined with multivariate Gaussian observations, as the Wishart process then simply provides the covariance matrix for the Gaussian at every input point. However, the situation becomes less straightforward when the observations are not Gaussian, or not even continuous. Here, we explore copula models that allow us to separate the multivariate correlation structure from the desired marginal distributions.

In [6]:
%load_ext autoreload
%autoreload 2

import os

SELECTED_DEVICE = '6'
print(f'Setting CUDA visible devices to [{SELECTED_DEVICE}]')
os.environ['CUDA_VISIBLE_DEVICES'] = f'{SELECTED_DEVICE}'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Setting CUDA visible devices to [6]


In [7]:
import matplotlib.pyplot as plt

import jax
jax.config.update("jax_enable_x64", True)

import jax.random as jrnd
import jax.numpy as jnp
import distrax as dx
import blackjax

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

import os
import sys

from blackjax import normal_random_walk
from blackjax.diagnostics import potential_scale_reduction, effective_sample_size

from jaxtyping import Array

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../../')))

from bamojax.base import Node, Model
from bamojax.sampling import gibbs_sampler, smc_inference_loop, elliptical_slice_nd, inference_loop

from bamojax.more_distributions import GaussianProcessFactory, RBF, Zero, Wishart

print('Python version:       ', sys.version)
print('Jax version:          ', jax.__version__)
print('BlackJax version:     ', blackjax.__version__)
print('Distrax version:      ', dx.__version__)
print('Jax default backend:  ', jax.default_backend())
print('Jax devices:          ', jax.devices())

Python version:        3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0]
Jax version:           0.4.35
BlackJax version:      1.2.4
Distrax version:       0.1.5
Jax default backend:   gpu
Jax devices:           [CudaDevice(id=0)]


## Copulas

But first, estimate a covariance matrix:

In [8]:
def vec2tril(v):
    L_sample = jnp.zeros((p, p))
    return L_sample.at[jnp.tril_indices(p, 0)].set(v)

#

def posdef_sigma(loc, L_vec):
    L = vec2tril(L_vec)
    return dict(loc=loc, covariance_matrix=jnp.dot(L, L.T))

#
cov = jnp.array([[1.0, 0.2], [0.2, 1.0]])
p = cov.shape[0]
nu = p + 1
n = 1000
m = int(p*(p+1)/2)

key = jrnd.PRNGKey(42)
key, subkey = jrnd.split(key)

Y = jrnd.multivariate_normal(subkey, mean=jnp.zeros(p), cov=cov, shape=(n, ))

copula_model = Model('Copulas')
L_node = copula_model.add_node(name='L_vec', distribution=dx.Normal(loc=jnp.zeros(m), scale=jnp.ones(m)))
Y_node = copula_model.add_node(name='Y', distribution=dx.MultivariateNormalFullCovariance, parents=dict(loc=jnp.zeros(p), L_vec=L_node), observations=Y, link_fn=posdef_sigma)


**Note**: Naive proposals do not impose a positive definiteness constraint on $\Sigma$ - we should instead sample $L$ in $\Sigma=LL^\top$.

In [9]:
%%time

logdensity_fn = lambda state: copula_model.loglikelihood_fn()(state) + copula_model.logprior_fn()(state)
rmh = normal_random_walk(logdensity_fn, sigma=0.03*jnp.eye(m))

num_samples = 50_000
num_burn = 50_000
num_thin = 1
num_chains = 1

rmh_states, rmh_info = inference_loop(key, model=copula_model, kernel=rmh, num_samples=num_samples, num_burn=num_burn, num_chains=num_chains, num_thin=num_thin)

print(f'Acceptance rate: {jnp.mean(1.0*rmh_info.is_accepted):0.3f}')

Acceptance rate: 0.359
CPU times: user 8.25 s, sys: 1.67 s, total: 9.91 s
Wall time: 9.5 s


In [10]:
mvn_params = jax.vmap(posdef_sigma, in_axes=(None, 0))(jnp.zeros(p), rmh_states.position['L_vec'])

jnp.mean(mvn_params['covariance_matrix'], axis=0)

Array([[0.99330278, 0.21618209],
       [0.21618209, 0.97584776]], dtype=float64)