In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".20"

In [2]:
import numpy as np
import jax.numpy as jnp
from jax import jit, random

In [3]:
@jit
def j_single_deltaW(x, ts):
    """Generate Wiener increments for m independent Wiener
    From the sdeint implementation
    The result is multiplied by the ts as low time step
    means that noise shall not affect a lot
    """
    seed = np.random.randint(4294967295)
    key = random.PRNGKey(seed)
    return random.normal(key, x.shape) * ts


@jit
def j_noise(x, c):
    # Controls noise proportional to
    # square root of activity
    return c * jnp.sqrt(jnp.abs(x))


# TODO generate this function by predetermining the shape
@jit
def j_step_euler_sde(quantities, derivatives, noise_param, ts):
    # d = len(quantities)
    dW = j_single_deltaW(quantities, ts)
    y_ = quantities \
         + derivatives * ts \
         + jnp.multiply(j_noise(quantities, noise_param), dW)
    # Ensure positive terms
    y_ = jnp.maximum(y_, 0)

    return y_

In [4]:
q = jnp.zeros(5)
d = jnp.zeros(5)
n = jnp.zeros(5)
ts = 0.1

In [5]:
j_step_euler_sde(q, d, n, ts)

DeviceArray([0., 0., 0., 0., 0.], dtype=float32)

# Must RERUN from scratch

In [6]:
@jit
def test_batch(batch):
    return [b * 3 for b in batch]

In [7]:
def test_batch_down():
    batch = [jnp.zeros(10) for i in range(100)]
    for i in range(1000):
        test_batch(batch[:i])

In [8]:
def test_batch_constant():
    batch = [jnp.zeros(10) for i in range(100)]
    for i in range(100):
        test_batch(batch)

In [9]:
import time

In [10]:
b = time.time()
test_batch_constant()
print(time.time() - b)

0.2887866497039795


In [11]:
b = time.time()
test_batch_down()
print(time.time() - b)

13.557604312896729
