# Predator-Prey model with triangular transport map
Based on "Transport map accelerated Markov chain Monte Carlo" by Parno and Marzouk (SIAM/ASA Journal on Uncertainty Quantification 2018)

In [1]:
from numpyro.util import safe_mul, safe_div

%env JAX_DEBUG_NANS=True
import jax
import jax.ops
import jax.numpy as jnp
import pandas as pd
import seaborn as sns
from tqdm import tqdm
from functools import partial

import numpyro
from numpyro import distributions as dist
from numpyro.distributions.transforms import MultivariateAffineTransform
from numpyro.infer.guide import WrappedGuide
from numpyro.contrib.autoguide import AutoDelta
from numpyro.examples.runge_kutta import runge_kutta_4
from numpyro.infer import init_to_uniform, init_with_noise, SVI, Stein, ELBO
from numpyro.infer.kernels import RBFKernel
from numpyro.callbacks import Progbar
from numpyro.optim import Adam

env: JAX_DEBUG_NANS=True




In [2]:
rng_key = jax.random.PRNGKey(242)

## Predator Prey Model

In [3]:
def predator_prey_step(t, state, r=0.6, k=100, s=1.2, a=25, u=0.5, v=0.3):
    prey = state[..., 0]
    predator = state[..., 1]
    sh = safe_div(safe_mul(prey, predator), a + prey)
    prey_upd = r * safe_mul(prey, safe_div(1 - prey, k)) - s * sh
    predator_upd = u * sh - v * predator
    return jnp.stack((prey_upd, predator_upd), axis=-1)
num_time = 5
step_size = 0.1
num_steps = int(num_time / step_size)
dampening_rate = 0.9
lyapunov_scale = 1e-3
clip = lambda x: jnp.clip(x, -10.0, 10.0)
predator_prey = runge_kutta_4(predator_prey_step, step_size, num_steps, dampening_rate, 
                              lyapunov_scale, clip,
                              unconstrain_fn=lambda _, x: jnp.where(x < 10, jnp.log(jnp.expm1(x)), x),
                              constrain_fn=lambda _, x: jax.nn.softplus(x))
predator_prey = partial(predator_prey, rng_key)

In [4]:
indices = jnp.array([1, 11, 21, 31, 41])
res, lyapunov_loss = predator_prey(jnp.array([50., 5.]))
# res = np.reshape(res, (num_time, num_steps // num_time, -1))[:, 0, :]
noise = jax.random.normal(rng_key, (1000,5,2)) * 10
data = (indices, res[indices] + noise)
data

