In [3]:
from jax import config
config.update('jax_enable_x64', True)
from pathlib import Path

import jax
import gpjax as gpx
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

from uncprop.utils.experiment import Experiment
from uncprop.models.vsem.experiment import VSEMReplicate
from uncprop.utils.plot import set_plot_theme
from uncprop.utils.grid import Grid, DensityComparisonGrid

colors = set_plot_theme()
base_dir = Path('/Users/andrewroberts/Desktop/git-repos/bip-surrogates-paper')

In [None]:
key = jr.key(246432)
setup_kwargs = {'n_grid': 50, 'n_design': 7}
num_reps = 3
out_dir = base_dir / 'out'

experiment = Experiment(name='vsem_experiment',
                        num_reps=num_reps,
                        base_out_dir=out_dir,
                        base_key=key,
                        Replicate=VSEMReplicate)

results = experiment(setup_kwargs=setup_kwargs)

FileExistsError: Experiment out dir already exists: out/vsem_experiment

In [None]:
for rep in results:
    rep.density_comparison.plot(normalized=True, log_scale=False,
                                max_cols=4, points=rep.surrogate_posterior.surrogate.design.X)

In [None]:
grid = Grid(low=posterior.support[0],
            high=posterior.support[1],
            n_points_per_dim=[50, 50],
            dim_names=posterior.prior.par_names)

# true posterior and design points
design = surrogate_posterior.surrogate.design
fig, ax = grid.plot(f=posterior._get_log_density_function(),
                    title='posterior', points=design.X)

In [None]:
# summarize surrogate mean and sd
pred = surrogate_posterior.surrogate(grid.flat_grid)

grid.plot(z=pred.mean, title='surrogate mean', points=design.X)
grid.plot(z=jnp.sqrt(pred.variance), title='surrogate sd', points=design.X)

In [None]:
# sampling posterior
key, key_mcmc = jr.split(key, 2)

samp = posterior.sample(key, n=3000, num_warmup_steps=500) # returns HMCState with samp.position (n, 1, 2)

plt.hist(samples[:,0])
plt.show()

In [None]:
density_comparison.plot(['exact', 'mean', 'eup', 'ep'],
                         normalized=True, log_scale=False,
                         max_cols=4, points=design.X)

In [None]:
from uncprop.utils.grid import get_grid_coverage_mask, plot_2d_mask

mask = get_grid_coverage_mask(log_prob=density_comparison.log_dens_norm_grid['ep'],
                              probs=jnp.linspace(0.1, 0.9, 30))

fig, ax = plot_2d_mask(mask, grid.shape, prob_idx=[0, 15, 20])