In [1]:
from matplotlib import pyplot as plt
import jax
from jax.lax import scan
import jax.numpy as jnp
import numpy as np
import numpy.random as npr
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, ELBO
from numpyro.infer.autoguide import AutoDelta
from numpyro.optim import Adam
from numpyro.callbacks import Progbar

In [2]:
def arms_race(x0, y0, *, k, a, g, l, b, h,
              num_iterations=10, step_size=0.02):
    def arms_race_up(state, t):
        x, y = state
        dx = k * y - a * x + g
        dy = l * x - b * y + h
        x = x + step_size * dx
        y = y + step_size * dy
        return (x, y), (x, y)
    _, (xs, ys) = scan(arms_race_up, (x0, y0), jnp.arange(step_size, num_iterations + step_size, step_size))
    return xs, ys


In [3]:
num_iterations = 30
step_size = 0.02
gt_k = 10.
gt_a = 20.
gt_g = .1
gt_l = 10.
gt_b = 3.
gt_h = 6.
gt_xs, gt_ys = arms_race(10., 10., k=gt_k, a=gt_a, g=gt_g,
                         l=gt_l, b=gt_b, h=gt_h, num_iterations=num_iterations,
                         step_size=step_size)

In [4]:
gt_xs = jnp.reshape(gt_xs, (num_iterations, -1))[:, 0]
gt_ys = jnp.reshape(gt_ys, (num_iterations, -1))[:, 0]


In [5]:
num_datapoints = 100
noise_scale = 3
obs_xs = np.array(gt_xs + npr.randn(num_datapoints, *np.shape(gt_xs)) * noise_scale)
obs_ys = np.array(gt_ys + npr.randn(num_datapoints, *np.shape(gt_ys)) * noise_scale)

In [6]:
def model(dxs, dys):
    x0 = numpyro.sample('x0', dist.HalfCauchy())
    y0 = numpyro.sample('y0', dist.HalfCauchy())
    k = numpyro.sample('k', dist.HalfCauchy())
    a = numpyro.sample('a', dist.HalfCauchy())
    g = numpyro.sample('g', dist.HalfCauchy())
    l = numpyro.sample('l', dist.HalfCauchy())
    b = numpyro.sample('b', dist.HalfCauchy())
    h = numpyro.sample('h', dist.HalfCauchy())
    sxs, sys = arms_race(x0, y0, k=k, a=a, g=g, l=l, b=b, h=h,
                         num_iterations=num_iterations, step_size=step_size)
    sxs = jnp.reshape(sxs, (num_iterations, -1))[:, 0]
    sys = jnp.reshape(sys, (num_iterations, -1))[:, 0]
    with numpyro.plate('data', dxs.shape[0], dim=-1):
        numpyro.sample('xs', dist.Normal(loc=jnp.expand_dims(sxs, 0), scale=1000.).to_event(1), obs=dxs)
        numpyro.sample('ys', dist.Normal(loc=jnp.expand_dims(sys, 0), scale=1000.).to_event(1), obs=dys)

In [7]:
rng_key = jax.random.PRNGKey(1337)

In [9]:
svi = SVI(model, AutoDelta(model), Adam(1.0), ELBO())
state, loss = svi.train(rng_key, 1000000, obs_xs, obs_ys, callbacks=[Progbar()])

SVI 2.8475e+38: 100%|██████████| 1000000/1000000 [11:25<00:00, 1458.28it/s]


In [10]:
svi.get_params(state)


{'auto_a': DeviceArray(1.2954643, dtype=float32),
 'auto_b': DeviceArray(0.33825225, dtype=float32),
 'auto_g': DeviceArray(4.9982767, dtype=float32),
 'auto_h': DeviceArray(1.8571754, dtype=float32),
 'auto_k': DeviceArray(1.0389583, dtype=float32),
 'auto_l': DeviceArray(0.4966619, dtype=float32),
 'auto_x0': DeviceArray(62.434444, dtype=float32),
 'auto_y0': DeviceArray(19.377914, dtype=float32)}