Skip to content

Commit ca4439b

Browse files
committed
Fix parameter order in solve_model function calls in ifp_advanced
The solve_model function signature expects (ifp, c_init, a_init) and returns (c_out, a_out), but all call sites were passing parameters in the wrong order (a_init, c_init) and expecting returns in the wrong order (a_out, c_out). This fixes all 5 call sites in the lecture to use the correct parameter and return order: - Line 490: Fixed initial solve_model call - Line 497: Fixed timed solve_model call - Line 645: Fixed simulation section call - Line 737: Fixed return volatility loop call - Line 814: Fixed income volatility loop call Fixes #759
1 parent b90fc02 commit ca4439b

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

lectures/ifp_advanced.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -487,15 +487,15 @@ a_init = σ_init.copy()
487487
Let's generate an approximation solution with JAX:
488488

489489
```{code-cell} ipython3
490-
a_star, σ_star = solve_model(ifp, a_init, σ_init)
490+
σ_star, a_star = solve_model(ifp, σ_init, a_init)
491491
```
492492

493493
Let's try it again with a timer.
494494

495495
```{code-cell} python3
496496
with qe.Timer(precision=8):
497-
a_star, σ_star = solve_model(ifp, a_init, σ_init)
498-
a_star.block_until_ready()
497+
σ_star, a_star = solve_model(ifp, σ_init, a_init)
498+
σ_star.block_until_ready()
499499
```
500500

501501
## Simulation
@@ -642,7 +642,7 @@ s_grid = ifp.s_grid
642642
n_z = len(ifp.P)
643643
a_init = s_grid[:, None] * jnp.ones(n_z)
644644
c_init = a_init
645-
a_vec, c_vec = solve_model(ifp, a_init, c_init)
645+
c_vec, a_vec = solve_model(ifp, c_init, a_init)
646646
assets = compute_asset_stationary(c_vec, a_vec, ifp, num_households=200_000)
647647
648648
# Compute Gini coefficient for the plot
@@ -734,8 +734,8 @@ for a_r in a_r_vals:
734734
n_z_temp = len(ifp_temp.P)
735735
a_init_temp = s_grid_temp[:, None] * jnp.ones(n_z_temp)
736736
c_init_temp = a_init_temp
737-
a_vec_temp, c_vec_temp = solve_model(
738-
ifp_temp, a_init_temp, c_init_temp
737+
c_vec_temp, a_vec_temp = solve_model(
738+
ifp_temp, c_init_temp, a_init_temp
739739
)
740740
741741
# Simulate households
@@ -811,8 +811,8 @@ for a_y in a_y_vals:
811811
n_z_temp = len(ifp_temp.P)
812812
a_init_temp = s_grid_temp[:, None] * jnp.ones(n_z_temp)
813813
c_init_temp = a_init_temp
814-
a_vec_temp, c_vec_temp = solve_model(
815-
ifp_temp, a_init_temp, c_init_temp
814+
c_vec_temp, a_vec_temp = solve_model(
815+
ifp_temp, c_init_temp, a_init_temp
816816
)
817817
818818
# Simulate households

0 commit comments

Comments
 (0)