In [10]:
import jax
import jax.numpy as jnp
import numpy as np
import numpy.random as npr
from jax.lax import scan

import numpyro
import numpyro.distributions as dist
from numpyro.callbacks import Progbar
from numpyro.infer import SVI, ELBO
from numpyro.infer.autoguide import AutoDelta
from numpyro.optim import Adam

In [11]:
def arms_race(x0, y0, *, k, a, g, l, b, h,
              num_iterations=10, step_size=0.02):
    def arms_race_up(state, t):
        x, y = state
        dx = k * y - a * x + g
        dy = l * x - b * y + h
        x = x + step_size * dx
        y = y + step_size * dy
        return (x, y), (x, y)
    _, (xs, ys) = scan(arms_race_up, (x0, y0), jnp.arange(step_size, num_iterations + step_size, step_size))
    return xs, ys


In [12]:
num_iterations = 30
step_size = 0.02
gt_k = 10.
gt_a = 20.
gt_g = .1
gt_l = 10.
gt_b = 3.
gt_h = 6.
gt_xs, gt_ys = arms_race(10., 10., k=gt_k, a=gt_a, g=gt_g,
                         l=gt_l, b=gt_b, h=gt_h, num_iterations=num_iterations,
                         step_size=step_size)

In [13]:
gt_xs = jnp.reshape(gt_xs, (num_iterations, -1))[:, 0]
gt_ys = jnp.reshape(gt_ys, (num_iterations, -1))[:, 0]


In [14]:
num_datapoints = 100
noise_scale = 3
obs_xs = np.array(gt_xs + npr.randn(num_datapoints, *np.shape(gt_xs)) * noise_scale)
obs_ys = np.array(gt_ys + npr.randn(num_datapoints, *np.shape(gt_ys)) * noise_scale)

In [15]:
def model(dxs, dys):
    x0 = numpyro.sample('x0', dist.HalfNormal())
    y0 = numpyro.sample('y0', dist.HalfNormal())
    k = numpyro.sample('k', dist.HalfNormal())
    a = numpyro.sample('a', dist.HalfNormal())
    g = numpyro.sample('g', dist.HalfNormal())
    l = numpyro.sample('l', dist.HalfNormal())
    b = numpyro.sample('b', dist.HalfNormal())
    h = numpyro.sample('h', dist.HalfNormal())
    sxs, sys = arms_race(x0, y0, k=k, a=a, g=g, l=l, b=b, h=h,
                         num_iterations=num_iterations, step_size=step_size)
    sxs = jnp.reshape(sxs, (num_iterations, -1))[:, 0]
    sys = jnp.reshape(sys, (num_iterations, -1))[:, 0]
    with numpyro.plate('data', dxs.shape[0], dim=-1):
        numpyro.sample('xs', dist.Normal(loc=jnp.expand_dims(sxs, 0), scale=100.).to_event(1), obs=dxs)
        numpyro.sample('ys', dist.Normal(loc=jnp.expand_dims(sys, 0), scale=100.).to_event(1), obs=dys)

In [16]:
rng_key = jax.random.PRNGKey(1337)

In [17]:
svi = SVI(model, AutoDelta(model), Adam(1.0), ELBO())
state, loss = svi.train(rng_key, 100, obs_xs, obs_ys, callbacks=[Progbar()])

SVI inf: 100%|██████████| 100/100 [00:00<00:00, 127.56it/s]


In [21]:
from jaxinterp.interpreter import _make_jaxpr_with_consts
_, rng_key_eval = jax.random.split(state.rng_key)
params = svi.get_params(state)
_make_jaxpr_with_consts(lambda x: svi.loss.loss(rng_key_eval, x, svi.model, svi.guide, obs_xs, obs_ys), False)(params)

({ lambda cu da df dj dq ds dx eb ; a b c d e f g h.
   let i = sub g 0.0
       j = div i 1.0
       k = integer_pow[ y=2 ] j
       l = log1p k
       m = sub -1.1447298526763916 l
       n = add m 0.6931471824645996
       o = reduce_sum[ axes=() ] n
       p = add 0.0 o
       q = sub h 0.0
       r = div q 1.0
       s = integer_pow[ y=2 ] r
       t = log1p s
       u = sub -1.1447298526763916 t
       v = add u 0.6931471824645996
       w = reduce_sum[ axes=() ] v
       x = add p w
       y = sub e 0.0
       z = div y 1.0
       ba = integer_pow[ y=2 ] z
       bb = log1p ba
       bc = sub -1.1447298526763916 bb
       bd = add bc 0.6931471824645996
       be = reduce_sum[ axes=() ] bd
       bf = add x be
       bg = sub a 0.0
       bh = div bg 1.0
       bi = integer_pow[ y=2 ] bh
       bj = log1p bi
       bk = sub -1.1447298526763916 bj
       bl = add bk 0.6931471824645996
       bm = reduce_sum[ axes=() ] bl
       bn = add bf bm
       bo = sub c 0.0
       bp = div 