Skip to content

Commit 966f0d4

Browse files
jstacclaude
andcommitted
Remove @jax.jit decorators from intermediate functions for code simplicity
- Remove @jax.jit from B, T, get_greedy, T_σ_vec, and iterate_policy_operator - Keep @jax.jit on main solver functions (value_function_iteration, optimistic_policy_iteration) - Performance testing shows no significant difference (within measurement noise) - Simplifies code while maintaining ~6x OPI speedup over VFI 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 79c9099 commit 966f0d4

File tree

1 file changed

+0
-5
lines changed

1 file changed

+0
-5
lines changed

lectures/ifp_opi.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ We repeat some functions from {doc}`ifp_discrete`.
109109
Here is the right hand side of the Bellman equation:
110110

111111
```{code-cell} ipython3
112-
@jax.jit
113112
def B(v, model):
114113
"""
115114
A vectorized version of the right-hand side of the Bellman equation
@@ -142,7 +141,6 @@ def B(v, model):
142141
Here's the Bellman operator:
143142

144143
```{code-cell} ipython3
145-
@jax.jit
146144
def T(v, model):
147145
"The Bellman operator."
148146
return jnp.max(B(v, model), axis=2)
@@ -151,7 +149,6 @@ def T(v, model):
151149
Here's the function that computes a $v$-greedy policy:
152150

153151
```{code-cell} ipython3
154-
@jax.jit
155152
def get_greedy(v, model):
156153
"Computes a v-greedy policy, returned as a set of indices."
157154
return jnp.argmax(B(v, model), axis=2)
@@ -194,7 +191,6 @@ Apply vmap to vectorize:
194191
T_σ_1 = jax.vmap(T_σ, in_axes=(None, None, None, None, 0))
195192
T_σ_vmap = jax.vmap(T_σ_1, in_axes=(None, None, None, 0, None))
196193
197-
@jax.jit
198194
def T_σ_vec(v, σ, model):
199195
"""Vectorized version of T_σ."""
200196
a_size, y_size = len(model.a_grid), len(model.y_grid)
@@ -206,7 +202,6 @@ def T_σ_vec(v, σ, model):
206202
Now we need a function to apply the policy operator m times:
207203

208204
```{code-cell} ipython3
209-
@jax.jit
210205
def iterate_policy_operator(σ, v, m, model):
211206
"""
212207
Apply the policy operator T_σ exactly m times to v.

0 commit comments

Comments
 (0)