<a href="https://colab.research.google.com/github/AntoineChapel/vfi_project/blob/main/vfi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Value Function Iteration: an application of Google JAX
The goal of this notebook is to demonstrate the effectiveness of Google JAX to solve economics problems with large dimensionality. I will follow the exact same algorithmic structure for the value function iteration algorithm used to solve the Neoclassical Growth Model.

As a reminder, the object of this problem is to determine an optimal policy function such that, if capital at time $t$ is $k_t$, the optimal capital to invest for next period is given by $k_{t+1} = g(k_t)$, where $g$ is the policy function.

The algorithm presented relies on the contraction mapping theorem to guarantee convergence. It is a fixed point algorithm that can be summarized in two operations. Given $n$ possible values for $k$ ($n$ states):
- Adding a $n \times n$ matrix of immediate utility and a $1 \times n$ vector $V_k$ (the value function), broadcasted to a $n \times n$ matrix, which yields a matrix called value array (identical to a $Q$ table in the $Q$-learning literature)
- Taking the $\max$ of each row of the value array, which determines the new value function $V_k'$

We show that, in low-dimensional problems (setting $n=50$), using JAX over numpy yields modest speed gains. We also notice that solving the problem is faster on CPUs than on GPUs with such low-dimensional problems.

Then, setting $n=10000$, we see that JAX yields impressive speed gains over numpy. There is no clear advantage to rewriting the entire code in jax.numpy arrays versus simply writing a jit-compiled version of the computation-intensive component of the algorithm and integrating it in an otherwise numpy environment.

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit

In [2]:
#random number generating key
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)

In [3]:
#numpy-based function that fills the value_array with log(max(0.01, Ak^alpha - kprim)) + beta*Vkprim

def fill_va_np(kgrid, A, alpha, beta, Vk):
  precision = kgrid.shape[0]
  C_mat = A*(kgrid.reshape(-1, 1)**alpha) - kgrid.reshape(1, -1)
  log_C_mat = np.where(C_mat > 0.01, np.log(C_mat), -np.inf)
  value_array = log_C_mat + beta*Vk.T
  return value_array

In [4]:
#jax-based function that fills the value array with log(max(0.01, Ak^alpha - kprim)) + beta*Vkprim

def fill_va_jax(kgrid, A, alpha, beta, Vk):
  prec = kgrid.shape[0]
  C_mat = jax.vmap(lambda k: A*(k**alpha) - kgrid)(kgrid)
  log_C_mat = jnp.where(C_mat > 0.01, jnp.log(C_mat), -jnp.inf)
  Vk = Vk.reshape(-1, 1)
  value_array = log_C_mat + beta*Vk.T
  return value_array
fill_va_jax_compiled = jit(fill_va_jax) #jit-compile the function

In [5]:
def vfi_np(kmin, kmax, A, alpha, beta, precision, maxiter, tol, va_func="va_np", verbose=2):
  kgrid = np.linspace(kmin, kmax, precision)
  gk = np.linspace(kmin, kmax, precision)
  Vk0 = jax.random.normal(key, shape=(1, precision)).flatten()
  norm = 1000
  n_iter = 0
  Vk = Vk0
  while n_iter < maxiter and norm > tol:
      if va_func == "va_jax":
        value_array = fill_va_jax_compiled(kgrid, A, alpha, beta, Vk)
      else:
        value_array = fill_va_np(kgrid, A, alpha, beta, Vk)

      Vkprim = np.max(value_array, axis=1)
      norm = np.max(np.abs(Vkprim - Vk))
      Vk = Vkprim.copy()

      n_iter += 1
      if verbose > 1:
        print("iteration: ", n_iter, " norm: ", norm)
  gk = kgrid[np.argmax(value_array, axis=1)]
  kstar = kgrid[np.argmin(np.abs(gk - kgrid))]
  if verbose > 0:
    print(f"The steady-state value of capital is {kstar}")

In [6]:
def vfi_jax(kmin, kmax, A, alpha, beta, prec, maxiter, tol, verbose=2):
  kgrid = jnp.linspace(kmin, kmax, prec)
  Vk0 = jax.random.normal(key, shape=(1, prec)).flatten()
  norm = 1000
  n_iter = 0

  Vk = Vk0
  while n_iter < maxiter and norm > tol:
    value_array = fill_va_jax_compiled(kgrid, A, alpha, beta, Vk)
    Vkprim = jnp.max(value_array, axis=1).block_until_ready()
    norm = jnp.max(jnp.abs(Vkprim - Vk)).block_until_ready()
    Vk = Vkprim
    n_iter += 1
    if verbose > 1:
      print(f"Iteration: {n_iter} Norm: {norm}")
  gk = kgrid[jnp.argmax(value_array, axis=1).block_until_ready()]
  kstar = kgrid[jnp.argmin(jnp.abs(gk - kgrid))]
  if verbose > 0:
    print(f"The Steady-state value of capital is:{kstar}")

In [7]:
#small dimensionality: pure numpy converges in 196s

%%timeit
vfi_np(1, 25, 10, 0.5, 0.9, 50, 180, 1e-6, "va_np", 0)

  log_C_mat = np.where(C_mat > 0.01, np.log(C_mat), -np.inf)


196 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
#small dimensionality: np outer with jax inner converges in 178ms
%%timeit
vfi_np(1, 25, 10, 0.5, 0.9, 50, 180, 1e-6, "va_jax", 0)

178 ms ± 36.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
#small dimensionality: pure jax converges in 137ms
%%timeit
vfi_jax(1, 25, 10, 0.5, 0.9, 50, 180, 1e-6, 0)

137 ms ± 6.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
#large dimensionality: pure numpy solution found in 3min9s
%%time
vfi_np(1, 25, 10, 0.5, 0.9, 10000, 180, 1e-6, "va_np", 0)

  log_C_mat = np.where(C_mat > 0.01, np.log(C_mat), -np.inf)


CPU times: user 1min 45s, sys: 1min 24s, total: 3min 9s
Wall time: 3min 9s


In [11]:
#large dimensionality: np outer with jax inner converges in 651ms
%%timeit
vfi_np(1, 25, 10, 0.5, 0.9, 10000, 180, 1e-6, "va_jax", 0)

651 ms ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
#large dimensionality: pure jax converges in 678ms
%%timeit
vfi_jax(1, 25, 10, 0.5, 0.9, 10000, 180, 1e-6, 0)

678 ms ± 28.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
