@@ -25,8 +25,7 @@ We require the following library to be installed.
2525!pip install --upgrade quantecon
2626```
2727
28- A monopolist faces inverse demand
29- curve
28+ We study a monopolist who faces inverse demand curve
3029
3130$$
3231P_t = a_0 - a_1 Y_t + Z_t,
3837* $Y_t$ is output and
3938* $Z_t$ is a demand shock.
4039
41- We assume that $Z_t$ is a discretized AR(1) process.
40+ We assume that $Z_t$ is a discretized AR(1) process, specified below .
4241
4342Current profits are
4443
@@ -116,10 +115,10 @@ def create_investment_model(
116115
117116
118117Let's re-write the vectorized version of the right-hand side of the
119- Bellman equation (before maximization), which is a 3D array representing:
118+ Bellman equation (before maximization), which is a 3D array representing
120119
121120$$
122- B(y, z, y') = r(y, z, y') + \beta \sum_{z'} v(y', z') Q(z, z')
121+ B(y, z, y') = r(y, z, y') + \beta \sum_{z'} v(y', z') Q(z, z')
123122$$
124123
125124for all $(y, z, y')$.
@@ -154,8 +153,10 @@ def B(v, constants, sizes, arrays):
154153B = jax.jit(B, static_argnums=(2,))
155154```
156155
156+ We define a function to compute the current rewards $r_ \sigma$ given policy $\sigma$,
157+ which is defined as the vector
157158
158- Define a function to compute the current rewards given policy $\sigma$.
159+ $$ r_\sigma(y, z) := r(y, z, \sigma(y, z)) $$
159160
160161``` {code-cell} ipython3
161162def compute_r_σ(σ, constants, sizes, arrays):
@@ -238,47 +239,32 @@ T_σ = jax.jit(T_σ, static_argnums=(3,))
238239
239240Next, we want to computes the lifetime value of following policy $\sigma$.
240241
241- The basic problem is to solve the linear system
242+ This lifetime value is a function $v _ \sigma$ that satisfies
242243
243- $$ v (y, z) = r(y, z, \sigma(y, z)) + \beta \sum_{z'} v (\sigma(y, z), z') Q(z, z) $$
244+ $$ v_\sigma (y, z) = r_ \sigma(y, z) + \beta \sum_{z'} v_\sigma (\sigma(y, z), z') Q(z, z' ) $$
244245
245- for $v $.
246+ We wish to solve this equation for $v _ \sigma $.
246247
247- It turns out to be helpful to rewrite this as
248+ Suppose we define the linear operator $L _ \sigma$ by
248249
249- $$ v (y, z) = r (y, z, \sigma(y, z)) + \beta \sum_{y', z'} v(y', z') P_ \sigma(y, z, y' , z') $$
250+ $$ (L_\sigma v) (y, z) = v (y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z , z') $$
250251
251- where $P_ \sigma(y, z, y', z') = 1\{ y' = \sigma(y, z)\} Q(z, z')$.
252-
253- We want to write this as $v = r_ \sigma + \beta P_ \sigma v$ and then solve for $v$
254-
255- Note, however, that $v$ is a multi-index array, rather than a vector.
256-
257-
258- The value $v_ {\sigma}$ of a policy $\sigma$ is defined as
252+ With this notation, the problem is to solve for $v$ via
259253
260254$$
261- v_ {\sigma} = (I - \beta P_{\sigma})^{-1} r_{ \sigma}
255+ (L_ {\sigma} v)(y, z) = r_ \sigma(y, z)
262256$$
263257
264- Here we set up the linear map $v \mapsto R_ {\sigma} v$,
265-
266- where $R_ {\sigma} := I - \beta P_ {\sigma}$
267-
268- In the investment problem, this map can be expressed as
269-
270- $$
271- (R_{\sigma} v)(y, z) = v(y, z) - \beta \sum_{z'} v(\sigma(y, z), z') Q(z, z')
272- $$
258+ In vector for this is $L_ \sigma v = r_ \sigma$, which tells us that the function
259+ we seek is
273260
274- Defining the map as above works in a more intuitive multi-index setting
275- (e.g. working with $v[ i, j] $ rather than flattening v to a one-dimensional
276- array) and avoids instantiating the large matrix $P_ {\sigma}$.
261+ $$ v_\sigma = L_\sigma^{-1} r_\sigma $$
277262
278- Let's define the function $R_ {\sigma}$.
263+ JAX allows us to solve linear systems defined in terms of operators; the first
264+ step is to define the function $L_ {\sigma}$.
279265
280266``` {code-cell} ipython3
281- def R_σ (v, σ, constants, sizes, arrays):
267+ def L_σ (v, σ, constants, sizes, arrays):
282268
283269 β, a_0, a_1, γ, c = constants
284270 y_size, z_size = sizes
@@ -296,12 +282,11 @@ def R_σ(v, σ, constants, sizes, arrays):
296282 # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]
297283 return v - β * jnp.sum(V * Q, axis=2)
298284
299- R_σ = jax.jit(R_σ , static_argnums=(3,))
285+ L_σ = jax.jit(L_σ , static_argnums=(3,))
300286```
301287
288+ Now we can define a function to compute $v_ {\sigma}$
302289
303- Define a function to get the value $v_ {\sigma}$ of policy
304- $\sigma$ by inverting the linear map $R_ {\sigma}$.
305290
306291``` {code-cell} ipython3
307292def get_value(σ, constants, sizes, arrays):
@@ -313,16 +298,16 @@ def get_value(σ, constants, sizes, arrays):
313298
314299 r_σ = compute_r_σ(σ, constants, sizes, arrays)
315300
316- # Reduce R_σ to a function in v
317- partial_R_σ = lambda v: R_σ (v, σ, constants, sizes, arrays)
301+ # Reduce L_σ to a function in v
302+ partial_L_σ = lambda v: L_σ (v, σ, constants, sizes, arrays)
318303
319- return jax.scipy.sparse.linalg.bicgstab(partial_R_σ , r_σ)[0]
304+ return jax.scipy.sparse.linalg.bicgstab(partial_L_σ , r_σ)[0]
320305
321306get_value = jax.jit(get_value, static_argnums=(2,))
322307```
323308
324309
325- Now we define the solvers, which implement VFI, HPI and OPI.
310+ Finally, we introduce the solvers that implement VFI, HPI and OPI.
326311
327312``` {code-cell} ipython3
328313:load: _static/lecture_specific/vfi.py
@@ -396,7 +381,7 @@ plt.show()
396381Let's plot the time taken by each of the solvers and compare them.
397382
398383``` {code-cell} ipython3
399- m_vals = range(5, 3000, 100 )
384+ m_vals = range(5, 600, 40 )
400385```
401386
402387``` {code-cell} ipython3
0 commit comments