Skip to content

Commit bfe2b34

Browse files
authored
Tidying up Newton's method lecture (#113)
* misc * misc * misc
1 parent ccf6fdd commit bfe2b34

File tree

1 file changed

+58
-83
lines changed

1 file changed

+58
-83
lines changed

lectures/newtons_method.md

Lines changed: 58 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,28 @@ kernelspec:
1919

2020
## Overview
2121

22-
In this lecture we highlight some of the capabilities of JAX, including JIT
23-
compilation and automatic differentiation.
22+
One of the key features of JAX is automatic differentiation.
2423

25-
The application is computing equilibria via Newton's method, which we discussed
26-
in [a more elementary QuantEcon lecture](https://python.quantecon.org/newton_method.html)
24+
While other software packages also offer this feature, the JAX version is
25+
particularly powerful because it integrates so closely with other core
26+
components of JAX, such as accelerated linear algebra, JIT compilation and
27+
parallelization.
2728

28-
Here our focus is on how to apply JAX to this problem.
29+
The application of automatic differentiation we consider is computing economic equilibria via Newton's method.
30+
31+
Newton's method is a relatively simple root and fixed point solution algorithm, which we discussed
32+
in [a more elementary QuantEcon lecture](https://python.quantecon.org/newton_method.html).
33+
34+
JAX is almost ideally suited to implementing Newton's method efficiently, even
35+
in high dimensions.
2936

3037
We use the following imports in this lecture
3138

3239
```{code-cell} ipython3
3340
import jax
3441
import jax.numpy as jnp
3542
from scipy.optimize import root
43+
import matplotlib.pyplot as plt
3644
```
3745

3846
Let's check the GPU we are running
@@ -48,14 +56,19 @@ Let's check the GPU we are running
4856
As a warm up, let's implement Newton's method in JAX for a simple
4957
one-dimensional root-finding problem.
5058

59+
Let $f$ be a function from $\mathbb R$ to itself.
60+
61+
A **root** of $f$ is an $x \in \mathbb R$ such that $f(x)=0$.
62+
5163
[Recall](https://python.quantecon.org/newton_method.html) that Newton's method for solving for the root of $f$ involves iterating with the map $q$ defined by
5264

5365
$$
5466
q(x) = x - \frac{f(x)}{f'(x)}
5567
$$
5668

5769

58-
Here is a function called `newton` that takes a function $f$ plus a guess $x_0$, iterates with $q$ starting from $x0$, and returns an approximate fixed point.
70+
Here is a function called `newton` that takes a function $f$ plus a scalar value $x_0$,
71+
iterates with $q$ starting from $x_0$, and returns an approximate fixed point.
5972

6073

6174
```{code-cell} ipython3
@@ -82,7 +95,6 @@ Let's test our `newton` routine on the function shown below.
8295
f = lambda x: jnp.sin(4 * (x - 1/4)) + x + x**20 - 1
8396
x = jnp.linspace(0, 1, 100)
8497
85-
import matplotlib.pyplot as plt
8698
fig, ax = plt.subplots()
8799
ax.plot(x, f(x), label='$f(x)$')
88100
ax.axhline(ls='--', c='k')
@@ -98,7 +110,7 @@ Here we go
98110
newton(f, 0.2)
99111
```
100112

101-
This number looks good, given the figure.
113+
This number looks to be close to the root, given the figure.
102114

103115

104116

@@ -108,87 +120,44 @@ Now let's move up to higher dimensions.
108120

109121
First we describe a market equilibrium problem we will solve with JAX via root-finding.
110122

111-
We begin with a two good case,
112-
which is borrowed from [an earlier lecture](https://python.quantecon.org/newton_method.html).
123+
The market is for $n$ goods.
113124

114-
Then we shift to higher dimensions.
125+
(We are extending a two-good version of the market from [an earlier lecture](https://python.quantecon.org/newton_method.html).)
115126

116-
117-
### The Two Goods Market Equilibrium
118-
119-
Assume we have a market for two complementary goods where demand depends on the
120-
price of both components.
121-
122-
We label them good 0 and good 1, with price vector $p = (p_0, p_1)$.
123-
124-
Then the supply of good $i$ at price $p$ is,
127+
The supply function for the $i$-th good is
125128

126129
$$
127-
q^s_i (p) = b_i \sqrt{p_i}
130+
q^s_i (p) = b_i \sqrt{p_i}
128131
$$
129132

130-
and the demand of good $i$ at price $p$ is,
133+
which we write in vector form as
131134

132135
$$
133-
q^d_i (p) = \text{exp}(-(a_{i0} p_0 + a_{i1} p_1)) + c_i
136+
q^s (p) =b \sqrt{p}
134137
$$
135138

136-
Here $a_{ij}$, $b_i$ and $c_i$ are parameters for $n \times n$ square matrix $A$ and $n \times 1$ parameter vectors $b$ and $c$.
139+
(Here $\sqrt{p}$ is the square root of each $p_i$ and $b \sqrt{p}$ is the vector
140+
formed by taking the pointwise product $b_i \sqrt{p_i}$ at each $i$.)
137141

138-
The excess demand function is,
142+
The demand function is
139143

140144
$$
141-
e_i(p) = q_i^d(p) - q_i^s(p), \quad i = 0, 1
145+
q^d (p) = \exp(- A p) + c
142146
$$
143147

144-
An equilibrium price vector $p^*$ satisfies $e_i(p^*) = 0$.
148+
(Here $A$ is an $n \times n$ matrix containing parameters, $c$ is an $n \times
149+
1$ vector and the $\exp$ function acts pointwise (element-by-element) on the
150+
vector $- A p$.)
145151

146-
We set
152+
The excess demand function is
147153

148154
$$
149-
A = \begin{pmatrix}
150-
a_{00} & a_{01} \\
151-
a_{10} & a_{11}
152-
\end{pmatrix},
153-
\qquad
154-
b = \begin{pmatrix}
155-
b_0 \\
156-
b_1
157-
\end{pmatrix}
158-
\qquad \text{and} \qquad
159-
c = \begin{pmatrix}
160-
c_0 \\
161-
c_1
162-
\end{pmatrix}
155+
e(p) = \exp(- A p) + c - b \sqrt{p}
163156
$$
164157

165-
for this particular question.
166-
167-
168-
### A High-Dimensional Version
158+
An **equilibrium price** vector is an $n$-vector $p$ such that $e(p) = 0$.
169159

170-
Let's now shift to a linear algebra formulation, which allows us to handle
171-
arbitrarily many goods.
172-
173-
The supply function remains unchanged,
174-
175-
$$
176-
q^s (p) =b \sqrt{p}
177-
$$
178-
179-
The demand function becomes
180-
181-
$$
182-
q^d (p) = \text{exp}(- A \cdot p) + c
183-
$$
184-
185-
Our new excess demand function is
186-
187-
$$
188-
e(p) = \text{exp}(- A \cdot p) + c - b \sqrt{p}
189-
$$
190-
191-
The function below calculates the excess demand for the given parameters
160+
The function below calculates the excess demand for given parameters
192161

193162
```{code-cell} ipython3
194163
def e(p, A, b, c):
@@ -206,7 +175,7 @@ In this section we describe and then implement the solution method.
206175

207176
We use a multivariate version of Newton's method to compute the equilibrium price.
208177

209-
The rule for updating a guess $p_n$ of the price vector is
178+
The rule for updating a guess $p_n$ of the equilibrium price vector is
210179

211180
```{math}
212181
:label: multi-newton
@@ -217,13 +186,20 @@ Here $J_e(p_n)$ is the Jacobian of $e$ evaluated at $p_n$.
217186

218187
Iteration starts from initial guess $p_0$.
219188

220-
Instead of coding the Jacobian by hand, we use `jax.jacobian()`.
189+
Instead of coding the Jacobian by hand, we use automatic differentiation via `jax.jacobian()`.
221190

222191
```{code-cell} ipython3
223192
def newton(f, x_0, tol=1e-5, max_iter=15):
193+
"""
194+
A multivariate Newton root-finding routine.
195+
196+
"""
224197
x = x_0
225198
f_jac = jax.jacobian(f)
226-
q = jax.jit(lambda x: x - jnp.linalg.solve(f_jac(x), f(x)))
199+
@jax.jit
200+
def q(x):
201+
" Updates the current guess. "
202+
return x - jnp.linalg.solve(f_jac(x), f(x))
227203
error = tol + 1
228204
n = 0
229205
while error > tol:
@@ -245,17 +221,15 @@ def newton(f, x_0, tol=1e-5, max_iter=15):
245221

246222
Let's now apply the method just described to investigate a large market with 5,000 goods.
247223

248-
We randomly generate the matrix $A$ and set the parameter vectors $b \text{ and } c$ to $1$.
224+
We randomly generate the matrix $A$ and set the parameter vectors $b, c$ to $1$.
249225

250226
```{code-cell} ipython3
251227
dim = 5_000
252228
seed = 32
253229
254230
# Create a random matrix A and normalize the rows to sum to one
255231
key = jax.random.PRNGKey(seed)
256-
257232
A = jax.random.uniform(key, [dim, dim])
258-
259233
s = jnp.sum(A, axis=0)
260234
A = A / s
261235
@@ -271,16 +245,18 @@ Here's our initial condition $p_0$
271245
init_p = jnp.ones(dim)
272246
```
273247

274-
By leveraging the power of Newton's method, JAX accelerated linear algebra,
248+
By combining the power of Newton's method, JAX accelerated linear algebra,
275249
automatic differentiation, and a GPU, we obtain a relatively small error for
276-
this very large problem in just a few seconds:
250+
this high-dimensional problem in just a few seconds:
277251

278252
```{code-cell} ipython3
279253
%%time
280254
281255
p = newton(lambda p: e(p, A, b, c), init_p).block_until_ready()
282256
```
283257

258+
Here's the size of the error:
259+
284260
```{code-cell} ipython3
285261
jnp.max(jnp.abs(e(p, A, b, c)))
286262
```
@@ -298,13 +274,14 @@ solution = root(lambda p: e(p, A, b, c),
298274
tol=1e-5)
299275
```
300276

277+
The result is also slightly less accurate:
278+
301279
```{code-cell} ipython3
302280
p = solution.x
303281
jnp.max(jnp.abs(e(p, A, b, c)))
304282
```
305283

306284

307-
The result is also less accurate.
308285

309286

310287

@@ -387,7 +364,7 @@ initLs = [jnp.ones(3),
387364
```
388365

389366

390-
Then define the multivariate version of the formula for the [law of motion of capital](https://python.quantecon.org/newton_method.html#solow)
367+
Then we define the multivariate version of the formula for the [law of motion of capital](https://python.quantecon.org/newton_method.html#solow)
391368

392369
```{code-cell} ipython3
393370
def multivariate_solow(k, A=A, s=s, α=α, δ=δ):
@@ -408,17 +385,16 @@ for init in initLs:
408385
```
409386

410387

411-
We find that the results are invariant to the starting values given the well-defined property of this question.
388+
We find that the results are invariant to the starting values.
412389

413390
But the number of iterations it takes to converge is dependent on the starting values.
414391

415-
Let substitute the output back to the formulate to check our last result
392+
Let substitute the output back into the formulate to check our last result
416393

417394
```{code-cell} ipython3
418395
multivariate_solow(k) - k
419396
```
420397

421-
422398
Note the error is very small.
423399

424400
We can also test our results on the known solution
@@ -435,8 +411,7 @@ init = jnp.repeat(1.0, 3)
435411
init).block_until_ready()
436412
```
437413

438-
439-
The result is very close to the ground truth but still slightly different.
414+
The result is very close to the true solution but still slightly different.
440415

441416
We can increase the precision of the floating point numbers and restrict the tolerance to obtain a more accurate approximation (see detailed discussion in the [lecture on JAX](https://python-programming.quantecon.org/jax_intro.html#differences))
442417

0 commit comments

Comments
 (0)