In [13]:
import timeit
import numpy as np
import math
import pendulums
from pendulums import PendulumMetadata

# Create initial state
theta = np.array([math.pi * 0.75, math.pi / 2])
n = len(theta)
omega = np.zeros(n)
omega[0] = math.pi
state0 = np.concatenate((theta, omega))

# Create pendulum metadata
masses = np.ones(n)
lengths = np.ones(n)
lengths[1] = 1.5
metadata = PendulumMetadata(masses=masses, lengths=lengths)

dt = 0.01

In [10]:
%timeit pendulums.rk4_step_np(pendulums.n_pendulum_ode_np, state0, 0, dt, metadata)

677 μs ± 20.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
%timeit pendulums.velocity_verlet_step_np(pendulums.n_pendulum_ode_np, state0, 0, dt, metadata)

336 μs ± 7.09 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
from jax import numpy as jnp

r_jax = jnp.array(metadata.lengths)
m_jax = jnp.array(metadata.masses)
state_jax = jnp.array(state0)

pendulums.n_pendulum_ode_jax(0.0, state_jax, r_jax, m_jax, n)  # Warm-up JAX compilation
pendulums.double_pendulum_ode_jax(0.0, state_jax, r_jax, m_jax)

Array([ 3.1415927,  0.       , -8.193671 ,  2.9626648], dtype=float32)

In [15]:
%timeit pendulums.n_pendulum_ode_jax(0.0, state_jax, r_jax, m_jax, n)

180 μs ± 14.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
%timeit pendulums.double_pendulum_ode_jax(0.0, state_jax, r_jax, m_jax)

232 μs ± 45.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
