# Lotka-Volterra

![](./lotka_volterra.png)

# Jax Concepts Used

- `jit` -> speed up jax code
- `fori_loop` -> jax native `for-loop` that speeds up computation. Useful when you're `jit`-ing an entire function that has a for-loop. See [exe_03_primitives.ipynb](../../exercises/exe_03_primitives.ipynb) for a further discussion
- `scan` -> like a `fori_loop`, but we can keep and pass along state. 

In [None]:
import numpy as np
import jax.numpy as jnp
import jax
from functools import partial

In [None]:
"""
Notes:
    1) We added in partial function to pass the static_argnames to the jit 
    2) we added in static_argnames to specify that these are constant and should not result in a recompilation
"""
@partial(jax.jit, static_argnames=["alpha", "beta", "gamma", "delta", "dt"])
def _lotka_volterra_step(
    x, y,
    alpha, beta, gamma, delta, dt
):
    
    dxdt = alpha * x - beta * x * y
    dydt = delta * x * y - gamma * y
    
    x_new = x + dxdt * dt
    y_new = y + dydt * dt
    
    return x_new, y_new

In [None]:
# Parameters
alpha = 1.1
beta = 0.4
gamma = 0.4
delta = 0.1
dt = 0.1
num_steps = 20


lotka_volterra_step = partial(
    _lotka_volterra_step,
    alpha=alpha, beta=beta, gamma=gamma, delta=delta, dt=dt
)

# Initial populations
x_prev = 10.0
y_prev = 5.0



In [None]:
xs = []
ys = []
for i in range(num_steps):
    x_new, y_new = lotka_volterra_step(
        x_prev, y_prev
    )

    xs.append(x_new)
    ys.append(y_new)

    x_prev = x_new
    y_prev = y_new


In [None]:
for (x, y) in zip(xs, ys):
    print(f"x: {x:.3f}, y: {y:.3f}")


# fori-loop

We need to redefine our lotka_volterra function to make it compatible with jax. There are two things we can do:

1) get only the last value
2) get the trajectory

## `fori_loop`: Only getting the last value

In [None]:
@partial(jax.jit, static_argnames=["alpha", "beta", "gamma", "delta", "dt"])
def _lotka_volterra_step(
    _, state,
    alpha, beta, gamma, delta, dt
):
    x, y = state 
    dxdt = alpha * x - beta * x * y
    dydt = delta * x * y - gamma * y
    
    x_new = x + dxdt * dt
    y_new = y + dydt * dt
    
    return x_new, y_new


lotka_volterra_step = partial(
    _lotka_volterra_step,
    alpha=alpha, beta=beta, gamma=gamma, delta=delta, dt=dt
)

# Initial populations
x_prev = 10.0
y_prev = 5.0

# If we're only concerned with having the last value, we can do the following
x, y = jax.lax.fori_loop(
    lower=0, upper=num_steps,
    body_fun=lotka_volterra_step, init_val=((x_prev, y_prev))
)
print(f"x: {x:.3f}, y: {y:.3f}")


## `fori_loop`: Getting the trajectory

We store the trajectories as an array that we pass in. Note that this method is likely **slow** because we are manually indexing in. From a functional programming standpoint, what we're doing is less natural, too, as we can more elegantly express this in a `scan`, that we see later


In [None]:
@partial(jax.jit, static_argnames=["alpha", "beta", "gamma", "delta", "dt"])
def _lotka_volterra_step(
        trajectory_idx, state,
        alpha, beta, gamma, delta, dt
):
    x, y = state[trajectory_idx]
    dxdt = alpha * x - beta * x * y
    dydt = delta * x * y - gamma * y

    x_new = x + dxdt * dt
    y_new = y + dydt * dt
    
    state = state.at[trajectory_idx + 1].set([x_new, y_new])
    return state


lotka_volterra_step = partial(
    _lotka_volterra_step,
    alpha=alpha, beta=beta, gamma=gamma, delta=delta, dt=dt
)

# Initial populations
x_prev = 10.0
y_prev = 5.0

trajectory = jnp.zeros((num_steps+1, 2))
trajectory = trajectory.at[0].set([x_prev, y_prev])


# If we're only concerned with having the last value, we can do the following
trajectory = jax.lax.fori_loop(
    lower=0, upper=num_steps,
    body_fun=lotka_volterra_step, init_val=(trajectory)
)

for i in range(len(np.asarray(trajectory))):
    x = trajectory[i, 0]
    y = trajectory[i, 1]
    print(f"x: {x:.3f}, y: {y:.3f}")


# Scan

This is similar to the `fori_loop` where we store the state, except we do not need to use `state = state.at[unused_i+1].set([x_new, y_new])` to store the trajectory. 
    
This is a more "natural" way to pass along the information.

Documentation: [jax.lax.scan](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)


In [None]:
@partial(jax.jit, static_argnames=["alpha", "beta", "gamma", "delta", "dt"])
def _lotka_volterra_step(
        state, _,
        alpha, beta, gamma, delta, dt
):
    x, y = state
    dxdt = alpha * x - beta * x * y
    dydt = delta * x * y - gamma * y

    x_new = x + dxdt * dt
    y_new = y + dydt * dt

    return (x_new, y_new), (x_new, y_new)


lotka_volterra_step = partial(
    _lotka_volterra_step,
    alpha=alpha, beta=beta, gamma=gamma, delta=delta, dt=dt
)

# Initial populations
x_prev = 10.0
y_prev = 5.0

final_state, trajectory = jax.lax.scan(
    f=lotka_volterra_step,
    init=(x_prev, y_prev),
    xs=None,
    length=num_steps
)



In [None]:
np.asarray(final_state)

In [None]:
for (_x, _y) in np.asarray(trajectory).T:
    print(_x, _y)