**Today's topic**

The aim for today's workshop is to discuss how we use JAX to solve and estimate dynamic life-cycle models by ccp inversion. The JAX API allow e.g. for hardware acceleration and just-in-time (jit) compilation.

**Usefull functions**

* einsum()
* while_loop()





In [57]:
from jax._src.api import block_until_ready
import jax.numpy as jnp
from jax import random

def rnd(*args):
  return random.normal(random.PRNGKey(123), (args))

I, J, K, L = 10, 4, 5, 6

A, B = rnd(I, J, K), rnd(I, K, L) #create the two ndarrays A and B

# Case 1: Vectorice the following calculation
C = jnp.empty((I,J,L))
for i in range(I):
  C = C.at[i,:,:].set(A[i,:,:] @ B[i,:,:])
  # C[i,:,:] = A[i,:,:] @ B[i,:,:] #unlike Numpy this is not possible in JAX as arrays are immutable
print('shape of C:'+str(C.shape))

C_matmul = A @ B
print('execution time for matmul: ')
%timeit (A @ B).block_until_ready()

C_einsum = jnp.einsum('ijk, ikl -> ijl', A, B)
print('execution time for einsum: ')
%timeit jnp.einsum('ijk, ikl -> ijl', A, B).block_until_ready()

print('check whether each element of C and C_matmul are identical: '+str(jnp.all(jnp.isclose(C,C_matmul))))
print('check whether each element of C and C_einsum are identical: '+str(jnp.all(jnp.isclose(C,C_einsum))))

shape of C:(10, 4, 6)
execution time for matmul: 
23.1 µs ± 13.9 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
execution time for einsum: 
286 µs ± 46.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
check whether each element of C and C_matmul are identical: True
check whether each element of C and C_einsum are identical: True


In [58]:
I, J, K = 10, 4, 5

D, E = rnd(I,K), rnd(J,K,I)

# Case 2: Vectorice the following calculation
F = jnp.empty((I,J))
for i in range(I):
  F = F.at[i,:].set(D[i,:] @ E[:,:,i].transpose(1,0) )
print('shape of F: '+str(F.shape))

F_matmul = jnp.squeeze(jnp.matmul(jnp.reshape(D, (I,1,K)), E.transpose(2, 1, 0)))
print('execution time for matmul: ')
%timeit jnp.squeeze(jnp.matmul(jnp.reshape(D, (I,1,K)), E.transpose(2, 1, 0))).block_until_ready()

F_einsum = jnp.einsum('ik, lki -> il', D, E)
print('execution time for einsum: ')
%timeit jnp.einsum('ik, lki -> il', D, E).block_until_ready()

print('check whether each element of F and F_matmul are identical: '+str(jnp.all(jnp.isclose(F,F_matmul))))
print('check whether each element of F and F_einsum are identical: '+str(jnp.all(jnp.isclose(F,F_einsum))))


shape of F: (10, 4)
execution time for matmul: 
321 µs ± 13.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
execution time for einsum: 
92.7 µs ± 1.28 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
check whether each element of F and F_matmul are identical: True
check whether each element of F and F_einsum are identical: True


In [59]:
import jax.numpy as jnp
from jax import lax
from jax import jit
import jax

def ArrayAddMult(A,B,C,t):
  return C[t+1,:] + A[t,:] @ B[t,:]

def SimpleLoop(A,B,C):
  T = jnp.shape(A)[0]
  for t in reversed(range(T)):
    C = C.at[t,:,:].set(ArrayAddMult(A,B,C,t))
  return C

T, I, J, K, L = 200, 10, 4, 5, 6

A, B, C = rnd(T,I,J,K), rnd(T,I,K,L), jnp.zeros((T+1,I,J,L))

print('execution time for python for loop: ')
%timeit SimpleLoop(A,B,C).block_until_ready()
print('execution time for JIT compiled python for loop: ')
%timeit jit(SimpleLoop)(A,B,C).block_until_ready()

def Cond(tup): # tup = (A, B, C, t)
  return (tup[-1] - 1 >= 0)

def LaxLoop(tup):
  (A, B, C, t) = tup
  t = t - 1
  C = C.at[t,:,:].set(ArrayAddMult(A,B,C,t))
  return (A, B, C, t)

