# Toy 1d Example for Review Paper

In [19]:
from jax import config
config.update('jax_enable_x64', True)
from pathlib import Path

import jax
import jax.numpy as jnp
import jax.random as jr
from scipy.stats import qmc
from numpy.random import default_rng

from uncprop.utils.grid import Grid, DensityComparisonGrid
from uncprop.core.inverse_problem import Prior, Posterior
from uncprop.utils.other import _numpy_rng_seed_from_jax_key
from uncprop.utils.plot import (
    set_plot_theme, 
    smart_subplots,
)

colors = set_plot_theme()

base_dir = Path('/Users/andrewroberts/Desktop/git-repos/bip-surrogates-paper')
out_dir = base_dir / 'out' / 'review_final'

In [20]:
key = jr.key(325234)

# plot settings
n_grid = 100

# exact inverse problem
noise_sd = 1.0
time_horizon = (-10, 10)
support = (0, 1)
true_param = 0.4
n_time = 8

# forward model surrogate
...

# log-posterior surrogate
... 

Ellipsis

## Setup

In [6]:
# Forward Model
times = jnp.linspace(time_horizon[0], time_horizon[1], n_time)
noise_var = noise_sd ** 2

def forward(u):
    u = jnp.atleast_2d(u)
    xt = jnp.exp(u * times)
    return jnp.mean(xt, axis=1)

In [7]:
# Ground truth
key, key_noise = jr.split(key)

true_observable = forward(true_param)
true_noise = noise_sd * jr.normal(key_noise)
y = true_observable + true_noise

In [17]:
# exact inverse problem 

def log_lik(u):
    return -0.5 * jnp.log(2*jnp.pi*noise_var) - 0.5 * (y - forward(u))**2 / noise_var

class Prior1d(Prior):

    @property
    def dim(self):
        return 1

    @property
    def support(self):
        return support
     
    @property
    def par_names(self):
        return ['u']
    
    def log_density(self, u):
        u = jnp.atleast_2d(u)
        n_inputs = u.shape[0]
        a,b = self.support
        log_dens = -jnp.log(b - a)
        return jnp.tile(log_dens, n_inputs)
    
    def sample(self, key, n: int = 1):
        return jr.uniform(key, shape=(n, self.dim), 
                          minval=self.support[0],
                          maxval=self.support[1])
    
    def sample_lhc(self, key, n: int = 1):
        """Latin hypercube sample"""
        rng_key = _numpy_rng_seed_from_jax_key(key)
        rng = default_rng(seed=rng_key)
        lhc = qmc.LatinHypercube(d=self.dim, rng=rng)

        samp = lhc.random(n=n)
        a,b = self.support
        return jnp.asarray(a + samp * (b - a))
    
prior = Prior1d()
posterior = Posterior(prior, log_lik)

In [21]:
grid = Grid(low=support[0], high=support[1], 
            n_points_per_dim=n_grid, 
            dim_names=prior.par_names)

In [29]:
grid.grid_arrays[0].shape

(100,)