Skip to content

Commit bf68305

Browse files
jstacclaude
andcommitted
Use vmap strategy for T operator to match T_σ implementation
- Replace vectorized B with vmap-based B(v, model, i, j, ip) - Add staged vmap application: B_1, B_2, B_vmap - Update T and get_greedy to use B_vmap with index arrays - Consistent with T_σ implementation which also uses vmap - Performance: ~6.7x speedup (slightly better than vectorized version) This makes the codebase more consistent by using the same vmap strategy for both the Bellman operator and the policy operator. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 966f0d4 commit bf68305

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

lectures/ifp_opi.md

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -109,49 +109,50 @@ We repeat some functions from {doc}`ifp_discrete`.
109109
Here is the right hand side of the Bellman equation:
110110

111111
```{code-cell} ipython3
112-
def B(v, model):
112+
def B(v, model, i, j, ip):
113113
"""
114-
A vectorized version of the right-hand side of the Bellman equation
115-
(before maximization), which is a 3D array representing
114+
The right-hand side of the Bellman equation before maximization, which takes
115+
the form
116116
117117
B(a, y, a′) = u(Ra + y - a′) + β Σ_y′ v(a′, y′) Q(y, y′)
118118
119-
for all (a, y, a′).
119+
The indices are (i, j, ip) -> (a, y, a′).
120120
"""
121-
122-
# Unpack
123121
β, R, γ, a_grid, y_grid, Q = model
124-
a_size, y_size = len(a_grid), len(y_grid)
125-
126-
# Compute current rewards r(a, y, ap) as array r[i, j, ip]
127-
a = jnp.reshape(a_grid, (a_size, 1, 1)) # a[i] -> a[i, j, ip]
128-
y = jnp.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip]
129-
ap = jnp.reshape(a_grid, (1, 1, a_size)) # ap[ip] -> ap[i, j, ip]
122+
a, y, ap = a_grid[i], y_grid[j], a_grid[ip]
130123
c = R * a + y - ap
124+
EV = jnp.sum(v[ip, :] * Q[j, :])
125+
return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
126+
```
131127

132-
# Calculate continuation rewards at all combinations of (a, y, ap)
133-
v = jnp.reshape(v, (1, 1, a_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp]
134-
Q = jnp.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp]
135-
EV = jnp.sum(v * Q, axis=3) # sum over last index jp
128+
Now we successively apply `vmap` to vectorize over all indices:
136129

137-
# Compute the right-hand side of the Bellman equation
138-
return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
130+
```{code-cell} ipython3
131+
B_1 = jax.vmap(B, in_axes=(None, None, None, None, 0))
132+
B_2 = jax.vmap(B_1, in_axes=(None, None, None, 0, None))
133+
B_vmap = jax.vmap(B_2, in_axes=(None, None, 0, None, None))
139134
```
140135

141136
Here's the Bellman operator:
142137

143138
```{code-cell} ipython3
144139
def T(v, model):
145140
"The Bellman operator."
146-
return jnp.max(B(v, model), axis=2)
141+
a_indices = jnp.arange(len(model.a_grid))
142+
y_indices = jnp.arange(len(model.y_grid))
143+
B_values = B_vmap(v, model, a_indices, y_indices, a_indices)
144+
return jnp.max(B_values, axis=-1)
147145
```
148146

149147
Here's the function that computes a $v$-greedy policy:
150148

151149
```{code-cell} ipython3
152150
def get_greedy(v, model):
153151
"Computes a v-greedy policy, returned as a set of indices."
154-
return jnp.argmax(B(v, model), axis=2)
152+
a_indices = jnp.arange(len(model.a_grid))
153+
y_indices = jnp.arange(len(model.y_grid))
154+
B_values = B_vmap(v, model, a_indices, y_indices, a_indices)
155+
return jnp.argmax(B_values, axis=-1)
155156
```
156157

157158
Now we define the policy operator $T_\sigma$, which is the Bellman operator with

0 commit comments

Comments
 (0)