# JAX Speed Test

This notebook compares a pure Python/Numpy loop with a JAX-jitted version.

In [None]:
import time
import numpy as np
import jax
import jax.numpy as jnp

def python_fn(x):
    for _ in range(1000):
        x = np.sin(x) + np.cos(x)
    return x

def jax_fn(x):
    for _ in range(1000):
        x = jnp.sin(x) + jnp.cos(x)
    return x

jit_fn = jax.jit(jax_fn)

x_np = np.random.rand(1000)
x_jax = jnp.array(x_np)

start = time.time()
python_fn(x_np)
python_time = time.time() - start

start = time.time()
jit_fn(x_jax).block_until_ready()
jax_time = time.time() - start

python_time, jax_time