In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

## Evaluating a simple function

Our aim is to evaluate the cosine function at many points:

In [None]:
fig, ax = plt.subplots()
x = np.linspace(0, 10, 20)
ax.plot(x, np.cos(x))
ax.scatter(x, np.cos(x))

Now let's try with a large array.

### With NumPy

In [None]:
n = 100_000_000
x = np.linspace(0, 10, n)

In [None]:
%time np.cos(x)

### With JAX

In [None]:
x = jnp.linspace(0, 10, n)

In [None]:
%time jnp.cos(x).block_until_ready()

In [None]:
%time jnp.cos(x).block_until_ready()

### Changing size triggers recompilation

In [None]:
x = jnp.linspace(0, 10, n + 1)

In [None]:
%time jnp.exp(x).block_until_ready()

In [None]:
%time jnp.exp(x).block_until_ready()

## Evaluating a more complicated function

In [None]:
def f(x):
    y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - 0.1 * x**2
    return y

In [None]:
fig, ax = plt.subplots()
x = np.linspace(0, 10, 50)
ax.plot(x, f(x))
ax.scatter(x, f(x))

Now let's try with a large array.

### With NumPy

In [None]:
x = np.linspace(0, 10, n)

In [None]:
%time f(x)

In [None]:
%time f(x)

### With JAX

In [None]:
def f(x):
    y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2
    return y

In [None]:
x = jnp.linspace(0, 10, n)

In [None]:
%time f(x).block_until_ready()

In [None]:
%time f(x).block_until_ready()

### Compiling the Whole Function

In [None]:
f = jax.jit(f)

In [None]:
%time f(x).block_until_ready()

In [None]:
%time f(x).block_until_ready()

## Solving Linear Systems

In [None]:
np.random.seed(1234)
n = 10_000
A = np.random.randn(n, n)
b = np.ones(n)

In [None]:
%time np.linalg.solve(A, b)

In [None]:
A, b = [jax.device_put(v) for v in (A, b)]

In [None]:
%time jnp.linalg.solve(A, b).block_until_ready()

In [None]:
%time jnp.linalg.solve(A, b).block_until_ready()

## Nonlinear equations: Newton’s method 

Let’s suppose we want to find an $ x $ such that $ f(x)=0 $ for some smooth
one-dimensional function $ f $.

Suppose we have a guess $ x_0 $ and we want to update it to a new point $ x_1 $.

As a first step, we take the first-order approximation of $ f $ around $ x_0 $:

$$
\hat f(x) \approx f\left(x_0\right)+f^{\prime}\left(x_0\right)\left(x-x_0\right)
$$

Now we solve for the zero of $ \hat f $.

In particular, we set $ \hat{f}(x_1) = 0 $ and solve for $ x_1 $ to get