tup = (A, B, C, T)
print('execution time for lax.while_loop(): ')
%timeit jax.block_until_ready(lax.while_loop(cond_fun=Cond, body_fun=LaxLoop, init_val=tup))


execution time for python for loop: 
1.79 s ± 241 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
execution time for JIT compiled python for loop: 
6.5 ms ± 937 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
execution time for lax.while_loop(): 
1.9 ms ± 395 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


**Life-cycle model**

Let's consider a simple human capital model, where the agents have a finitie time horizon and in each time period must decide on their  labor supply, $\ell_{it}$, in order to maximizer their expected utility over their life-time

\begin{align}
\max_{\ell_{it}\in(0,1)} v_{it}(\ell_{it},h_{it}) + \epsilon_{it}(\ell_{it}),
\end{align}
subject to the law of motion for human capital, $h_{it}$, which is governed by the transition probabilities, $p(h_{it+1}|\ell_{it},h_{it})$.

$v_{it}(\ell_{it},h_{it})$ is the choice-specific value function
\begin{align}
v_{it}(\ell_{it},h_{it}) &= u(\ell_{it}h_{it}) + \beta E_{t}V_{it+1}(\ell_{it},h_{it}),\\
E_{t}V_{it+1}(\ell_{it},h_{it}) &= \sum_{h_{t+1}}p(h_{it+1}|\ell_{it},h_{it})E_{t+1}V_{it+1}(h_{it+1}),
\end{align}

$u(\ell_{it}h_{it})$ is the instantanous utility, $p(h_{it+1}|\ell_{it},h_{it})$ is the transition probability for human capital, and $E_{t+1}V_{it+1}(h_{it+1})$ is the expected value function of next period. The instantanous utility is given by sum of utility of consumption, (dis-)utility of working, and the structural error term

\begin{align}
u(0,h_{it}) &= u_c(b), \\
u(1,h_{it}) &= u_c(w(h_{it})) + \psi_0 + \psi_1 t + e_{t}(h_{it}).
\end{align}

As the saving motive is ignored, the agent simply consume her entire income in each period, and utility of consumption is assumed to given by the CRRA utility function.

\begin{align}
u_c(c) &= \tfrac{c^{1-\gamma}-1}{1-\gamma}.
\end{align}
When unemployed the agent receive an exogenous benefit, $b$. In contrast, the agent earn a wage income when employed, which is assumed to be a function of the agent's human capital, $w(h_{it})$ Finally, as we will assume the tast-shock, $\epsilon_{it}(\ell_{it})$, is extreme value type-I distributed, the choice probabilities and expected utility are given by the well known logit choice probabilities and log-sum

\begin{align}
q_{it}(\ell_{it},h_{it}) &= \tfrac{\exp\{v_{it}(\ell_{it},h_{it})\}}{\exp\{v_{it}(0,h_{it})\} + \exp\{v_{it}(1,h_{it})\}}, \tag{1} \\
E_{t}V_{it}(h_{it}) &= \log[\exp\{v_{it}(0,h_{it})\} + \exp\{v_{it}(1,h_{it})\}]. \tag{2}
\end{align}

**CCP inversion**

Assume we know the none-linear structural parameters ($\gamma$,$\beta$). By inverting the choice probabilities we can isolate the linear utility terms.
\begin{equation}
 \log q_t(1,h_{it}) - \log q_t(0,h_{it}) - \big\{u_c(w(h_{it})) - u_c(b) \big\} - \beta \big\{ E_{t+1}V_{it+1}(1,h_{it}) - E_{t+1}V_{it+1}(0,h_{it}) \big\}  = \psi_0 + \psi_1 t + e_t(h_{it}). \tag{3}
