diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 7b83ac96..9b695bc8 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -585,85 +585,16 @@ And we produce the correct answer: jnp.allclose(z_vmap, z_mesh) ``` -## Exercises - - - -```{exercise-start} -:label: jax_intro_ex1 -``` - -Recall that Newton's method for solving for the root of $f$ involves iterating on - - -$$ - q(x) = x - \frac{f(x)}{f'(x)} -$$ - -Write a function called `newton` that takes a function $f$ plus a guess $x_0$ and returns an approximate fixed point. - -Your `newton` implementation should use automatic differentiation to calculate $f'$. - -Test your `newton` method on the function shown below. - -```{code-cell} ipython3 -f = lambda x: jnp.sin(4 * (x - 1/4)) + x + x**20 - 1 -x = jnp.linspace(0, 1, 100) - -fig, ax = plt.subplots() -ax.plot(x, f(x), label='$f(x)$') -ax.axhline(ls='--', c='k') -ax.set_xlabel('$x$', fontsize=12) -ax.set_ylabel('$f(x)$', fontsize=12) -ax.legend(fontsize=12) -plt.show() -``` - -```{exercise-end} -``` - -```{solution-start} jax_intro_ex1 -:class: dropdown -``` - -Here's a suitable function: - -```{code-cell} ipython3 -def newton(f, x_0, tol=1e-5): - f_prime = jax.grad(f) - def q(x): - return x - f(x) / f_prime(x) - - error = tol + 1 - x = x_0 - while error > tol: - y = q(x) - error = abs(x - y) - x = y - - return x -``` - -Let's try it: - -```{code-cell} ipython3 -newton(f, 0.2) -``` -This number looks good, given the figure. - - -```{solution-end} -``` +## Exercises ```{exercise-start} :label: jax_intro_ex2 ``` -In {ref}`an earlier exercise on parallelization `, we used Monte -Carlo to price a European call option. +In the Exercise section of [a lecture on Numba and parallelization](https://python-programming.quantecon.org/parallelization.html), we used Monte Carlo to price a European call option. The code was accelerated by Numba-based multithreading. diff --git a/lectures/newtons_method.md b/lectures/newtons_method.md index 10d7e49f..afe52aef 100644 --- a/lectures/newtons_method.md +++ b/lectures/newtons_method.md @@ -39,12 +39,74 @@ Let's check the GPU we are running ```{code-cell} ipython3 !nvidia-smi + +``` + + +## Newton in one dimension + +As a warm up, let's implement Newton's method in JAX for a simple +one-dimensional root-finding problem. + +[Recall](https://python.quantecon.org/newton_method.html) that Newton's method for solving for the root of $f$ involves iterating with the map $q$ defined by + +$$ + q(x) = x - \frac{f(x)}{f'(x)} +$$ + + +Here is a function called `newton` that takes a function $f$ plus a guess $x_0$, iterates with $q$ starting from $x0$, and returns an approximate fixed point. + + +```{code-cell} ipython3 +def newton(f, x_0, tol=1e-5): + f_prime = jax.grad(f) + def q(x): + return x - f(x) / f_prime(x) + + error = tol + 1 + x = x_0 + while error > tol: + y = q(x) + error = abs(x - y) + x = y + + return x +``` + +The code above uses automatic differentiation to calculate $f'$ via the call to `jax.grad`. + +Let's test our `newton` routine on the function shown below. + +```{code-cell} ipython3 +f = lambda x: jnp.sin(4 * (x - 1/4)) + x + x**20 - 1 +x = jnp.linspace(0, 1, 100) + +import matplotlib.pyplot as plt +fig, ax = plt.subplots() +ax.plot(x, f(x), label='$f(x)$') +ax.axhline(ls='--', c='k') +ax.set_xlabel('$x$', fontsize=12) +ax.set_ylabel('$f(x)$', fontsize=12) +ax.legend(fontsize=12) +plt.show() ``` -## The Equilibrium Problem +Here we go + +```{code-cell} ipython3 +newton(f, 0.2) +``` + +This number looks good, given the figure. + + + +## An Equilibrium Problem + +Now let's move up to higher dimensions. -In this section we describe the market equilibrium problem we will solve with -JAX. +First we describe a market equilibrium problem we will solve with JAX via root-finding. We begin with a two good case, which is borrowed from [an earlier lecture](https://python.quantecon.org/newton_method.html).