In [1]:
import timeit
import numpy as np
import math
import pendulums
from pendulums import PendulumMetadata

# Create initial state
theta = np.array([math.pi * 0.75, math.pi / 2])
n = len(theta)
omega = np.zeros(n)
omega[0] = math.pi
state0 = np.concatenate((theta, omega))

# Create pendulum metadata
masses = np.ones(n)
lengths = np.ones(n)
lengths[1] = 1.5
metadata = PendulumMetadata(masses=masses, lengths=lengths)

dt = 0.01

In [2]:
%timeit pendulums.rk4_step_np(pendulums.n_pendulum_ode_np, state0, 0, dt, metadata)

339 μs ± 55.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [3]:
%timeit pendulums.velocity_verlet_step_np(pendulums.n_pendulum_ode_np, state0, 0, dt, metadata)

195 μs ± 20.8 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
from jax import numpy as jnp

r_jax = jnp.array(metadata.lengths)
m_jax = jnp.array(metadata.masses)
state_jax = jnp.array(state0)

pendulums.n_pendulum_ode_jax(0.0, state_jax, r_jax, m_jax, n)  # Warm-up JAX compilation
pendulums.double_pendulum_ode_jax(0.0, state_jax, r_jax, m_jax)

W1206 14:45:39.223408   27104 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1206 14:45:39.226516   26976 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


Array([ 3.1415927,  0.       , -8.193671 ,  2.9626658], dtype=float32)

In [5]:
%timeit pendulums.n_pendulum_ode_jax(0.0, state_jax, r_jax, m_jax, n)

252 μs ± 41.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
%timeit pendulums.double_pendulum_ode_jax(0.0, state_jax, r_jax, m_jax)

191 μs ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [13]:
import jax
from k_n_pendulum import KN_Pendulum_JAX_vispy
rk4_step_batch_n = jax.jit(jax.vmap(pendulums.n_pendulum_rk4_step_jax,
                                               in_axes=(0, None, None, None, None, None)),
                                      static_argnums=(5,))
rk4_step_batch_double = jax.jit(jax.vmap(pendulums.double_pendulum_rk4_step_jax,
                                               in_axes=(0, None, None, None, None)))

k = 1000*1000
states = jnp.tile(state0, (k, 1))
key = jax.random.PRNGKey(0)
perturbations = jax.random.uniform(key, shape=(k, state0.shape[0]),
                                minval=-KN_Pendulum_JAX_vispy.P_MAGNITUDE,
                                maxval=KN_Pendulum_JAX_vispy.P_MAGNITUDE)
states = states.at[:, :n].add(perturbations[:, :n])

rk4_step_batch_n(states, 0.0, dt, r_jax, m_jax, n);  # Warm-up JAX compilation
rk4_step_batch_double(states, 0.0, dt, r_jax, m_jax);  # Warm-up JAX compilation

In [None]:
# Still very little difference in performance between n-pendulum and double pendulum
%timeit -n 100 rk4_step_batch_n(states, 0.0, dt, r_jax, m_jax, n)
%timeit -n 100 rk4_step_batch_double(states, 0.0, dt, r_jax, m_jax)

13.1 ms ± 626 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
12.7 ms ± 8.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
rk4_step_batch_n_2D = jax.jit(jax.vmap(jax.vmap(pendulums.n_pendulum_rk4_step_jax,
                                               in_axes=(0, None, None, None, None, None)),
                                               in_axes=(0, None, None, None, None, None)),
                                      static_argnums=(5,))
rk4_step_batch_double_2D = jax.jit(jax.vmap(jax.vmap(pendulums.double_pendulum_rk4_step_jax,
                                               in_axes=(0, None, None, None, None)),
                                               in_axes=(0, None, None, None, None)))

k_2D = np.sqrt(k).astype(int)
theta1 = jnp.linspace(-2*np.pi, 2*np.pi, k_2D)
theta2 = jnp.linspace(2*np.pi, 2*np.pi, k_2D)
theta1_grid, theta2_grid = jnp.meshgrid(theta1, theta2, indexing='xy')
states_2D = jnp.zeros(
    shape=(k_2D, k_2D, 4))
states_2D = states_2D.at[:, :, 0].set(theta1_grid)
states_2D = states_2D.at[:, :, 1].set(theta2_grid)

rk4_step_batch_n_2D(states_2D, 0.0, dt, r_jax, m_jax, n);  # Warm-up JAX compilation
rk4_step_batch_double_2D(states_2D, 0.0, dt, r_jax, m_jax);  # Warm-up JAX compilation

In [None]:
# No measured difference in double vmap of flat array vs. 2D array (1e6 vs 1e3 x 1e3)
%timeit -n 100 rk4_step_batch_n_2D(states_2D, 0.0, dt, r_jax, m_jax, n)
%timeit -n 100 rk4_step_batch_double_2D(states_2D, 0.0, dt, r_jax, m_jax)

13.1 ms ± 459 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
12.7 ms ± 17.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
def rk4_loop(states_2D, x):
    states_loop = states_2D
    for _ in range(x):
        states_loop = rk4_step_batch_n_2D(states_loop, 0.0, dt, r_jax, m_jax, n)
    return states_loop

@jax.jit(static_argnames=['x'])
def jax_loop(states_2D, x):
    def body(_, s):
        return rk4_step_batch_n_2D(s, 0.0, dt, r_jax, m_jax, n)
    return jax.lax.fori_loop(0, x, body, states_2D)


rk4_loop(states_2D, 1);
jax_loop(states_2D, 1);

In [None]:
# JAX fori_loop is NOT faster.
%timeit -n 10 rk4_loop(states_2D, 6)
%timeit -n 10 jax_loop(states_2D, 6)

79.7 ms ± 6.48 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
80.8 ms ± 2.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
