In [1]:
from jax import config
config.update('jax_enable_x64', True)
from pathlib import Path

import jax
import jax.numpy as jnp
import gpjax as gpx
import jax.numpy as jnp
import jax.random as jr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from numpyro.distributions import MultivariateNormal

from uncprop.core.distribution import GaussianFromNumpyro
from uncprop.core.samplers import sample_distribution
from uncprop.utils.plot import set_plot_theme, smart_subplots
from uncprop.utils.grid import plot_coverage_curve_reps
from uncprop.models.elliptic_pde.inverse_problem import (
    PDESettings,
    generate_pde_inv_prob_rep,
    plot_inverse_problem_setup,
)

colors = set_plot_theme()
base_dir = Path('/Users/andrewroberts/Desktop/git-repos/bip-surrogates-paper')

from pde_model import (
    get_discrete_source, 
    solve_pde, 
    solve_pde_vmap,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# settings 

key = jr.key(5232214)
noise_sd = 1e-2
n_kl_modes = 6
obs_locations = jnp.array([10, 30, 60, 75])

inv_prob_settings = {
    'noise_cov' : noise_sd**2 * jnp.identity(len(obs_locations)),
    'n_kl_modes': n_kl_modes,
    'obs_locations': obs_locations,
    'settings': PDESettings()
}
inv_prob_settings['key'] = key

In [3]:
posterior, gp_prior, eig_info, ground_truth = generate_pde_inv_prob_rep(**inv_prob_settings)

In [None]:
key, key_plot = jr.split(key)

fig, ax = plot_inverse_problem_setup(key=key_plot,
                                     posterior=posterior,
                                     ground_truth=ground_truth,
                                     observation=posterior.likelihood.observation,
                                     n_samp=3)

In [None]:
key, key_prior, key_mcmc = jr.split(key, 3)

positions, states, warmup_samp, prop_cov = sample_distribution(
    key=key,
    dist=posterior,
    initial_position=posterior.prior.sample(key_prior).squeeze(),
    n_samples=10_000,
    n_warmup=10_000
)

In [None]:
for i in range(posterior.dim):
    plt.plot(positions[:,i])

In [6]:
from uncprop.utils.distribution import _gaussian_log_density_tril, _gaussian_log_det_term_tril

In [36]:
y = jnp.array([-1, 0])
m = jnp.array([[1, 1], [0, 0], [0, 1]])

C1 = jnp.array([[1, 0.8], [0.8, 1]])
C2 = jnp.array([[2, 0.5], [0.5, 2]])
C3 = jnp.array([[0.5, -0.1], [-0.1, 1]])
C = jnp.stack([C1, C2, C3])

L = jnp.linalg.cholesky(C, upper=False)

In [59]:
jnp.atleast_2d(y) + m

Array([[ 0,  1],
       [-1,  0],
       [-1,  1]], dtype=int64)

In [56]:
m

Array([[1, 1],
       [0, 0],
       [0, 1]], dtype=int64)

In [66]:
from jax.scipy.linalg import solve_triangular

def _gaussian_log_density_tril_new(x, m, L):
    x = jnp.atleast_2d(x) - m
    d = x.shape[1]
    L = jnp.broadcast_to(L, (x.shape[0], d, d))
    Linv_x = solve_triangular(L, x, lower=True)  # (3, 2, 2) (3, 2)
    mah2 = jnp.sum(Linv_x ** 2, axis=1)
    log_det_term = _gaussian_log_det_term_tril(L) # (3,)
    return log_det_term - 0.5 * mah2

In [67]:
print(_gaussian_log_density_tril(y, m, L[0]),
      _gaussian_log_density_tril(y, m, L[1]),
      _gaussian_log_density_tril(y, m, L[2]))

print(_gaussian_log_density_tril_new(y, m, L[0]),
      _gaussian_log_density_tril_new(y, m, L[1]),
      _gaussian_log_density_tril_new(y, m, L[2]))

print(_gaussian_log_density_tril_new(y, m, L))



[-3.82705144 -2.71594033 -1.882607  ] [-3.56542165 -2.76542165 -2.89875499] [-6.48120212 -2.50161029 -3.215896  ]
[-3.82705144 -2.71594033 -1.882607  ] [-3.56542165 -2.76542165 -2.89875499] [-6.48120212 -2.50161029 -3.215896  ]
[-3.82705144 -2.76542165 -3.215896  ]


In [None]:
from uncprop.utils.distribution import _gaussian_log_density_tril


_gaussian_log_density_tril(y, m=pred.mean.T, L=L_x)