Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion lectures/jax_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -832,16 +832,31 @@ def compute_call_price_jax(β=β,

s = jnp.full(M, np.log(S0))
h = jnp.full(M, h0)
for t in range(n):

def update(i, loop_state):
s, h, key = loop_state
key, subkey = jax.random.split(key)
Z = jax.random.normal(subkey, (2, M))
s = s + μ + jnp.exp(h) * Z[0, :]
h = ρ * h + ν * Z[1, :]
new_loop_state = s, h, key
return new_loop_state

initial_loop_state = s, h, key
final_loop_state = jax.lax.fori_loop(0, n, update, initial_loop_state)
s, h, key = final_loop_state

expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))

return β**n * expectation
```

```{note}
We use `jax.lax.fori_loop` instead of a Python `for` loop.
This allows JAX to compile the loop efficiently without unrolling it,
which significantly reduces compilation time for large arrays.
```

Let's run it once to compile it:

```{code-cell} ipython3
Expand Down
Loading