You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lectures/newtons_method.md
+58-83Lines changed: 58 additions & 83 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -19,20 +19,28 @@ kernelspec:
19
19
20
20
## Overview
21
21
22
-
In this lecture we highlight some of the capabilities of JAX, including JIT
23
-
compilation and automatic differentiation.
22
+
One of the key features of JAX is automatic differentiation.
24
23
25
-
The application is computing equilibria via Newton's method, which we discussed
26
-
in [a more elementary QuantEcon lecture](https://python.quantecon.org/newton_method.html)
24
+
While other software packages also offer this feature, the JAX version is
25
+
particularly powerful because it integrates so closely with other core
26
+
components of JAX, such as accelerated linear algebra, JIT compilation and
27
+
parallelization.
27
28
28
-
Here our focus is on how to apply JAX to this problem.
29
+
The application of automatic differentiation we consider is computing economic equilibria via Newton's method.
30
+
31
+
Newton's method is a relatively simple root and fixed point solution algorithm, which we discussed
32
+
in [a more elementary QuantEcon lecture](https://python.quantecon.org/newton_method.html).
33
+
34
+
JAX is almost ideally suited to implementing Newton's method efficiently, even
35
+
in high dimensions.
29
36
30
37
We use the following imports in this lecture
31
38
32
39
```{code-cell} ipython3
33
40
import jax
34
41
import jax.numpy as jnp
35
42
from scipy.optimize import root
43
+
import matplotlib.pyplot as plt
36
44
```
37
45
38
46
Let's check the GPU we are running
@@ -48,14 +56,19 @@ Let's check the GPU we are running
48
56
As a warm up, let's implement Newton's method in JAX for a simple
49
57
one-dimensional root-finding problem.
50
58
59
+
Let $f$ be a function from $\mathbb R$ to itself.
60
+
61
+
A **root** of $f$ is an $x \in \mathbb R$ such that $f(x)=0$.
62
+
51
63
[Recall](https://python.quantecon.org/newton_method.html) that Newton's method for solving for the root of $f$ involves iterating with the map $q$ defined by
52
64
53
65
$$
54
66
q(x) = x - \frac{f(x)}{f'(x)}
55
67
$$
56
68
57
69
58
-
Here is a function called `newton` that takes a function $f$ plus a guess $x_0$, iterates with $q$ starting from $x0$, and returns an approximate fixed point.
70
+
Here is a function called `newton` that takes a function $f$ plus a scalar value $x_0$,
71
+
iterates with $q$ starting from $x_0$, and returns an approximate fixed point.
59
72
60
73
61
74
```{code-cell} ipython3
@@ -82,7 +95,6 @@ Let's test our `newton` routine on the function shown below.
82
95
f = lambda x: jnp.sin(4 * (x - 1/4)) + x + x**20 - 1
83
96
x = jnp.linspace(0, 1, 100)
84
97
85
-
import matplotlib.pyplot as plt
86
98
fig, ax = plt.subplots()
87
99
ax.plot(x, f(x), label='$f(x)$')
88
100
ax.axhline(ls='--', c='k')
@@ -98,7 +110,7 @@ Here we go
98
110
newton(f, 0.2)
99
111
```
100
112
101
-
This number looks good, given the figure.
113
+
This number looks to be close to the root, given the figure.
102
114
103
115
104
116
@@ -108,87 +120,44 @@ Now let's move up to higher dimensions.
108
120
109
121
First we describe a market equilibrium problem we will solve with JAX via root-finding.
110
122
111
-
We begin with a two good case,
112
-
which is borrowed from [an earlier lecture](https://python.quantecon.org/newton_method.html).
123
+
The market is for $n$ goods.
113
124
114
-
Then we shift to higher dimensions.
125
+
(We are extending a two-good version of the market from [an earlier lecture](https://python.quantecon.org/newton_method.html).)
115
126
116
-
117
-
### The Two Goods Market Equilibrium
118
-
119
-
Assume we have a market for two complementary goods where demand depends on the
120
-
price of both components.
121
-
122
-
We label them good 0 and good 1, with price vector $p = (p_0, p_1)$.
Then define the multivariate version of the formula for the [law of motion of capital](https://python.quantecon.org/newton_method.html#solow)
367
+
Then we define the multivariate version of the formula for the [law of motion of capital](https://python.quantecon.org/newton_method.html#solow)
391
368
392
369
```{code-cell} ipython3
393
370
def multivariate_solow(k, A=A, s=s, α=α, δ=δ):
@@ -408,17 +385,16 @@ for init in initLs:
408
385
```
409
386
410
387
411
-
We find that the results are invariant to the starting values given the well-defined property of this question.
388
+
We find that the results are invariant to the starting values.
412
389
413
390
But the number of iterations it takes to converge is dependent on the starting values.
414
391
415
-
Let substitute the output back to the formulate to check our last result
392
+
Let substitute the output back into the formulate to check our last result
416
393
417
394
```{code-cell} ipython3
418
395
multivariate_solow(k) - k
419
396
```
420
397
421
-
422
398
Note the error is very small.
423
399
424
400
We can also test our results on the known solution
@@ -435,8 +411,7 @@ init = jnp.repeat(1.0, 3)
435
411
init).block_until_ready()
436
412
```
437
413
438
-
439
-
The result is very close to the ground truth but still slightly different.
414
+
The result is very close to the true solution but still slightly different.
440
415
441
416
We can increase the precision of the floating point numbers and restrict the tolerance to obtain a more accurate approximation (see detailed discussion in the [lecture on JAX](https://python-programming.quantecon.org/jax_intro.html#differences))
0 commit comments