# Using Jax to increase computation speed

[Jax](https://jax.readthedocs.io/en/latest/quickstart.html) is a library focus on high performance computing by offering an simple API with a Numpy-like interface that offers the following features:
- Computations on CPU, GPU, and TPU
- Just-In-Time compilation (JIT)
- Automatic differentiation (autograd)

Below, we illustrate how Jax can be applied to a dynamic programming problem, specifically the McCall job search model.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import time

## 1. Baseline Solution to the McCall model

We begin with our baseline solution to the McCall model, found by iterating on the value function and using Numpy operations.

In [2]:
# Bellman operator
def Tv(V, b, beta, p, wages):
    EV = np.sum(V * p)
    v_search = b + beta * EV
    v_accept = wages / (1 - beta)
    TV = np.maximum(v_accept, v_search)
    return TV

# Define function to solve the job search problem
def McCall_VFI(theta=0.3, beta=0.99, mu=40, sigma=0.5, w_min=10, w_max=100, n=1_000_000, v_tol=1e-6, max_iter=1000):
    # default distribution wages
    wages = np.linspace(w_min, w_max, n+1)
    # vector of probabilities of each wage being drawn from log normal distribution
    p = stats.lognorm.pdf(wages, s=sigma, scale=mu)
    p = p / p.sum() # normalize for truncated distribution
    # set other model parameters
    expected_wage = np.sum(wages * p)
    b = theta * expected_wage  # b is replacement rate, c is unemployment benefit
    V = np.zeros(wages.size)  # value function guess, could have better one
    phi = np.zeros(wages.size) # policy function
    v_dist = 10
    iter = 0
    while (v_dist > v_tol) & (iter < max_iter):
        TV = Tv(V, b, beta, p, wages)
        v_dist = np.abs(V-TV).max()
        V = TV
        iter += 1
    return V

In [3]:
%time McCall_VFI()

CPU times: user 546 ms, sys: 698 ms, total: 1.24 s
Wall time: 1.25 s


array([ 7643.33969033,  7643.33969033,  7643.33969033, ...,
        9999.982     ,  9999.991     , 10000.        ])

### Check on if we have a GPU available

NOTE: I don't know to do this on Windows generally.  If trying from the command line, you can use `nvidia-smi.exe` to check if you have a GPU available, but I think you need to be in the directory where your GPU drivers are installed (see [this Stack Overflow](https://stackoverflow.com/questions/57100015/how-do-i-run-nvidia-smi-on-windows)).


In [4]:
# check if GPU is available
# !nvidia-smi  # linux
!system_profiler SPDisplaysDataType # mac silicon

Graphics/Displays:

    Apple M1 Max:

      Chipset Model: Apple M1 Max
      Type: GPU
      Bus: Built-In
      Total Number of Cores: 32
      Vendor: Apple (0x106b)
      Metal Support: Metal 3
      Displays:
        Color LCD:
          Display Type: Built-in Liquid Retina XDR Display
          Resolution: 3024 x 1964 Retina
          Main Display: Yes
          Mirror: Off
          Online: Yes
          Automatically Adjust Brightness: Yes
          Connection Type: Internal



## 2. Solution with Jax

Now we solve the same model, but replace our `np.` calls with `jnp.` calls. We also use the `device_put` function to move our arrays to the GPU.

In [5]:
# Define the Bellman operator with Jax functions
def Tv_jax(V, b, beta, p, wages):
    EV = jnp.sum(V * p)
    v_search = b + beta * EV
    v_accept = wages / (1 - beta)
    TV = jnp.maximum(v_accept, v_search)
    return TV


# Define function to solve the job search problem
def McCall_VFI_jax(theta=0.3, beta=0.99, mu=40, sigma=0.5, w_min=10, w_max=100, n=1_000_000, v_tol=1e-6, max_iter=1000):
    # default distribution wages
    wages = np.linspace(w_min, w_max, n+1)
    # vector of probabilities of each wage being drawn from log normal distribution
    p = stats.lognorm.pdf(wages, s=sigma, scale=mu)
    p = p / p.sum() # normalize for truncated distribution
    # set other model parameters
    expected_wage = np.sum(wages * p)
    b = theta * expected_wage  # b is replacement rate, c is unemployment benefit
    V = np.zeros(wages.size)  # value function guess, could have better one
    phi = np.zeros(wages.size) # policy function

    # Shift all NumPy arrays onto the GPU
    wages = jax.device_put(wages)
    p = jax.device_put(p)
    V = jax.device_put(V)
    phi = jax.device_put(phi)

    v_dist = 10
    iter = 0
    while (v_dist > v_tol) & (iter < max_iter):
        TV = Tv_jax(V, b, beta, p, wages)
        v_dist = jnp.abs(V-TV).max()
        V = TV
        iter += 1
    return V

In [6]:
%time McCall_VFI_jax().block_until_ready()  # using block_until_ready to get accurate timing

CPU times: user 775 ms, sys: 449 ms, total: 1.22 s
Wall time: 653 ms


Array([ 7643.3257,  7643.3257,  7643.3257, ...,  9999.981 ,  9999.991 ,
       10000.    ], dtype=float32)

Notice that Jax **reduces the time to solve the model by about 2/3**.

## 3. Solution with Jax GPU + JIT

Finally, we use the `jit` function to compile our function and run it on the GPU. This further reduces the time to solve the model.

In [7]:
# "Jitting" the Bellman operator
Tv_jax_jit = jax.jit(Tv_jax)

In [13]:
theta=0.3
beta=0.99
mu=40
sigma=0.5
w_min=10
w_max=100
n=100_000_000
wages = np.linspace(w_min, w_max, n+1)
p = stats.lognorm.pdf(wages, s=sigma, scale=mu)
p = p / p.sum() # normalize for truncated distribution
mean_wage = np.sum(wages * p)
b = theta * mean_wage
V = np.zeros(wages.size)

## FOR Jax calls, put the arrays on the GPU
 # Shift all NumPy arrays onto the GPU
beta_j = jax.device_put(beta)
b_j = jax.device_put(b)
wages_j = jax.device_put(wages)
p_j = jax.device_put(p)
V_j = jax.device_put(V)

%timeit Tv(V, b, beta, p, wages)
%timeit Tv_jax(V_j, b_j, beta_j, p_j, wages_j).block_until_ready()
%timeit Tv_jax_jit(V_j, b_j, beta_j, p_j, wages_j).block_until_ready()

336 ms ± 9.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
74.4 ms ± 2.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
40.5 ms ± 609 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
