@@ -49,7 +49,7 @@ libraries
4949``` {code-cell} ipython3
5050:tags: [hide-output]
5151
52- !pip install quantecon
52+ !pip install quantecon jax
5353```
5454
5555We use the following imports:
@@ -64,7 +64,7 @@ import matplotlib.pyplot as plt
6464from functools import partial
6565```
6666
67- ## Model Setup
67+ ## Model setup
6868
6969The setting is as follows:
7070
@@ -74,7 +74,7 @@ The setting is as follows:
7474- Unemployed workers receive compensation $c$ per period
7575- Future payoffs are discounted by factor $\beta \in (0,1)$
7676
77- ### Decision Problem
77+ ### Decision problem
7878
7979When unemployed and receiving wage offer $w$, the agent chooses between:
8080
@@ -86,7 +86,7 @@ The wage updates are as follows:
8686* If an unemployed agent rejects offer $w$, then their next offer is drawn from $P(w, \cdot)$
8787* If an employed agent loses a job in which they were paid wage $w$, then their next offer is drawn from $P(w, \cdot)$
8888
89- ### The Wage Offer Process
89+ ### The wage offer process
9090
9191To construct the wage offer process we start with an AR1 process.
9292
@@ -112,7 +112,7 @@ Actually, in practice, we approximate this wage process as follows:
112112
113113
114114
115- ### Value Functions
115+ ### Value functions
116116
117117We let
118118
168168
169169+++
170170
171- ### Optimal Policy
171+ ### Optimal policy
172172
173173Once we have the solutions $v_e$ and $v_u$ to these Bellman equations, we can compute the optimal policy: accept at current wage offer $w$ if
174174
175175$$
176- v_e(w) ≥ u(c) + β(Pv_u )(w)
176+ v_e(w) \geq u(c) + \beta (P v_u )(w)
177177$$
178178
179179The optimal policy turns out to be a reservation wage strategy: accept all wages above some threshold.
@@ -185,7 +185,7 @@ The optimal policy turns out to be a reservation wage strategy: accept all wages
185185
186186Let's now implement the model.
187187
188- ### Set Up
188+ ### Set up
189189
190190The default utility function is a CRRA utility function
191191
@@ -234,7 +234,7 @@ def create_js_with_sep_model(
234234```
235235
236236
237- ### Solution: First Pass
237+ ### Solution: first pass
238238
239239Let's put together a (not very efficient) routine for calculating the
240240reservation wage.
@@ -244,7 +244,7 @@ reservation wage.
244244It works by starting with guesses for $v_e$ and $v_u$ and iterating to
245245convergence.
246246
247- Here's are Bellman operators that update $v_u$ and $v_e$ respectively.
247+ Here are Bellman operators that update $v_u$ and $v_e$ respectively.
248248
249249
250250``` {code-cell} ipython3
@@ -313,7 +313,7 @@ def solve_model_first_pass(
313313```
314314
315315
316- ### Road Test
316+ ### Road test
317317
318318Let's solve the model:
319319
@@ -348,9 +348,9 @@ The reservation wage is at the intersection of $v_e$, and the continuation value
348348function, which is the value of rejecting.
349349
350350
351- ## Improving Efficiency
351+ ## Improving efficiency
352352
353- The solution method desribed above works fine but we can do much better.
353+ The solution method described above works fine but we can do much better.
354354
355355First, we use the employed worker's Bellman equation to express
356356$v_e$ in terms of $Pv_u$
@@ -495,7 +495,7 @@ The result is the same as before but we only iterate on one array --- and also
495495our JAX code is more efficient.
496496
497497
498- ## Sensitivity Analysis
498+ ## Sensitivity analysis
499499
500500Let's examine how reservation wages change with the separation rate.
501501
@@ -523,7 +523,7 @@ Can you provide an intuitive economic story behind the outcome that you see in t
523523
524524+++
525525
526- ## Employment Simulation
526+ ## Employment simulation
527527
528528Now let's simulate the employment dynamics of a single agent under the optimal policy.
529529
@@ -691,15 +691,15 @@ often leads a high new draw.
691691
692692+++
693693
694- ## The Ergodic Property
694+ ## Ergodic property
695695
696696Below we examine cross-sectional unemployment.
697697
698698In particular, we will look at the unemployment rate in a cross-sectional
699699simulation and compare it to the time-average unemployment rate, which is the
700700fraction of time an agent spends unemployed over a long time series.
701701
702- We will see that these two values are approximately equal -- if fact they are
702+ We will see that these two values are approximately equal -- in fact they are
703703exactly equal in the limit.
704704
705705The reason is that the process $(S_t, W_t)$, where
@@ -744,15 +744,15 @@ Often the second approach is better for our purposes, since it's easier to paral
744744
745745+++
746746
747- ## Cross-Sectional Analysis
747+ ## Cross-sectional analysis
748748
749749Now let's simulate many agents simultaneously to examine the cross-sectional unemployment rate.
750750
751751We first create a vectorized version of ` update_agent ` to efficiently update all agents in parallel:
752752
753753``` {code-cell} ipython3
754- # Create vectorized version of update_agent
755- # The last parameter is now w_bar (scalar) instead of σ (array)
754+ # Create vectorized version of update_agent.
755+ # Vectorize over key, status, wage_idx
756756update_agents_vmap = jax.vmap(
757757 update_agent, in_axes=(0, 0, 0, None, None)
758758)
@@ -761,76 +761,72 @@ update_agents_vmap = jax.vmap(
761761Next we define the core simulation function, which uses ` lax.fori_loop ` to efficiently iterate many agents forward in time:
762762
763763``` {code-cell} ipython3
764- @partial( jax.jit, static_argnums=(3, 4))
764+ @jax.jit
765765def _simulate_cross_section_compiled(
766766 key: jnp.ndarray,
767767 model: Model,
768768 w_bar: float,
769- n_agents: int,
769+ initial_wage_indices: jnp.ndarray,
770+ initial_status_vec: jnp.ndarray,
770771 T: int
771772 ):
772- """JIT-compiled core simulation loop using lax.fori_loop.
773- Returns only the final employment state to save memory."""
773+ """
774+ JIT-compiled core simulation loop for shifting the cross section
775+ using lax.fori_loop. Returns the final employment employment status
776+ cross-section.
777+
778+ """
774779 n, w_vals, P, P_cumsum, β, c, α, γ = model
780+ n_agents = len(initial_wage_indices)
775781
776- # Initialize arrays
777- wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
778- status = jnp.zeros(n_agents, dtype=jnp.int32)
779782
780783 def update(t, loop_state):
781- key, status, wage_indices = loop_state
782-
783- # Shift loop state forwards
784- key, subkey = jax.random.split(key)
785- agent_keys = jax.random.split(subkey, n_agents)
786-
784+ " Shift loop state forwards "
785+ status, wage_indices = loop_state
786+ step_key = jax.random.fold_in(key, t)
787+ agent_keys = jax.random.split(step_key, n_agents)
787788 status, wage_indices = update_agents_vmap(
788789 agent_keys, status, wage_indices, model, w_bar
789790 )
790-
791- return key, status, wage_indices
791+ return status, wage_indices
792792
793793 # Run simulation using fori_loop
794- initial_loop_state = (key, status, wage_indices )
794+ initial_loop_state = (initial_status_vec, initial_wage_indices )
795795 final_loop_state = lax.fori_loop(0, T, update, initial_loop_state)
796796
797797 # Return only final employment state
798- _, final_is_employed, _ = final_loop_state
798+ final_is_employed, _ = final_loop_state
799799 return final_is_employed
800800
801801
802802def simulate_cross_section(
803- model: Model,
804- n_agents: int = 100_000,
805- T: int = 200,
806- seed: int = 42
803+ model: Model, # Model instance with parameters
804+ n_agents: int = 100_000, # Number of agents to simulate
805+ T: int = 200, # Length of burn-in
806+ seed: int = 42 # For reproducibility
807807 ) -> float:
808808 """
809- Simulate employment paths for many agents and return final unemployment rate .
809+ Wrapper function for _simulate_cross_section_compiled .
810810
811- Parameters:
812- - model: Model instance with parameters
813- - n_agents: Number of agents to simulate
814- - T: Number of periods to simulate
815- - seed: Random seed for reproducibility
811+ Push forward a cross-section for T periods and return the final
812+ cross-sectional unemployment rate.
816813
817- Returns:
818- - unemployment_rate: Fraction of agents unemployed at time T
819814 """
820815 key = jax.random.PRNGKey(seed)
821816
822817 # Solve for optimal reservation wage
823818 v_u = vfi(model)
824819 w_bar = get_reservation_wage(v_u, model)
825820
826- # Run JIT-compiled simulation
821+ # Initialize arrays
822+ initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
823+ initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)
824+
827825 final_status = _simulate_cross_section_compiled(
828- key, model, w_bar, n_agents , T
826+ key, model, w_bar, initial_wage_indices, initial_status_vec , T
829827 )
830828
831- # Calculate unemployment rate at final period
832829 unemployment_rate = 1 - jnp.mean(final_status)
833-
834830 return unemployment_rate
835831```
836832
@@ -850,8 +846,13 @@ def plot_cross_sectional_unemployment(
850846 key = jax.random.PRNGKey(42)
851847 v_u = vfi(model)
852848 w_bar = get_reservation_wage(v_u, model)
849+
850+ # Initialize arrays
851+ initial_wage_indices = jnp.zeros(n_agents, dtype=jnp.int32)
852+ initial_status_vec = jnp.zeros(n_agents, dtype=jnp.int32)
853+
853854 final_status = _simulate_cross_section_compiled(
854- key, model, w_bar, n_agents , t_snapshot
855+ key, model, w_bar, initial_wage_indices, initial_status_vec , t_snapshot
855856 )
856857
857858 # Calculate unemployment rate
@@ -906,7 +907,7 @@ Now let's visualize the cross-sectional distribution:
906907plot_cross_sectional_unemployment(model)
907908```
908909
909- ## Lower Unemployment Compensation (c=0.5)
910+ ## Lower unemployment compensation (c=0.5)
910911
911912What happens to the cross-sectional unemployment rate with lower unemployment compensation?
912913
0 commit comments