Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 2 additions & 71 deletions lectures/jax_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -585,85 +585,16 @@ And we produce the correct answer:
jnp.allclose(z_vmap, z_mesh)
```

## Exercises



```{exercise-start}
:label: jax_intro_ex1
```

Recall that Newton's method for solving for the root of $f$ involves iterating on


$$
q(x) = x - \frac{f(x)}{f'(x)}
$$

Write a function called `newton` that takes a function $f$ plus a guess $x_0$ and returns an approximate fixed point.

Your `newton` implementation should use automatic differentiation to calculate $f'$.

Test your `newton` method on the function shown below.

```{code-cell} ipython3
f = lambda x: jnp.sin(4 * (x - 1/4)) + x + x**20 - 1
x = jnp.linspace(0, 1, 100)

fig, ax = plt.subplots()
ax.plot(x, f(x), label='$f(x)$')
ax.axhline(ls='--', c='k')
ax.set_xlabel('$x$', fontsize=12)
ax.set_ylabel('$f(x)$', fontsize=12)
ax.legend(fontsize=12)
plt.show()
```

```{exercise-end}
```

```{solution-start} jax_intro_ex1
:class: dropdown
```

Here's a suitable function:

```{code-cell} ipython3
def newton(f, x_0, tol=1e-5):
f_prime = jax.grad(f)
def q(x):
return x - f(x) / f_prime(x)

error = tol + 1
x = x_0
while error > tol:
y = q(x)
error = abs(x - y)
x = y

return x
```

Let's try it:

```{code-cell} ipython3
newton(f, 0.2)
```

This number looks good, given the figure.


```{solution-end}
```

## Exercises


```{exercise-start}
:label: jax_intro_ex2
```

In {ref}`an earlier exercise on parallelization <jax_intro_ex1>`, we used Monte
Carlo to price a European call option.
In the Exercise section of [a lecture on Numba and parallelization](https://python-programming.quantecon.org/parallelization.html), we used Monte Carlo to price a European call option.

The code was accelerated by Numba-based multithreading.

Expand Down
68 changes: 65 additions & 3 deletions lectures/newtons_method.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,74 @@ Let's check the GPU we are running

```{code-cell} ipython3
!nvidia-smi

```


## Newton in one dimension

As a warm up, let's implement Newton's method in JAX for a simple
one-dimensional root-finding problem.

[Recall](https://python.quantecon.org/newton_method.html) that Newton's method for solving for the root of $f$ involves iterating with the map $q$ defined by

$$
q(x) = x - \frac{f(x)}{f'(x)}
$$


Here is a function called `newton` that takes a function $f$ plus a guess $x_0$, iterates with $q$ starting from $x0$, and returns an approximate fixed point.


```{code-cell} ipython3
def newton(f, x_0, tol=1e-5):
f_prime = jax.grad(f)
def q(x):
return x - f(x) / f_prime(x)

error = tol + 1
x = x_0
while error > tol:
y = q(x)
error = abs(x - y)
x = y

return x
```

The code above uses automatic differentiation to calculate $f'$ via the call to `jax.grad`.

Let's test our `newton` routine on the function shown below.

```{code-cell} ipython3
f = lambda x: jnp.sin(4 * (x - 1/4)) + x + x**20 - 1
x = jnp.linspace(0, 1, 100)

import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(x, f(x), label='$f(x)$')
ax.axhline(ls='--', c='k')
ax.set_xlabel('$x$', fontsize=12)
ax.set_ylabel('$f(x)$', fontsize=12)
ax.legend(fontsize=12)
plt.show()
```

## The Equilibrium Problem
Here we go

```{code-cell} ipython3
newton(f, 0.2)
```

This number looks good, given the figure.



## An Equilibrium Problem

Now let's move up to higher dimensions.

In this section we describe the market equilibrium problem we will solve with
JAX.
First we describe a market equilibrium problem we will solve with JAX via root-finding.

We begin with a two good case,
which is borrowed from [an earlier lecture](https://python.quantecon.org/newton_method.html).
Expand Down