In [1]:
%matplotlib inline


# Example: Predator-Prey Model

This example replicates the great case study [1], which leverages the Lotka-Volterra
equation [2] to describe the dynamics of Canada lynx (predator) and snowshoe hare
(prey) populations. We will use the dataset obtained from [3] and run MCMC to get
inferences about parameters of the differential equation governing the dynamics.

**References:**

    1. Bob Carpenter (2018), `"Predator-Prey Population Dynamics: the Lotka-Volterra model in Stan"
       <https://mc-stan.org/users/documentation/case-studies/lotka-volterra-predator-prey.html/>`_.
    2. https://en.wikipedia.org/wiki/Lotka-Volterra_equations
    3. http://people.whitman.edu/~hundledr/courses/M250F03/M250.html

<img src="file://../_static/img/examples/ode.png" align="center">


In [6]:
import os

import matplotlib
import matplotlib.pyplot as plt

import jax
from jax.experimental.ode import odeint
import jax.numpy as jnp
from jax.random import PRNGKey

import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import LYNXHARE, load_dataset
from numpyro.infer import MCMC, NUTS, Predictive

matplotlib.use("Agg")  # noqa: E402


def dz_dt(z, t, theta):
    """
    Lotka–Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`
    describes the interaction of two species.
    """
    u = z[0]
    v = z[1]
    alpha, beta, gamma, delta = (
        theta[..., 0],
        theta[..., 1],
        theta[..., 2],
        theta[..., 3],
    )
    du_dt = (alpha - beta * v) * u
    dv_dt = (-gamma + delta * u) * v
    return jnp.stack([du_dt, dv_dt])


def model(N, y=None):
    """
    :param int N: number of measurement times
    :param numpy.ndarray y: measured populations with shape (N, 2)
    """
    # initial population
    z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1).expand([2]))
    # measurement times
    ts = jnp.arange(float(N))
    # parameters alpha, beta, gamma, delta of dz_dt
    theta = numpyro.sample(
        "theta",
        dist.TruncatedNormal(
            low=0.0,
            loc=jnp.array([1.0, 0.05, 1.0, 0.05]),
            scale=jnp.array([0.5, 0.05, 0.5, 0.05]),
        ),
    )
    # integrate dz/dt, the result will have shape N x 2
    z = odeint(dz_dt, z_init, ts, theta, rtol=1e-6, atol=1e-5, mxstep=1000)
    # measurement errors
    sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand([2]))
    # measured populations
    numpyro.sample("y", dist.LogNormal(jnp.log(z), sigma), obs=y)