(DeviceArray([ 1, 11, 21, 31, 41], dtype=int32),
 DeviceArray([[[ 33.311028 , -12.732149 ],
               [ 37.39188  ,   6.7296047],
               [ 22.41841  ,   2.8242142],
               [ 34.01062  ,  11.274597 ],
               [  7.3329477,  -5.1669645]],
 
              [[ 64.81349  ,   8.4971285],
               [ 53.886345 ,   6.9032702],
               [  1.8843441,  -3.759818 ],
               [ 21.644583 ,  11.018448 ],
               [ -3.675232 ,   5.044437 ]],
 
              [[ 39.057636 ,  -8.654196 ],
               [ 39.12978  ,  10.124067 ],
               [ 17.123663 ,  19.667484 ],
               [ 18.93769  ,  22.730627 ],
               [  3.0498734,   9.167135 ]],
 
              ...,
 
              [[ 59.46758  ,  17.306845 ],
               [ 33.735657 ,   5.6848707],
               [ 31.406258 ,   6.3421493],
               [ 23.054996 ,  14.960936 ],
               [  7.580385 ,   2.0696585]],
 
              [[ 54.357674 ,  -6.5606894],
               

In [5]:
def model(indices, observations):
    prior_dist = dist.HalfNormal(1000)
    prey0 = numpyro.sample('prey0', prior_dist)
    predator0 = numpyro.sample('predator0', prior_dist)
    r = numpyro.sample('r', prior_dist)
    k = numpyro.sample('k', prior_dist)
    s = numpyro.sample('s', prior_dist)
    a = numpyro.sample('a', prior_dist)
    u = numpyro.sample('u', prior_dist)
    v = numpyro.sample('v', prior_dist)
    ppres, lyapunov_loss = predator_prey(jnp.array([prey0, predator0]), r=r, k=k, s=s, a=a, u=u, v=v)
    # ppres = np.reshape(ppres, (num_time, num_time // num_steps, -1))
    numpyro.factor('lyapunov_loss', lyapunov_loss)
    numpyro.sample('obs', dist.Normal(ppres[indices], 10.0).to_event(2), obs=observations)

### SVI

In [None]:
svi = SVI(model, AutoDelta(model), Adam(0.001), ELBO())
state = svi.init(rng_key, *data)
pbar = tqdm(range(1000))
prev_state = state
for i in pbar:
    prev_state = state
    state, loss = svi.update(state, *data)
    pbar.set_description(f'SVI {loss}')

In [7]:
from jaxinterp.interpreter import interpret
_, rng_key_debug = jax.random.split(prev_state.rng_key)
params = svi.optim.get_params(prev_state.optim_state)
interpret(lambda params, *data:
          jax.value_and_grad(lambda x: svi.loss.loss(rng_key_debug, svi.constrain_fn(x),
                                           svi.model, svi.guide, *data))(params),
          stage_out=True)(params, *data)

KeyboardInterrupt: 

In [None]:
svi.get_params(state)

### Guide and Stein with Transport Maps

In [None]:
svgd = Stein(model, AutoDelta(model),
             Adam(0.0001), ELBO(),
             RBFKernel(), num_particles=100,
             repulsion_temperature=0.001 * data[1].shape[0])
state, loss = svgd.train(rng_key, 10000, *data, callbacks=[Progbar()])  # rounds 10000

In [None]:
sample_frame = pd.DataFrame.from_dict(svgd.predict(state, *data))
g = sns.pairplot(sample_frame, corner=True, diag_kind='kde')
g.map_lower(sns.kdeplot, lw=2)

In [None]:
def transmap_guide(indices, observations):
    param_keys = {'prey0', 'preadtor0', 'r', 'k', 's', 'a', 'u', 'v'}
    n = len(param_keys)
    tmapp = numpyro.param('tmapp', jnp.zeros(n * (n + 1) // 2))
    tril_idx = jnp.tril_indices(n)
    tmap = jax.ops.index_update(jnp.zeros((n, n)), tril_idx, tmapp)
    tloc = numpyro.param('tloc', jnp.zeros(n))
    vals = numpyro.param('vals', jnp.zeros(n),
                         particle_transform=MultivariateAffineTransform(tloc, tmap),
                         constraint=dist.constraints.greater_than(0.1))
    for pk, val in zip(param_keys, vals):
        numpyro.sample(pk, dist.Delta(val))

In [None]:
svgd = Stein(model, WrappedGuide(transmap_guide, init_strategy=init_with_noise(init_to_uniform())),
             Adam(0.01), ELBO(),
             RBFKernel(), repulsion_temperature=0.01 * data[1].shape[0],
             num_particles=100, classic_guide_params_fn=lambda n: n in {'tmapp', 'tloc'})
state, loss = svgd.train(rng_key, 10000, *data, callbacks=[Progbar()])

In [None]:
sample_frame = pd.DataFrame(svgd.predict(state, *data))
g = sns.pairplot(sample_frame, corner=True, diag_kind='kde')
g.map_lower(sns.kdeplot, lw=2)

### NUTS MCMC Sampling

In [None]:
mcmc = numpyro.infer.MCMC(numpyro.infer.NUTS(model), 100, 500, chain_method='vectorized')
mcmc.run(rng_key, *data)
mcmc.print_summary()

In [None]:
samples = mcmc.get_samples()
sample_frame = pd.DataFrame(samples)
g = sns.pairplot(sample_frame, corner=True, diag_kind='kde')
g.map_lower(sns.kdeplot, lw=2)