In [1]:
import jax.numpy as jnp
from jax import grad, jit, lax

In [2]:
# Runge-Kuta
def rk4(f, t0, t1, steps, y0):
    T = jnp.linspace(t0, t1, steps, endpoint=False)  # time stamps
    h = (t1 - t0) / steps

    @jit
    def step(carry, t):
        y_prev = carry

        k1 = h * f(t, y_prev)
        k2 = h * f(t + h/2, y_prev + k1/2)
        k3 = h * f(t + h/2, y_prev + k2/2)
        k4 = h * f(t + h/2, y_prev + k3)

        new_y = y_prev + (k1 + 2*k2 + 2*k3 + k4)/6

        return new_y, ()

    y1, _ = lax.scan(step, y0, T)

    return y1

In [4]:
f = lambda t, y: t*y
t0 = jnp.array([0., 1., 3.])
t1 = jnp.sqrt(2)
steps = 1000
y0 = jnp.array([1., 4., 7.])

import time
s = time.time()
y1 = rk4(f, t0, t1, steps, y0)
print(f'time: {time.time() - s} s')
print(y1)

time: 0.02400350570678711 s
[2.7178273  6.5947914  0.21133728]
