In [25]:
import jax.numpy as jnp
import numpy as np
from jax import config
import matplotlib.pyplot as plt

from pde_model import (
    get_discrete_source, 
    solve_pde, 
    solve_pde_vmap,
)

import sys
sys.path.append("./../linear_Gaussian/")
from Gaussian import Gaussian

config.update("jax_enable_x64", True)

In [26]:
import seaborn as sns
sns.set_theme(style='white', palette='colorblind')
sns.set_context("paper", font_scale=1.5)

# Specific Paul Tol color scheme when comparing different posteriors
colors = {
    'exact': "#4477AA",
    'mean': "#EE6677",
    'eup': "#228833",
    'ep': "#CCBB44",
    'aux': "#888888"
}

In [10]:
rng = np.random.default_rng(5623423)

In [11]:
# grid 
n_grid = 50
xgrid = jnp.linspace(0.0, 1.0, n_grid)

# diffusivity
k = jnp.abs(0.1 + 0.05 * rng.standard_normal(n_grid))

# boundary conditions
left_flux = -1.0
rightbc = 1.0

# source term
source_wells = jnp.array([0.2, 0.4, 0.6, 0.8])
source_strength = 0.8
source_width = 0.05
source = get_discrete_source(xgrid, well_locations=source_wells, 
                             strength=source_strength, width=source_width)

In [12]:
u = solve_pde(xgrid=xgrid, 
              left_flux=left_flux, 
              k=k, 
              source=source, 
              rightbc=rightbc)

### Prior

In [28]:
from gpjax.kernels.stationary import PoweredExponential

prior_kernel = PoweredExponential(lengthscale=0.3, variance=1.0, power=0.3, n_dims=1)
prior_mean = jnp.tile(1.0, n_grid)
prior_cov = prior_kernel.gram(xgrid.reshape((-1,1))).to_dense()
prior = Gaussian(prior_mean, prior_cov, rng=rng)

In [None]:
def plot_marginals(dist, grid, colors, alpha=0.3):

    fig, ax = plt.subplots()

    # intervals
    sd = np.sqrt(np.diag(dist.cov))
    lower = dist.mean - 2 * sd
    upper = dist.mean + 2 * sd

    ax.fill_between(grid, lower, upper, alpha=alpha, label="+/- 2 sd")

    if inv_prob.u_true is not None:
        ax.plot(grid, inv_prob.u_true, color="black", label="u_true")
    if g_conv_true is not None:
        ax.plot(grid, g_conv_true, color="orange", label="g_true")

    ax.plot(idx_obs, inv_prob.y, "o", color="red", label="y")
    ax.plot(grid, inv_prob.post.mean, color=colors['exact'], label="post mean")
    ax.legend()

    return fig, ax

In [23]:
prior.sample(4)

Array([[ 0.6726668 ,  1.75650566,  0.74384269,  1.54554806, -0.28312074,
         0.33202675,  0.11369178,  1.07904184, -1.64905509,  0.3414461 ,
         0.80893663,  0.14798933,  0.97111467, -1.02117946, -0.27388689,
         0.46696328, -0.28635038, -0.69302215,  0.4969304 ,  1.16979761,
         1.50849603, -0.50828017, -0.20943322,  0.53321726, -0.09290759,
         1.28409685,  1.52832826,  1.11348184,  0.43554385,  0.0477857 ,
         0.39332239,  0.2499602 ,  1.06565903,  1.59798837,  0.92559865,
         1.06917633,  1.45101385,  1.28257831,  1.67239164,  0.91746126,
         0.47665148,  0.82657415,  1.39352575,  0.74195187, -0.02892505,
        -0.15006272,  0.23606172,  0.18541898, -0.39923778, -0.93273425],
       [-0.248677  ,  1.17925829,  1.60971742,  1.81863766,  0.92627127,
         0.34132458,  1.48577967,  0.74366327,  1.38135731,  1.86363246,
         0.57648644,  0.96354225,  2.09945278,  1.44617124,  0.16893023,
         0.79827978,  0.69491193,  0.57784538,  0.