In [1]:
import os
import numpy as np
import jax.numpy as jnp
import jax.random as random
from scipy.stats import gaussian_kde
from rsnl.metrics import plot_and_save_coverage
from rsnl.examples.contaminated_slcp import calculate_summary_statistics, true_dgp
import matplotlib.pyplot as plt
import pickle as pkl
import arviz as az
import matplotlib.colors as mcolors

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
# check x_obs for seed_9
seed = 1
rng_key = random.PRNGKey(seed)
rng_key, sub_key  = random.split(rng_key)
true_params = jnp.array([0.7, -2.9, -1.0, -0.9, 0.6])
# true_params = prior.sample(sub_key1)
x_obs = calculate_summary_statistics(true_dgp(sub_key, *true_params))
x_obs = jnp.around(x_obs, 2)
print(f'x_obs: {x_obs}')

x_obs: [ 1.2000000e-01 -2.7500000e+00  4.7000000e-01 -1.7700000e+00
 -1.6000000e-01 -3.3199999e+00  1.6999999e+00 -2.5999999e+00
  2.3410000e+01 -1.7848999e+02]


In [None]:
with open("../res/contaminated_slcp/rsnl/seed_1/theta.pkl", "rb") as f:
    theta_draws_rsnl = jnp.array(pkl.load(f))

thetas_rsnl = jnp.concatenate(theta_draws_rsnl, axis=0)
thetas_rsnl = jnp.squeeze(thetas_rsnl)

with open("../res/contaminated_slcp/snl/seed_1/theta.pkl", "rb") as f:
    theta_draws_snl = jnp.array(pkl.load(f))

thetas_snl = jnp.concatenate(theta_draws_snl, axis=0)
thetas_snl = jnp.squeeze(thetas_snl)

In [None]:
with open("../res/contaminated_slcp/rsnl/seed_1/adj_params.pkl", "rb") as f:
    adj_params = jnp.array(pkl.load(f))

adj_params = jnp.concatenate(adj_params, axis=0)


In [None]:
# Set the default font to Times New Roman
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams.update({'font.size': 35})

## Adjustment Plots

In [None]:
rng_key = random.PRNGKey(0)
prior_samples = random.laplace(rng_key, shape=(10000, 2))

for i in range(10):
    az.plot_dist(adj_params[:, i],
                 label='Posterior',
                 color='black')
    az.plot_dist(prior_samples[:, i],
                 color=mcolors.CSS4_COLORS['limegreen'],
                 plot_kwargs={'linestyle': 'dashed'},
                 label='Prior')
    plt.xlabel("$\gamma_{%s}$" % (i+1), fontsize=35)
    plt.ylabel("Density", fontsize=35)
    plt.xlim([-10, 10])
    plt.ylim(bottom=0)
    plt.xticks([-10, -5, 0, 5, 10], fontsize=35)
    plt.yticks(fontsize=35)
    plt.legend(fontsize=35,
               loc='upper left',
               borderpad=0.1, labelspacing=0.1, handletextpad=0.1)

    plt.tight_layout()
    plt.savefig(f'contaminated_slcp_adj_param_{i+1}.pdf', bbox_inches='tight')
    plt.clf()

<Figure size 640x480 with 0 Axes>