<a href="https://colab.research.google.com/github/AntoineChapel/vfi_project/blob/main/vfi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit

In [2]:
#random number generating key
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)

In [3]:
#numpy-based function that fills the value_array with log(max(0.01, Ak^alpha - kprim)) + beta*Vkprim

def fill_va_np(kgrid, A, alpha, beta, Vk):
  precision = kgrid.shape[0]
  value_array = np.empty((precision, precision))
  for iprim, kprim in enumerate(kgrid):
    for i, k in enumerate(kgrid):
      c = A*(k**alpha) - kprim
      value_array[i, iprim] = np.where(c > 0.01, np.log(c) + beta*Vk[iprim], -np.inf)
  return value_array

In [4]:
#jax-based function that fills the value array with log(max(0.01, Ak^alpha - kprim)) + beta*Vkprim

def fill_va_jax(kgrid, A, alpha, beta, Vk):
  prec = kgrid.shape[0]
  C_mat = jax.vmap(lambda k: A*(k**alpha) - kgrid)(kgrid)
  log_C_mat = jnp.where(C_mat > 0.0, jnp.log(C_mat), -jnp.inf)
  Vk = Vk.reshape(-1, 1)
  value_array = log_C_mat + beta*Vk.T
  return value_array
fill_va_jax_compiled = jit(fill_va_jax) #jit-compile the function

In [5]:
def vfi_np(kmin, kmax, A, alpha, beta, precision, maxiter, tol, va_func="va_np", verbose=2):
  kgrid = np.linspace(kmin, kmax, precision)
  gk = np.linspace(kmin, kmax, precision)
  Vk0 = jax.random.normal(key, shape=(1, precision)).flatten()
  norm = 1000
  n_iter = 0
  Vk = Vk0
  while n_iter < maxiter and norm > tol:
      if va_func == "va_jax":
        value_array = fill_va_jax_compiled(kgrid, A, alpha, beta, Vk)
      else:
        value_array = fill_va_np(kgrid, A, alpha, beta, Vk)

      Vkprim = np.max(value_array, axis=1)
      norm = np.max(np.abs(Vkprim - Vk))
      Vk = Vkprim.copy()

      n_iter += 1
      if verbose > 1:
        print("iteration: ", n_iter, " norm: ", norm)
  gk = kgrid[np.argmax(value_array, axis=1)]
  kstar = kgrid[np.argmin(np.abs(gk - kgrid))]
  if verbose > 0:
    print(f"The steady-state value of capital is {kstar}")


In [6]:
def vfi_jax(kmin, kmax, A, alpha, beta, prec, maxiter, tol, verbose=2):
  kgrid = jnp.linspace(kmin, kmax, prec)
  Vk0 = jax.random.normal(key, shape=(1, prec)).flatten()
  norm = 1000
  n_iter = 0

  Vk = Vk0
  while n_iter < maxiter and norm > tol:
    value_array = fill_va_jax_compiled(kgrid, A, alpha, beta, Vk)
    Vkprim = jnp.max(value_array, axis=1).block_until_ready()
    norm = jnp.max(jnp.abs(Vkprim - Vk)).block_until_ready()
    Vk = Vkprim
    n_iter += 1
    if verbose > 1:
      print(f"Iteration: {n_iter} Norm: {norm}")
  gk = kgrid[jnp.argmax(value_array, axis=1).block_until_ready()]
  kstar = kgrid[jnp.argmin(jnp.abs(gk - kgrid))]
  if verbose > 0:
    print(f"The Steady-state value of capital is:{kstar}")

In [9]:
#small dimensionality: pure numpy converges in 5.21s

%%timeit
vfi_np(1, 25, 10, 0.5, 0.9, 50, 180, 1e-6, "va_np", 0)

  value_array[i, iprim] = np.where(c > 0.01, np.log(c) + beta*Vk[iprim], -np.inf)


5.62 s ± 472 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
#small dimensionality: np outer with jax inner converges in 175ms
%%timeit
vfi_np(1, 25, 10, 0.5, 0.9, 50, 180, 1e-6, "va_jax", 0)

175 ms ± 20.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
#small dimensionality: pure jax converges in 172ms
%%timeit
vfi_jax(1, 25, 10, 0.5, 0.9, 50, 180, 1e-6, 0)

172 ms ± 10.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
#large dimensionality: np outer with jax inner converges in 636ms
%%timeit
vfi_np(1, 25, 10, 0.5, 0.9, 10000, 180, 1e-6, "va_jax", 0)

636 ms ± 3.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
#large dimensionality: pure jax converges in 679ms
%%timeit
vfi_jax(1, 25, 10, 0.5, 0.9, 10000, 180, 1e-6, 0)

679 ms ± 8.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
