# Finite Markov Chains -- JAX Versions

### Prepared for the CBC Quantitative Economics Workshop (September 2022)

#### John Stachurski

In this notebook we develop some functions for manipulating finite Markov chains with JAX.

In [None]:
#!pip install quantecon

We will use the following imports.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import quantecon as qe
import jax
import jax.numpy as jnp

In [None]:
def update_scalar(P_cs, x, u):
    return jnp.searchsorted(P_cs[x, :], u)

update_vectorized = jax.vmap(update_scalar, in_axes=(None, 0, 0), out_axes=0)

@jax.jit
def update(P_cs, x, key): 
    """
    Performs a one step update of a population of k P-Markov chains, where

    * `P_cs` is a stochastic matrix converted to cumulative sums on the rows
    * `x` is a vector giving the current state of each member of the
        population, with `x[i]` in {0, ..., n-1} for all `i`
    * `key` is an instance of `jax.random.PRNGKey`

    `x` is a flat integer-valued array with values in {0, ..., n-1}
    that gives the current state across the population.
    """
    k = len(x)
    U = jax.random.uniform(key, (k, ))
    return update_vectorized(P_cs, x, U)


def simulate_mc(P, num_steps, pop_size, 
                init_vals=None, 
                seed=1234):
    """
    Pushes forward in time a population of size `pop_size`, all of which
    update independently via a P-Markov chain on the
    integers {0, ..., n-1}, where `n = len(P)`.
    
    The initial conditions for the population are given by `init_vals`, which
    is an array of integers of length `pop_size`, each element of which
    takes values in in {0, ..., n-1}.
    
    If no initial conditions are supplied then they are chosen as IID
    draws from a uniform distribution on {0,..., n-1}

    The function returns an array `x` where `x[i]` is the state of the 
    i-th element of the population after `num_steps` updates.
    """

    P = jnp.array(P)
    
    assert (len(P.shape) == 2), "P must be two-dimensional."
    n, k = P.shape
    assert (n == k), "P must be a square matrix."

    state = jnp.arange(n)
    P_cs = jnp.cumsum(P, axis=1)

    key = jax.random.PRNGKey(seed)
    
    if init_vals is None:
        init_vals = jax.random.randint(key, (pop_size,), minval=0, maxval=n-1)

    x = init_vals
    for _ in range(num_steps):
        key, subkey = jax.random.split(key)
        x = update(P_cs, init_vals, key)

    return x

### Test Case: Business Cycles

As a test case, we use the stochastic matrix 

$$
P_H =
\left(
  \begin{array}{ccc}
     0.971 & 0.029 & 0 \\
     0.145 & 0.778 & 0.077 \\
     0 & 0.508 & 0.492
  \end{array}
\right)
$$

Estimated from  US unemployment data by Hamilton
[[Ham05](https://python.quantecon.org/zreferences.html#id164)].

In [None]:
P_H = [
    [0.971, 0.029, 0    ],
    [0.145, 0.778, 0.077],
    [0,     0.508, 0.492]
]

In [None]:
type(P_H)

In [None]:
P_H = np.array(P_H)

In [None]:
type(P_H)

In [None]:
P_H[0, 0]

In [None]:
P_H[0, :]

Now that $P_H$ is a NumPy array, we can compute powers by matrix multiplication:

In [None]:
P2_H = P_H @ P_H   # Two step transition probabilities