\end{equation}
However, in order to evaluate $y_{it}$ we need to calculate the the expected value function, $E_{t+1}V_{it+1}(\ell_{it},h_{it})$. We can use the trick that the log of the denominator of the choice probabilities equals the expected value, see equation $(1)$-$(2)$
\begin{align}
E_{t+1}V_{it+1}(h_{it+1}) &= v(0,h_{it+1}) - \log q_{it+1}(0,h_{it+1}),\\
&= \big\{ u_c(b) + \beta \sum_{h_{t+2}}p(h_{it+2}|0,h_{it+1})E_{t+1}V_{it+2}(h_{it+2}) \big\} - \log q_{it+1}(0,h_{it+1}). \tag{4}
\end{align}
Hence, $E_{t+1}V_{it+1}(\ell_{it},h_{it})$ and $E_{t+1}V_{it+1}(h_{it+1})$ can be calculated by backward induction.

Based on equation $(3)$-$(4)$ we can then estimate the linear structural parameters ($\psi_0$,$\psi_1$) by a simple regression
\begin{align}
 y_{it}(h_{it})  = X_{it} \psi + e_t(h_{it}), \tag{5}
\end{align}
where $X_{it}$ contains an intercept and the age of the agent (1,t) and $\psi$ contains the linear parameters $(\psi_0,\psi_1)$. Finally, $y_{it}$ is the left side of equation $(5)$.

In [60]:
import jax.numpy as jnp
from jax import lax
from jax import random

from jax.config import config
config.update("jax_enable_x64", True)

import dataclasses as dc
import statsmodels.api as sm
import numpy as np

In [65]:
@dc.dataclass
class Par:
    # Number of alternatives
    L: int = 2 #number of alternatives for the labor supply decision (unemployement, employement)

    # Set dimensions of state variables
    T: int = 25 #number of time periods the agents live
    H: int = T #number of grid points for human capital
    TH: int = T*H #combinations of state variables

    # Set structural parameters
    beta: float = 0.90 #discount factor
    gamma: float = 2.00 #CRRA parameter
    psi0: float =-0.50 #(dis-)utility from working (baseline)
    psi1: float =-0.20 #(dis-)utility from working (linear in age)#

    # Set parameters governing the income process
    benefit: float = 0.10 #benefit level (income when unemployed)
    wageBase: float = 0.10 #wage when human capital is zero
    wageGrowth: float = 2.00**(1/(H-1)) #wage growth as human capital increase

    # Set parameters governing the distribution of structural errors
    std = 0.5

par = Par() #set model parameters

def ExogenousVariables(par):
  """ Generate exogenous variables of the life-cycle model

      Inputs
        - par: data class containing model parameters

      Outputs
        - exog: data class containing the exogenout variables of the model
  """
  # Set income process
  income = jnp.empty((par.L,par.H))
  income = income.at[0,:].set( par.benefit )
  income = income.at[1,:].set( par.wageBase*(par.wageGrowth**jnp.arange(par.H)) )

  # Transition probabilities for human capital (HC)
  prob_HC = jnp.zeros((par.L,par.H,par.H))
  prob_HC = prob_HC.at[0,:,:].set( jnp.eye(par.H) )
  prob_HC = prob_HC.at[1,:-1,1:].set( jnp.eye(par.H-1) )
  prob_HC = prob_HC.at[1,-1,-1].set( 1.0 )

  # Draw structural errors for employment from the normal distribution
  error = jnp.zeros((par.L,par.T,par.H))
  error = error.at[1,:,:].set(par.std * random.normal(random.PRNGKey(345), shape=(par.T,par.H) ) )

  # Store exogenous variables in data class
  @dc.dataclass
  class Exog:
    income: jnp.ndarray #Income process
    prob_HC: jnp.ndarray #Transition probabilities for human capital
    error: jnp.ndarray #Structural errors

  exog = Exog(income=income, prob_HC=prob_HC, error=error)
  return exog

exog = ExogenousVariables(par)

if par.H<10:
  print('income when unemployed as a function of human capital, b:')
  print(round(exog.income[0,:],2))
  print('income when employed as a function of human capital, w(h_t):')
  print(round(exog.income[1,:],2))
  print('transition probabilities for human capital when unemployed, p(h_t+1|0,h_t):')
  print(exog.prob_HC[0,:])
  print('transition probabilities for human capital when employed, p(h_t+1|1,h_t):')
  print(exog.prob_HC[1,:])
  print('structural error as a function of age and human capital, e(h_t):')
  print(exog.error[1,:])

