From 50de16b59e2d6d8cf3997d865a7230fa17513af8 Mon Sep 17 00:00:00 2001 From: mmcky Date: Fri, 28 Nov 2025 10:03:50 +1100 Subject: [PATCH 1/3] Fix jax_intro timeout: use lax.fori_loop instead of Python for loop The compute_call_price_jax function was timing out during cache.yml builds because JAX unrolls Python for loops during JIT compilation. With large arrays (M=10M), this causes excessive compilation time. Solution: Replace Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling. Fixes cell execution timeout in jax_intro.md --- lectures/jax_intro.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index b4114630..641bbe5e 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -832,16 +832,28 @@ def compute_call_price_jax(β=β, s = jnp.full(M, np.log(S0)) h = jnp.full(M, h0) - for t in range(n): + + def loop_body(i, state): + s, h, key = state key, subkey = jax.random.split(key) Z = jax.random.normal(subkey, (2, M)) s = s + μ + jnp.exp(h) * Z[0, :] h = ρ * h + ν * Z[1, :] + return s, h, key + + s, h, key = jax.lax.fori_loop(0, n, loop_body, (s, h, key)) + expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0)) return β**n * expectation ``` +```{note} +We use `jax.lax.fori_loop` instead of a Python `for` loop. +This allows JAX to compile the loop efficiently without unrolling it, +which significantly reduces compilation time for large arrays. +``` + Let's run it once to compile it: ```{code-cell} ipython3 From afc2d365bea93bab11c6152ea052607ce386ea5b Mon Sep 17 00:00:00 2001 From: mmcky Date: Fri, 28 Nov 2025 10:28:47 +1100 Subject: [PATCH 2/3] style: use jstac's fori_loop naming conventions - loop_body -> update - state -> loop_state - Added explicit new_loop_state and final_loop_state variables - More verbose but clearer for first-time fori_loop readers --- lectures/jax_intro.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 641bbe5e..1890e74d 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -833,15 +833,18 @@ def compute_call_price_jax(β=β, s = jnp.full(M, np.log(S0)) h = jnp.full(M, h0) - def loop_body(i, state): - s, h, key = state + def update(i, loop_state): + s, h, key = loop_state key, subkey = jax.random.split(key) Z = jax.random.normal(subkey, (2, M)) s = s + μ + jnp.exp(h) * Z[0, :] h = ρ * h + ν * Z[1, :] - return s, h, key + new_loop_state = s, h, key + return new_loop_state - s, h, key = jax.lax.fori_loop(0, n, loop_body, (s, h, key)) + loop_state = s, h, key + final_loop_state = jax.lax.fori_loop(0, n, update, loop_state) + s, h, key = final_loop_state expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0)) From 9bceb2d135befba14faf2c11ad95aae9009f4510 Mon Sep 17 00:00:00 2001 From: mmcky Date: Fri, 28 Nov 2025 10:41:03 +1100 Subject: [PATCH 3/3] style: loop_state -> initial_loop_state --- lectures/jax_intro.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 1890e74d..f8fe265d 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -842,8 +842,8 @@ def compute_call_price_jax(β=β, new_loop_state = s, h, key return new_loop_state - loop_state = s, h, key - final_loop_state = jax.lax.fori_loop(0, n, update, loop_state) + initial_loop_state = s, h, key + final_loop_state = jax.lax.fori_loop(0, n, update, initial_loop_state) s, h, key = final_loop_state expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))