318318We'll use the same iterative approach to solving the Bellman equations that we
319319adopted in the {doc}` first job search lecture <mccall_model> ` .
320320
321- Since we have reduced the problem to a single scalar equation {eq}` bell_scalar ` ,
322- we only need to iterate on $h$.
321+ In this case we only need to iterate on the single scalar equation {eq}` bell_scalar ` .
323322
324323The iteration rule is
325324
@@ -332,18 +331,9 @@ h_{n+1} = u(c) + \beta \sum_{w' \in \mathbb W}
332331
333332starting from some initial condition $h_0$.
334333
335- Once convergence is achieved, we can compute $v_e$ from {eq}` v_e_closed ` :
334+ Once convergence is achieved, we can compute $v_e$ from {eq}` v_e_closed ` .
336335
337- ``` {math}
338- :label: bell_v_e_final
339-
340- v_e(w) = \frac{u(w) + \alpha(h - u(c))}{1 - \beta(1-\alpha)}
341- ```
342-
343- This approach is simpler than iterating on both $h$ and $v_e$ simultaneously, as
344- we now only need to track a single scalar value.
345-
346- (Convergence can be established via the Banach contraction mapping theorem.)
336+ (It is possible to prove that {eq}` bell_iter ` converges via the Banach contraction mapping theorem.)
347337
348338## Implementation
349339
@@ -405,28 +395,32 @@ def update_h(model, h):
405395 v_e = compute_v_e(model, h)
406396 h_new = u(c) + β * (jnp.maximum(v_e, h) @ q)
407397 return h_new
398+ ```
399+
400+ Using this iteration rule, we can write our model solver.
408401
402+ ``` {code-cell} ipython3
409403@jax.jit
410404def solve_model(model, tol=1e-5, max_iter=2000):
411405 " Iterates to convergence on the Bellman equations. "
412406
413- def cond_fun(state ):
414- h, i, error = state
407+ def cond(loop_state ):
408+ h, i, error = loop_state
415409 return jnp.logical_and(error > tol, i < max_iter)
416410
417- def body_fun(state ):
418- h, i, error = state
411+ def update(loop_state ):
412+ h, i, error = loop_state
419413 h_new = update_h(model, h)
420414 error_new = jnp.abs(h_new - h)
421415 return h_new, i + 1, error_new
422416
423- # Initial state: (h, i, error)
417+ # Initialize
424418 h_init = u(model.c) / (1 - model.β)
425419 i_init = 0
426420 error_init = tol + 1
427-
428421 init_state = (h_init, i_init, error_init)
429- final_state = jax.lax.while_loop(cond_fun, body_fun, init_state)
422+
423+ final_state = jax.lax.while_loop(cond, update, init_state)
430424 h_final, _, _ = final_state
431425
432426 # Compute v_e from the converged h
@@ -461,7 +455,11 @@ plt.show()
461455
462456The value $v_e$ is increasing because higher $w$ generates a higher wage flow conditional on staying employed.
463457
464- ### The Reservation Wage: Computation
458+
459+ The reservation wage is the $w$ where these lines meet.
460+
461+
462+ ### Computing the Reservation Wage
465463
466464Here's a function ` compute_reservation_wage ` that takes an instance of ` Model `
467465and returns the associated reservation wage.
@@ -483,6 +481,8 @@ def compute_reservation_wage(model):
483481
484482Next we will investigate how the reservation wage varies with parameters.
485483
484+
485+
486486## Impact of Parameters
487487
488488In each instance below, we'll show you a figure and then ask you to reproduce it in the exercises.
0 commit comments