In [63]:
def Logit(V,axis=0):
    """ Returns the logit choice probabilities and the log-sum of the associated payoff matrix, V """
    maxV = jnp.max(V, axis=axis, keepdims=True) #used for centering

    nominator = jnp.exp(V - maxV) #nominator of the logit choice probabilities
    denominator = jnp.sum(nominator, axis=axis, keepdims=True) #denominator of the logit choice probabilities
    return nominator / denominator, jnp.log(denominator) + maxV

def CRRA(x,gamma):
  """ Return the CRRA utility of consumption, x, given the parameter gamma """
  return (x**(1.0 - gamma) - 1.0)/(1.0 - gamma)

def Utility(income,error,age,gamma,psi0,psi1):
  """ Return the instantanous choice-specific utility function """
  laborSupply = jnp.arange(income.shape[0])[:,jnp.newaxis] #indicator function for unemployment and employment
  return CRRA(income,gamma) + laborSupply * (psi0 + psi1 * age) + error #instantanous choice-specific utilities

def Bellman(utility,prob_HC,EVnext,beta):
  """ Solve the Bellman equation

      Inputs
        - utility: instantanous choice-specific utility
        - prob_HC: transition probabilities for human capital
        - EVnext: expected value function of next period
        - beta: discount factor

      Outputs
        - v: choice-specific value function
        - q: choice probabilities
        - EV: expected value function of the period
  """
  # explanation of the subscripts for jnp.einsum()
  #   l: labor supply alternatives
  #   h: current human capital level
  #   k: next period human capital level
  v = utility + beta * jnp.einsum('lhk, k -> lh', prob_HC, EVnext) #calculate choice-specific value functions

  q, EV = Logit(v) #calculate choice probabilities and the expected value function
  return v, q, EV

In [37]:
def Condition(inputTuple):
  """ Stopping criterium for the while loop """
  t = inputTuple[-1] #unpack t from the tuple (t has to be the last element of the tuple)
  return (t - 1 >= 0) #continue if t - 1 > 0

def BackwardRecursion(endogTuple,exog,par):
  """ Solve and store the solution of the Bellman equation for the time period t

      Inputs
        - endogTuple: tuple containing the endogenous variables of the model, (v, q, EV, t)
        - exog: data class containing the exogenous variables of the model, (prob_HC, income)
        - par: data class containing the structural parameters of the model

      Outputs
        - tuple containing the updated endogenous variables of the model
  """
  (v, q, EV, t) = endogTuple #unpack tuple

  t = t - 1 #recurse backward

  u_t = Utility(exog.income,exog.error[:,t,:],t,par.gamma,par.psi0,par.psi1) #calculate instantanous utility
  v_t, q_t, EV_t = Bellman(u_t,exog.prob_HC,EV[t+1,:],par.beta) #solve bellman

  # store solution for period t
  v = v.at[:,t,:].set( v_t )
  q = q.at[:,t,:].set( q_t )
  EV = EV.at[t,:].set( jnp.squeeze(EV_t) )
  return (v, q, EV, t) #return tuple with updated values

def SolveLifeCycleModel(exog,par):
  """ Solve the life cycle model by backward induction

      Inputs
        - exog: data class containing the exogenous variables of the model
        - par: data class containing the structural parameters of the model

      Outputs
        - endog: data class containing the endogenous variables of the model
  """
  v = jnp.empty((par.L,par.T,par.H)) #initialize choice-specific value function
  q = jnp.empty((par.L,par.T,par.H)) #initialize choice probabilities
  EV = jnp.empty((par.T+1,par.H)) #initialize expected value function

  t = par.T #initialize time

  endogTuple = (v, q, EV, t) #tuple with endogenous variables

  # Solve model by backward induction
  Fun = lambda x: BackwardRecursion(x,exog,par)
  (v, q, EV, t) = lax.while_loop(body_fun=Fun, cond_fun=Condition, init_val=endogTuple)

  # for t in reversed(range(par.T)):
  #   endogTuple = Fun(endogTuple)
  # (v, q, EV, t) = endogTuple

  # Store solution in a data class
  @dc.dataclass
  class Endog:
    v: jnp.ndarray #choice-specific value function
    q: jnp.ndarray #choice probabilities
    EV: jnp.ndarray #expected value function

  endog = Endog(v=v, q=q, EV=EV) #store solution in the data class
  return endog

