In [1]:
from jax import config
config.update('jax_enable_x64', True)
from pathlib import Path

import jax
import gpjax as gpx
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

from uncprop.models.vsem.runner import run_vsem_experiment
from uncprop.utils.plot import set_plot_theme

colors = set_plot_theme()
base_dir = Path('/Users/andrewroberts/Desktop/git-repos/bip-surrogates-paper')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
test = run_vsem_experiment()

Creating new output directory: /Users/andrewroberts/Desktop/git-repos/bip-surrogates-paper/out/vsem
Creating experiment sub-directory: /Users/andrewroberts/Desktop/git-repos/bip-surrogates-paper/out/vsem/clip_gp_N8
Running replicate 0
Running replicate 1
Running replicate 2
Running replicate 3
0 of 4 replicates failed.


In [None]:
results, failed_iters = test
print(failed_iters)
print(len(results))

In [None]:
results[4].surrogate_pred.cov

In [None]:
from uncprop.models.vsem.experiment import VSEMReplicate

key = jr.key(9768565)
rep = VSEMReplicate(key=key,
                    n_design=8,
                    noise_sd=1.0,
                    n_grid=50,
                    jitter=1e-4)

In [None]:
from uncprop.models.vsem.experiment import VSEMExperiment, VSEMReplicate

key = jr.key(9768565)
setup_kwargs = {'n_grid': 50, 
                'n_design': 8, 
                'noise_sd': 1.0, 
                'verbose': False,
                'jitter': 1e-4}
num_reps = 20
backup_frequency = 10
experiment_name = 'vsem'
out_dir = base_dir / 'temp_out' / experiment_name

def _make_subdir_name(setup_kwargs, run_kwargs):
    return f'{run_kwargs['surrogate_tag']}_N{setup_kwargs['n_design']}'

# -----------------------------------------------------------------------------
# Run experiment 
# -----------------------------------------------------------------------------
experiment = VSEMExperiment(name=experiment_name,
                            num_reps=8, 
                            base_out_dir=out_dir,
                            base_key=key,
                            Replicate=VSEMReplicate,
                            subdir_name_fn=_make_subdir_name,
                            write_to_file=False)


In [None]:
test = experiment(run_kwargs={'surrogate_tag': 'clip_gp'}, 
                  setup_kwargs=setup_kwargs, 
                  backup_frequency=2)

In [None]:
results, failed_iters = test

print(failed_iters)

In [None]:
test = experiment.collect_results(results, failed_iters)

In [None]:
results, failed_iters = run_vsem_experiment()

In [None]:
from uncprop.utils.grid import _is_normalized

for idx, rep in enumerate(results):
    for nm in ['exact', 'mean', 'eup', 'ep']:
        l = rep.density_comparison.log_dens_norm_grid[nm]
        is_inf = jnp.isinf(l)
        is_norm = _is_normalized(l)
        if jnp.any(jnp.logical_or(is_inf, ~is_norm)):
            print(f'{idx} - {nm}')

In [None]:
rep = results[15]
l = rep.density_comparison.log_dens_grid['ep']

l

In [None]:
flat_grid = rep.density_comparison.grid.flat_grid

gp = rep.surrogate_posterior_gp.posterior_surrogate.surrogate
pred = gp(flat_grid)

In [None]:
from jax.numpy.linalg import cholesky
from jax.lax.linalg import cholesky as cholesky_lax

d = pred.dim
C = pred.cov.at[jnp.diag_indices(d)].add(1e-3)

L = cholesky(C, upper=False)
L2 = cholesky_lax(C, symmetrize_input=True)

print(f'jnp: {jnp.sum(jnp.isnan(L))}')
print(f'lax: {jnp.sum(jnp.isnan(L2))}')

In [None]:
from uncprop.models.vsem.surrogate import _print_gp_fit_info

_print_gp_fit_info(gp.gp, rep.fit_info)

In [None]:
rep.posterior.prior.support

In [None]:
# coverage results
log_coverage = jnp.stack(
    [rep.density_comparison.calc_coverage(baseline='exact')[0] for rep in results],
    axis=0
)

In [None]:
key = jr.key(9768565)
setup_kwargs = {'n_grid': 50, 'n_design': 4, 'noise_sd': 1.0, 'verbose': False}
num_reps = 3
experiment_name = 'vsem'
out_dir = base_dir / 'out' / experiment_name

# 3 cases: n = 4, 8, and one other (maybe 16)

def _make_subdir_name(setup_kwargs, run_kwargs):
    return f'{run_kwargs['surrogate_tag']}_N{setup_kwargs['n_design']}'

experiment = Experiment(name=experiment_name,
                        num_reps=num_reps,
                        base_out_dir=out_dir,
                        base_key=key,
                        Replicate=VSEMReplicate,
                        subdir_name_fn=_make_subdir_name)

# results_gp, failed_iters_gp = experiment(run_kwargs={'surrogate_tag': 'gp'}, 
#                                          setup_kwargs=setup_kwargs)

results_clip_gp, failed_iters_clip_gp = experiment(run_kwargs={'surrogate_tag': 'clip_gp'}, 
                                                   setup_kwargs=setup_kwargs)

In [None]:
for rep in results_gp:
    rep.density_comparison.plot(normalized=True, log_scale=False,
                                max_cols=4, points=rep.surrogate_posterior_gp.surrogate.design.X)
    rep.density_comparison.plot_coverage(baseline='exact')

In [None]:
for rep in results_clip_gp:
    rep.density_comparison.plot(normalized=True, log_scale=False,
                                max_cols=4, points=rep.surrogate_posterior_clip_gp.surrogate.design.X)
    rep.density_comparison.plot_coverage(baseline='exact')

In [None]:
from uncprop.utils.grid import plot_coverage_curve_reps

log_coverage_reps = jnp.stack(
    [rep.density_comparison.calc_coverage(baseline='exact')[0] for rep in results_clip_gp],
    axis=0
)

probs = results_clip_gp[0].density_comparison.calc_coverage(baseline='exact')[1]

fig, ax = plot_coverage_curve_reps(log_coverage_reps, probs=probs, names=['mean', 'eup', 'ep'])

In [None]:
# summarize surrogate mean and sd
pred = surrogate_posterior.surrogate(grid.flat_grid)

grid.plot(z=pred.mean, title='surrogate mean', points=design.X)
grid.plot(z=jnp.sqrt(pred.variance), title='surrogate sd', points=design.X)

In [None]:
# sampling posterior
key, key_mcmc = jr.split(key, 2)

samp = posterior.sample(key, n=3000, num_warmup_steps=500) # returns HMCState with samp.position (n, 1, 2)

plt.hist(samples[:,0])
plt.show()

In [None]:
density_comparison.plot(['exact', 'mean', 'eup', 'ep'],
                         normalized=True, log_scale=False,
                         max_cols=4, points=design.X)

In [None]:
from uncprop.utils.grid import get_grid_coverage_mask, plot_2d_mask

mask = get_grid_coverage_mask(log_prob=density_comparison.log_dens_norm_grid['ep'],
                              probs=jnp.linspace(0.1, 0.9, 30))

fig, ax = plot_2d_mask(mask, grid.shape, prob_idx=[0, 15, 20])

# TEMP

In [None]:
import blackjax

In [None]:
mwg_init_x = blackjax.rmh.init