$$
x_1 = x_0 - \frac{ f(x_0) }{ f'(x_0) },
\quad x_0 \text{ given}
$$

Generalizing the formula above, for one-dimensional zero-finding problems, Newton’s method iterates on


<a id='equation-oned-newton'></a>
$$
x_{t+1} = x_t - \frac{ f(x_t) }{ f'(x_t) },
\quad x_0 \text{ given} \tag{9.5}
$$

The following code implements the iteration [(9.5)](#equation-oned-newton)

<a id='first-newton-attempt'></a>

In [None]:
def newton(f, Df, x_0, tol=1e-7, max_iter=100_000):
    x = x_0

    # Implement the zero-finding formula
    def q(x):
        return x - f(x) / Df(x)

    error = tol + 1
    n = 0
    while error > tol:
        n += 1
        if(n > max_iter):
            raise Exception('Max iteration reached without convergence')
        y = q(x)
        error = np.abs(x - y)
        x = y
        print(f'iteration {n}, error = {error:.5f}')
    return x

### A Two Goods Market Equilibrium

Let’s start by computing the market equilibrium of a two-good problem.

We consider a market for two related products, good 0 and good 1, with
price vector $ p = (p_0, p_1) $

Supply of good $ i $ at price $ p $,

$$
q^s_i (p) = b_i \sqrt{p_i}
$$

Demand of good $ i $ at price $ p $ is,

$$
q^d_i (p) = \exp(-(a_{i0} p_0 + a_{i1} p_1)) + c_i
$$

Here $ c_i $, $ b_i $ and $ a_{ij} $ are parameters.

For example, the two goods might be computer components that are typically used together, in which case they are complements. Hence demand depends on the price of both components.

The excess demand function is,

$$
e_i(p) = q^d_i(p) - q^s_i(p), \quad i = 0, 1
$$

An equilibrium price vector $ p^* $ satisfies $ e_i(p^*) = 0 $.

We set

$$
A = \begin{pmatrix}
            a_{00} & a_{01} \\
            a_{10} & a_{11}
        \end{pmatrix},
            \qquad 
    b = \begin{pmatrix}
            b_0 \\
            b_1
        \end{pmatrix}
    \qquad \text{and} \qquad
    c = \begin{pmatrix}
            c_0 \\
            c_1
        \end{pmatrix}
$$

for this particular question.

#### A Graphical Exploration

Since our problem is only two-dimensional, we can use graphical analysis to visualize and help understand the problem.

Our first step is to define the excess demand function

$$
e(p) = 
    \begin{pmatrix}
    e_0(p) \\
    e_1(p)
    \end{pmatrix}
$$

The function below calculates the excess demand for given parameters

In [None]:
def e(p, A, b, c):
    return np.exp(- A @ p) + c - b * np.sqrt(p)

Our default parameter values will be

$$
A = \begin{pmatrix}
            0.5 & 0.4 \\
            0.8 & 0.2
        \end{pmatrix},
            \qquad 
    b = \begin{pmatrix}
            1 \\
            1
        \end{pmatrix}
    \qquad \text{and} \qquad
    c = \begin{pmatrix}
            1 \\
            1
        \end{pmatrix}
$$

In [None]:
A = np.array([
    [0.5, 0.4],
    [0.8, 0.2]
])
b = np.ones(2)
c = np.ones(2)

At a price level of $ p = (1, 0.5) $, the excess demand is

In [None]:
ex_demand = e((1.0, 0.5), A, b, c)

print(f'The excess demand for good 0 is {ex_demand[0]:.3f} \n'
      f'The excess demand for good 1 is {ex_demand[1]:.3f}')

Next we plot the two functions $ e_0 $ and $ e_1 $ on a grid of $ (p_0, p_1) $ values, using contour surfaces and lines.

We will use the following function to build the contour plots

In [None]:
def plot_excess_demand(ax, good=0, grid_size=100, grid_max=4, surface=True):

    # Create a 100x100 grid
    p_grid = np.linspace(0, grid_max, grid_size)
    z = np.empty((100, 100))

    for i, p_1 in enumerate(p_grid):
        for j, p_2 in enumerate(p_grid):
            z[i, j] = e((p_1, p_2), A, b, c)[good]

    if surface:
        cs1 = ax.contourf(p_grid, p_grid, z.T, alpha=0.5)
        plt.colorbar(cs1, ax=ax, format="%.6f")

    ctr1 = ax.contour(p_grid, p_grid, z.T, levels=[0.0])
    ax.set_xlabel("$p_0$")
    ax.set_ylabel("$p_1$")
    ax.set_title(f'Excess Demand for Good {good}')
    plt.clabel(ctr1, inline=1, fontsize=13)

Here’s our plot of $ e_0 $:

In [None]:
fig, ax = plt.subplots()
plot_excess_demand(ax, good=0)
plt.show()

Here’s our plot of $ e_1 $:

In [None]:
fig, ax = plt.subplots()
plot_excess_demand(ax, good=1)
plt.show()

We see the black contour line of zero, which tells us when $ e_i(p)=0 $.

For a price vector $ p $ such that $ e_i(p)=0 $ we know that good $ i $ is in equilibrium (demand equals supply).

If these two contour lines cross at some price vector $ p^* $, then $ p^* $ is an equilibrium price vector.

In [None]:
fig, ax = plt.subplots(figsize=(10, 5.7))
for good in (0, 1):
    plot_excess_demand(ax, good=good, surface=False)
plt.show()

It seems there is an equilibrium close to $ p = (1.6, 1.5) $.

#### Using a Multidimensional Root Finder

To solve for $ p^* $ more precisely, we use a zero-finding algorithm from `scipy.optimize`.

We supply $ p = (1, 1) $ as our initial guess.

In [None]:
init_p = np.ones(2)

This uses the [modified Powell method](https://docs.scipy.org/doc/scipy/reference/optimize.root-hybr.html#optimize-root-hybr) to find the zero

In [None]:
%%time
solution = root(lambda p: e(p, A, b, c), init_p, method='hybr')

Here’s the resulting value:

In [None]:
p = solution.x
p

This looks close to our guess from observing the figure. We can plug it back into $ e $ to test that $ e(p) \approx 0 $:

In [None]:
np.max(np.abs(e(p, A, b, c)))

This is indeed a very small error.

#### Adding Gradient Information

In many cases, for zero-finding algorithms applied to smooth functions, supplying the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) of the function leads to better convergence properties.

Here we manually calculate the elements of the Jacobian

$$
J(p) = 
    \begin{pmatrix}
        \frac{\partial e_0}{\partial p_0}(p) & \frac{\partial e_0}{\partial p_1}(p) \\
        \frac{\partial e_1}{\partial p_0}(p) & \frac{\partial e_1}{\partial p_1}(p)
    \end{pmatrix}
$$

In [None]:
def jacobian_e(p, A, b, c):
    p_0, p_1 = p
    a_00, a_01 = A[0, :]
    a_10, a_11 = A[1, :]
    j_00 = -a_00 * np.exp(-a_00 * p_0) - (b[0]/2) * p_0**(-1/2)
    j_01 = -a_01 * np.exp(-a_01 * p_1)
    j_10 = -a_10 * np.exp(-a_10 * p_0)
    j_11 = -a_11 * np.exp(-a_11 * p_1) - (b[1]/2) * p_1**(-1/2)
    J = [[j_00, j_01],
         [j_10, j_11]]
    return np.array(J)

In [None]:
%%time
solution = root(lambda p: e(p, A, b, c),
                init_p, 
                jac=lambda p: jacobian_e(p, A, b, c), 
                method='hybr')

Now the solution is even more accurate (although, in this low-dimensional problem, the difference is quite small):

In [None]:
p = solution.x
np.max(np.abs(e(p, A, b, c)))

#### Using Newton’s Method

Now let’s use Newton’s method to compute the equilibrium price using the multivariate version of Newton’s method


<a id='equation-multi-newton'></a>
$$
p_{n+1} = p_n - J_e(p_n)^{-1} e(p_n) \tag{9.6}
$$

This is a multivariate version of [(9.5)](#equation-oned-newton)

(Here $ J_e(p_n) $ is the Jacobian of $ e $ evaluated at $ p_n $.)

The iteration starts from some initial guess of the price vector $ p_0 $.

Here, instead of coding Jacobian by hand, We use the `jacobian()` function in the `autograd` library to auto-differentiate and calculate the Jacobian.

With only slight modification, we can generalize [our previous attempt](#first-newton-attempt) to multi-dimensional problems

In [None]:
def newton(f, x_0, tol=1e-5, max_iter=10):
    x = x_0
    q = lambda x: x - np.linalg.solve(jacobian(f)(x), f(x))
    error = tol + 1
    n = 0
    while error > tol:
        n+=1
        if(n > max_iter):
            raise Exception('Max iteration reached without convergence')
        y = q(x)
        if(any(np.isnan(y))):
            raise Exception('Solution not found with NaN generated')
        error = np.linalg.norm(x - y)
        x = y
        print(f'iteration {n}, error = {error:.5f}')
    print('\n' + f'Result = {x} \n')
    return x

In [None]:
def e(p, A, b, c):
    return np.exp(- np.dot(A, p)) + c - b * np.sqrt(p)

We find the algorithm terminates in 4 steps

In [None]:
%%time
p = newton(lambda p: e(p, A, b, c), init_p)

In [None]:
np.max(np.abs(e(p, A, b, c)))

The result is very accurate.

With the larger overhead, the speed is not better than the optimized `scipy` function.

### A High-Dimensional Problem

Our next step is to investigate a large market with 3,000 goods.

A JAX version of this section using GPU accelerated linear algebra and
automatic differentiation is available [here](https://jax.quantecon.org/newtons_method.html#application)

The excess demand function is essentially the same, but now the matrix $ A $ is $ 3000 \times 3000 $ and the parameter vectors $ b $ and $ c $ are $ 3000 \times 1 $.

In [None]:
dim = 3000
np.random.seed(123)

# Create a random matrix A and normalize the rows to sum to one
A = np.random.rand(dim, dim)
A = np.asarray(A)
s = np.sum(A, axis=0)
A = A / s

# Set up b and c
b = np.ones(dim)
c = np.ones(dim)

Here’s our initial condition

In [None]:
init_p = np.ones(dim)

In [None]:
%%time
p = newton(lambda p: e(p, A, b, c), init_p)

In [None]:
np.max(np.abs(e(p, A, b, c)))

With the same tolerance, we compare the runtime and accuracy of Newton’s method to SciPy’s `root` function

In [None]:
%%time
solution = root(lambda p: e(p, A, b, c),
                init_p, 
                jac=lambda p: jacobian(e)(p, A, b, c), 
                method='hybr',
                tol=1e-5)

In [None]:
p = solution.x
np.max(np.abs(e(p, A, b, c)))

Here is a function called `newton` that takes a function $ f $ plus a scalar value $ x_0 $,
iterates with $ q $ starting from $ x_0 $, and returns an approximate fixed point.

In [None]:
def newton(f, x_0, tol=1e-5):
    f_prime = jax.grad(f)
    def q(x):
        return x - f(x) / f_prime(x)

    error = tol + 1
    x = x_0
    while error > tol:
        y = q(x)
        error = abs(x - y)
        x = y
        
    return x

The code above uses automatic differentiation to calculate $ f' $ via the call to `jax.grad`.

Let’s test our `newton` routine on the function shown below.

In [None]:
f = lambda x: jnp.sin(4 * (x - 1/4)) + x + x**20 - 1
x = jnp.linspace(0, 1, 100)

fig, ax = plt.subplots()
ax.plot(x, f(x), label='$f(x)$')
ax.axhline(ls='--', c='k')
ax.set_xlabel('$x$', fontsize=12)
ax.set_ylabel('$f(x)$', fontsize=12)
ax.legend(fontsize=12)
plt.show()

Here we go

In [None]:
newton(f, 0.2)

This number looks to be close to the root, given the figure.

## An Equilibrium Problem

Now let’s move up to higher dimensions.

First we describe a market equilibrium problem we will solve with JAX via root-finding.

The market is for $ n $ goods.

(We are extending a two-good version of the market from [an earlier lecture](https://python.quantecon.org/newton_method.html).)

The supply function for the $ i $-th good is

$$
q^s_i (p) = b_i \sqrt{p_i}
$$

which we write in vector form as

$$
q^s (p) =b \sqrt{p}
$$

(Here $ \sqrt{p} $ is the square root of each $ p_i $ and $ b \sqrt{p} $ is the vector
formed by taking the pointwise product $ b_i \sqrt{p_i} $ at each $ i $.)

The demand function is

$$
q^d (p) = \exp(- A p) + c
$$

(Here $ A $ is an $ n \times n $ matrix containing parameters, $ c $ is an $ n \times
1 $ vector and the $ \exp $ function acts pointwise (element-by-element) on the
vector $ - A p $.)

The excess demand function is

$$
e(p) = \exp(- A p) + c - b \sqrt{p}
$$

An **equilibrium price** vector is an $ n $-vector $ p $ such that $ e(p) = 0 $.

The function below calculates the excess demand for given parameters

In [None]:
def e(p, A, b, c):
    return jnp.exp(- A @ p) + c - b * jnp.sqrt(p)

## Computation

In this section we describe and then implement the solution method.

### Newton’s Method

We use a multivariate version of Newton’s method to compute the equilibrium price.

The rule for updating a guess $ p_n $ of the equilibrium price vector is


<a id='equation-multi-newton'></a>
$$
p_{n+1} = p_n - J_e(p_n)^{-1} e(p_n) \tag{3.1}
$$

Here $ J_e(p_n) $ is the Jacobian of $ e $ evaluated at $ p_n $.

Iteration starts from initial guess $ p_0 $.

Instead of coding the Jacobian by hand, we use automatic differentiation via `jax.jacobian()`.

In [None]:
def newton(f, x_0, tol=1e-5, max_iter=15):
    """
    A multivariate Newton root-finding routine.

    """
    x = x_0
    f_jac = jax.jacobian(f)
    @jax.jit
    def q(x):
        " Updates the current guess. "
        return x - jnp.linalg.solve(f_jac(x), f(x))
    error = tol + 1
    n = 0
    while error > tol:
        n += 1
        if(n > max_iter):
            raise Exception('Max iteration reached without convergence')
        y = q(x)
        error = jnp.linalg.norm(x - y)
        x = y
        print(f'iteration {n}, error = {error}')
    return x

### Application

Let’s now apply the method just described to investigate a large market with 5,000 goods.

We randomly generate the matrix $ A $ and set the parameter vectors $ b, c $ to $ 1 $.

In [None]:
dim = 5_000
seed = 32

# Create a random matrix A and normalize the rows to sum to one
key = jax.random.PRNGKey(seed)
A = jax.random.uniform(key, [dim, dim])
s = jnp.sum(A, axis=0)
A = A / s

# Set up b and c
b = jnp.ones(dim)
c = jnp.ones(dim)

Here’s our initial condition $ p_0 $

In [None]:
init_p = jnp.ones(dim)

By combining the power of Newton’s method, JAX accelerated linear algebra,
automatic differentiation, and a GPU, we obtain a relatively small error for
this high-dimensional problem in just a few seconds:

In [None]:
%%time

p = newton(lambda p: e(p, A, b, c), init_p).block_until_ready()

Here’s the size of the error:

In [None]:
jnp.max(jnp.abs(e(p, A, b, c)))

With the same tolerance, SciPy’s `root` function takes much longer to run,
even with the Jacobian supplied.

In [None]:
%%time

solution = root(lambda p: e(p, A, b, c),
                init_p,
                jac=lambda p: jax.jacobian(e)(p, A, b, c),
                method='hybr',
                tol=1e-5)

The result is also slightly less accurate:

In [None]:
p = solution.x
jnp.max(jnp.abs(e(p, A, b, c)))