## "For" loops

In [8]:
import numpy as np
import numba

@numba.jit(nopython=True)
def numba_loops(arr):
    n = arr.shape[0]
    result = 0.0
    for i in range(n):
        result += arr[i] ** 2
    return result

In [9]:
%%timeit
arr = np.random.rand(1000000)
numba_loops(arr)  # Very fast

10.4 ms ± 533 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
import jax.numpy as jnp
from jax import jit

@jit
def jax_loops(arr):
    n = arr.shape[0]
    result = 0.0
    for i in range(n):  # JAX does not optimize this well
        result += arr[i] ** 2
    return result

In [11]:
%%time
arr = jnp.array(np.random.rand(100))
print(jax_loops(arr))  # Not as fast as Numba

34.125214
CPU times: total: 31.2 ms
Wall time: 157 ms


In [12]:
def python_loops(arr):
    n = arr.shape[0]
    result = 0.0
    for i in range(n):  # JAX does not optimize this well
        result += arr[i] ** 2
    return result

In [13]:
%%time
arr = jnp.array(np.random.rand(10**4))
print(python_loops(arr))  # Not as fast as Numba

3322.3618
CPU times: total: 719 ms
Wall time: 1.61 s


Very bad results when using Python for and JAX

## ODE solver

#### JAX approach

In [14]:
import jax.numpy as jnp
import jax.lax as lax
from jax import jit

def f(x, t):  
    return -x  # Example: Simple exponential decay dx/dt = -x

def step(carry, t):
    x, h = carry
    x_new = x + h * f(x, t)  # Euler step
    return (x_new, h), x_new  # (carry, output)

In [17]:
@jit
def solve_euler(x0, h, t_array):
    carry = (x0, h)
    carry, x_values = lax.scan(step, carry, t_array)
    return x_values  # Solution for all timesteps

In [85]:
n_steps = 10**5

In [86]:
t_array = jnp.linspace(0, 10, n_steps)  # Time steps
h = t_array[1] - t_array[0]  # Step size
x0 = jnp.array(1.0)  # Initial condition

In [87]:
%%timeit

solution = solve_euler(x0, h, t_array)

179 µs ± 1.97 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


#### Naïve approach

In [88]:
def euler_naive(x0, h, t_array):
    x = np.zeros(len(t_array))
    x[0] = x0
    for i, t in enumerate(t_array[:-1]):
        x[i + 1] = x[i] + h * f(x[i], t)  # Euler update

    return np.array(x)

In [89]:
%%timeit

solution = euler_naive(x0, h, t_array)

3.05 s ± 101 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### Numba

In [90]:
import numpy as np
import numba

@numba.njit
def solve_euler_numba(x0, h, t_array):
    n = t_array.shape[0]
    x_values = np.empty(n, dtype=np.float64)  
    x_values[0] = x0
    
    for i in range(n - 1):
        x_values[i + 1] = x_values[i] + h * (-x_values[i])  # dx/dt = -x

    return x_values

In [93]:
t_array_np = np.linspace(0, 10, n_steps)  # Time steps
h_np = t_array_np[1] - t_array_np[0]  # Step size
x0_np = np.array(1.0)  # Initial condition

In [94]:
%%timeit
solution = solve_euler_numba(x0_np, h_np, t_array_np)

246 µs ± 4.55 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Jax gives the best results for many timesteps when used wisely