In [42]:
def MyOLS(y,X):
    """ Estimate linear parameters by ordinary least squares (OLS) """
    return jnp.linalg.solve(jnp.matmul(X.T, X), jnp.matmul(X.T, y))

def BackwardRecursionCCP(evTuple,data,par):
  """ Calculate expected value function by CCP inversion

      Inputs
        - evTuple: tuple containing (EV, t)
        - data: data class containing observed data
        - par: data class containing model parameters

      Outputs
        - evTuple: tuple containing updated values of (EV, t)
  """
  (EV, t) = evTuple #unpack tuple

  t = t - 1 #recurse backward

  # Calculate expected value functions by CCP inversion
  EV_t = data.uC[0,:] - jnp.log(data.q[0,t,:]) + par.beta * EV[t+1,:]

  # Store solution
  EV = EV.at[t,:].set( EV_t )
  return (EV, t) #return tuple with updated values

def SetupRegression(data,par):
  """ CCP inversion

      Inputs
        - data: data class containing observed data
        - par: data class containing model parameters

      Outputs
        - y: vector containing the dependent variable
        - X: matrix containing the independent variables
  """
  evTuple = (jnp.zeros((par.T+1,par.H)), par.T) #initialize tuple, evTuple = (EV, t)

  data.uC = CRRA(data.income, par.gamma) #utility of consumption

  Fun = lambda x: BackwardRecursionCCP(x,data,par)
  (EV, t) = lax.while_loop(body_fun=Fun, cond_fun=Condition, init_val=evTuple) #calculate expected value function by CCP inversion

  # explanation of the subscripts for jnp.einsum()
  #   l: labor supply alternatives
  #   h: current human capital level
  #   k: next period human capital level
  #   t: current time period (current age)
  pEV = jnp.einsum('lhk, tk -> lth',data.prob_HC, EV) #calculate choice-specific expected value functions

  # Set up dependent variable and independent variables (y, X)
  y = jnp.reshape(jnp.log(data.q[1,:]) - jnp.log(data.q[0,:])
                  - (data.uC[1,:] - data.uC[0,:])
                  - par.beta * (pEV[1,1:,:] - pEV[0,1:,:]), (par.TH,1), order='F' )

  X = jnp.c_[jnp.ones((par.TH,1)), jnp.tile(jnp.arange(par.T), (par.H,) ) ]
  return y, X

def Estimation(data,par):
  """ Estimate linear parameters by CCP inversion

      Inputs
        - data: data class containing observed data
        - par: data class containing model parameters

      Outputs
        - pvec: vector of estimated linear parameters
  """
  y, X = SetupRegression(data,par) #set up variables for regression based on observed data
  #results = MyOLS(y, x)

  results = sm.OLS(np.array(y), np.array(X)).fit() #estimate linear model by OLS

  return results, y, X

In [66]:
endog = SolveLifeCycleModel(exog,par) #solve the model to obtain the observed choice probabilities

#store the observed data in data class
@dc.dataclass
class Data:
  income: jnp.ndarray #Observed income process
  prob_HC: jnp.ndarray #Observed transition probabilities for human capital
  q: jnp.ndarray #Observed choice probabilities

data = Data(income=exog.income, prob_HC=exog.prob_HC, q=endog.q) #observed data

results = Estimation(data,par)[0] #estimates linear parameters (psi0,psi1)

# The estimates differ from the true parameter values due to the structural errors
print(results.summary())
#Note that we have assumed that the none-linear parameters (beta,gamma) are known

                            OLS Regression Results                            
Dep. Variable:                      y   R-squared:                       0.898
Model:                            OLS   Adj. R-squared:                  0.897
Method:                 Least Squares   F-statistic:                     5463.
Date:                Fri, 15 Sep 2023   Prob (F-statistic):          1.54e-310
Time:                        02:52:00   Log-Likelihood:                -452.51
No. Observations:                 625   AIC:                             909.0
Df Residuals:                     623   BIC:                             917.9
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
const         -0.4387      0.039    -11.300      0.0