def main(args):
    _, fetch = load_dataset(LYNXHARE, shuffle=False)
    year, data = fetch()  # data is in hare -> lynx order

    # use dense_mass for better mixing rate
    mcmc = MCMC(
        NUTS(model, dense_mass=True),
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(PRNGKey(1), N=data.shape[0], y=data)
    mcmc.print_summary()

    # predict populations
    pop_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"]
    mu = jnp.mean(pop_pred, 0)
    pi = jnp.percentile(pop_pred, jnp.array([10, 90]), 0)
    plt.figure(figsize=(8, 6), constrained_layout=True)
    plt.plot(year, data[:, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67)
    plt.plot(year, data[:, 1], "bx", label="true lynx")
    plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67)
    plt.plot(year, mu[:, 1], "b--", label="pred lynx")
    plt.fill_between(year, pi[0, :, 0], pi[1, :, 0], color="k", alpha=0.2)
    plt.fill_between(year, pi[0, :, 1], pi[1, :, 1], color="b", alpha=0.3)
    plt.gca().set(ylim=(0, 160), xlabel="year", ylabel="population (in thousands)")
    plt.title("Posterior predictive (80% CI) with predator-prey pattern.")
    plt.legend()

    plt.savefig("ode_plot.pdf")


if __name__ == "__main__":
    
    
    assert numpyro.__version__.startswith("0.9.0")
    
    class Args:
        num_samples = int(1e3)
        num_warmup = int(1e3)
        num_chains = int(3)
        device = "cpu"

    args = Args()

    numpyro.set_platform(args.device)
    numpyro.set_host_device_count(args.num_chains)
    print(jax.local_device_count())
    main(args)

1


  mcmc = MCMC(
sample: 100%|████████████| 2000/2000 [01:32<00:00, 21.71it/s, 255 steps of size 7.63e-03. acc. prob=0.96]
sample: 100%|█████████████| 2000/2000 [00:06<00:00, 294.70it/s, 7 steps of size 3.29e-01. acc. prob=0.86]
sample: 100%|████████████| 2000/2000 [00:51<00:00, 38.63it/s, 255 steps of size 1.57e-02. acc. prob=0.95]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
  sigma[0]      1.04      0.08      1.03      0.90      1.17   2658.68      1.00
  sigma[1]      0.52      0.04      0.52      0.45      0.59   2841.27      1.00
  theta[0]      0.48      0.11      0.48      0.32      0.65   2594.34      1.00
  theta[1]      0.02      0.00      0.02      0.01      0.03   2691.18      1.00
  theta[2]      0.99      0.22      0.96      0.62      1.30   2580.78      1.00
  theta[3]      0.03      0.01      0.03      0.02      0.05   2512.23      1.00
 z_init[0]     51.15      9.46     50.24     35.88     65.76   2351.98      1.00
 z_init[1]     34.54      5.90     34.24     24.90     43.91   1907.89      1.00

Number of divergences: 3


In [7]:
# full transporter model


# vrxn1 = vol*(rxn1_k1*IF-rxn1_k2*OF)
# vrxn2 = vol*(rxn2_k1*OF*H_out-rxn2_k2*OF_Hb)
# vrxn3 = vol*(rxn3_k1*OF_Sb-rxn3_k2*OF*S_out)
# vrxn4 = vol*(rxn4_k1*OF_Hb-rxn4_k2*IF_Hb)
# vrxn5 = vol*(rxn5_k1*OF_Hb_Sb-rxn5_k2*OF_Hb*S_out)
# vrxn6 = vol*(rxn6_k1*IF_Sb-rxn6_k2*OF_Sb)
# vrxn7 = vol*(rxn7_k1*OF_Sb*H_out-rxn7_k2*OF_Hb_Sb)
# vrxn8 = vol*(rxn8_k1*OF_Hb_Sb-rxn8_k2*IF_Hb_Sb)
# vrxn9 = vol*(rxn9_k1*IF_Hb-rxn9_k2*IF*H_in)
# vrxn10 = vol*(rxn10_k1*IF*S_in-rxn10_k2*IF_Sb)
# vrxn11 = vol*(rxn11_k1*IF_Hb*S_in-rxn11_k2*IF_Hb_Sb)
# vrxn12 = vol*(rxn12_k1*IF_Hb_Sb-rxn12_k2*IF_Sb*H_in)

# dOF/dt = vrxn1 - vrxn2 + vrxn3
# dOF_Hb/dt = vrxn2 - vrxn4 + vrxn5
# dIF_Hb/dt = vrxn4 - vrxn9 - vrxn11
# dS_in/dt = -vrxn10 - vrxn11
# dIF_Hb_Sb/dt = vrxn8 + vrxn11 - vrxn12
# dH_in/dt = vrxn9 + vrxn12
# dIF_Sb/dt = -vrxn6 + vrxn10 + vrxn12
# dOF_Sb/dt = -vrxn3 + vrxn6 - vrxn7
# dIF/dt = -vrxn1 + vrxn9 - vrxn10
# dOF_Hb_Sb/dt = -vrxn5 + vrxn7 - vrxn8


def transporter(c, k):
    
    
    
    
    pass


SyntaxError: cannot assign to expression here. Maybe you meant '==' instead of '='? (1640022148.py, line 16)