<a href="https://colab.research.google.com/github/QuantEcon/qe_lunch_workshops/blob/w2/week_02/LunchWorkshop21042023.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Today's topic**

The aim for today's workshop is to introduce a two-sided matching models with transferable utility and discuss numerical methods for solving systems of equations. The discussion will focus on how these methods can be implemented in JAX. This Python package allow for just-in-time (jit) compilation and automatic differentiation. 

**Matching model**

Consider a matching market that consist of $X$ worker types and $Y$ firm types. It is assumed that there exists a continum of each type, and the marginal distribution of worker and firm types are denoted by $n_x$ and $n_y$, respectively.

Each worker of type $x$ face the discrete choice of working for one of the $Y$ types of firms or become unemployed

\begin{align}
\max_{y} \tilde{u}_{xy} + \epsilon_{xy},
\end{align}

where $\tilde{u}_{xy}$ is the deterministic utility term

\begin{align}
\tilde{u}_{xy} &= u_{xy} + t_{xy}, \quad for \quad y=1,...,Y,\\
\tilde{u}_{x0} &= 0. \\
\end{align}

Similarily, each firm of type $y$ face the discrete choice of hirering one of the $X$ type of workers or not hire anyone

\begin{align}
\max_{x} \tilde{v}_{xy} + \eta_{xy},
\end{align}
where $\tilde{v}_{xy}$ is the deterministic productivity term

\begin{align}
\tilde{v}_{xy} &= v_{xy} - t_{xy}, \quad for \quad x=1,...,X,\\
\tilde{v}_{0y} &= 0.
\end{align}

If the taste-shocks $(\epsilon_{xy},\eta_{xy})$ are assumed iid type-I extreme value distributed the choice probabilities of the workers and firms $(p_{xy},q_{xy})$ are given by the logit choice probabilities

\begin{align}
p_{xy} &= \frac{\exp(u_{xy} + t_{xy})}{1+\sum_{y=1}^Y \exp(u_{xy} + t_{xy}) },\quad \forall (x,y), \\
q_{xy} &= \frac{\exp(v_{xy} - t_{xy})}{1+\sum_{x=1}^X \exp(v_{xy} - t_{xy}) },\quad \forall (x,y).
\end{align}

The wages, $t_{xy}$, are determined by a set of market clearing conditions, such that supply and demand equate

\begin{equation}
p_{xy} \cdot n_x = q_{xy} \cdot n_y,\quad \forall (x,y).
\end{equation}

By some mathematical manipulations of the market clearing conditions, the wages can be expressed as a set of fixed point equations

\begin{equation}
t_{xy} = t_{xy} + \tfrac{1}{2} \cdot \log \bigg\{ \tfrac{q_{xy} \cdot n_y}{p_{xy} \cdot n_x} \bigg\},\quad \forall (x,y),
\end{equation}
which can be shown to be a contraction mapping.

**Solving the matching model numerical**

Last section suggests at least two methods to find the equilibrium wages of our matching market
<ol>
  <li>Solve the market clearing conditions.</li>
  <li>Solve the fixed point equation.</li>
</ol>

**References**

Eugene Choo, and Aloysius Siow. “Who Marries Whom and Why.” Journal of Political Economy 114, no. 1 (2006): 175–201. https://doi.org/10.1086/498585.

Arnaud Dupuy, and Alfred Galichon. "A note on the estimation of job amenities and labor productivity." Quantitative Economics 13 (2022): 153-177. https://doi.org/10.3982/QE928.



In [None]:
import jax.numpy as jnp
from jax import random, jit, jacobian

from jax.config import config
config.update("jax_enable_x64", True)

In [None]:
!pip install jax_dataclasses
import jax_dataclasses as jdc

In [None]:
@jdc.pytree_dataclass
class Exo:
  # Set dimensions of the matching market
  X = 3 #number of worker types
  Y = 2 #number of firm types
  
  vectorLength = (X*Y)
  matrixDimensions = (X, Y)

  dimensionsToSumWorkers = 1
  dimensionsToSumFirms = 0

  # Simulate match-specific payoffs
  u =-random.uniform(random.PRNGKey(111), [X,Y]) # match-specific utilities
  v = random.uniform(random.PRNGKey(222), [X,Y]) # match-specific productivities

  # Simulate marginal distribution of types
  n_x = random.uniform(random.PRNGKey(333), [X,1]) # marginal distribution of worker types;
  n_y = 1.0 + random.uniform(random.PRNGKey(444), [1,Y]) # marginal distribution of firm types;

exo = Exo()

print('u_xy')
print(exo.u)
print('v_xy')
print(exo.v)
print('n_x')
print(exo.n_x)
print('n_y')
print(exo.n_y)

# initial guess for wages
t0_matrix = jnp.zeros((exo.X,exo.Y)) # initial guess for wages (matrix form)
t0_vector = jnp.reshape(t0_matrix, exo.vectorLength, order='F') # initial guess for wages (vector form)

In [None]:
def Logit(Vmatch,dimensionsToSum):
    """Returns the logit choice probabilities of the associated tensor of match-specific payoffs, Vmatch."""

    nominator = jnp.exp(Vmatch)

    denominator = 1.0 + jnp.sum(nominator, axis=dimensionsToSum, keepdims=True) #Note that exp(Vunmatch) = 1.0 as Vunmatch = 0.0
    return nominator/denominator

def Supply(t0,exo):
    """Calculates the workers supply for any given match, S, for a given set of match-specific wages, t0"""

    p_xy = Logit(exo.u + t0, exo.dimensionsToSumWorkers) #workers' choice probabilities

    S = p_xy * exo.n_x # workers' supply
    return S, p_xy

def Demand(t0,exo):
    """Calculates firms demand for any given match, D, for a given set of match-specific wages, t0"""
        
    q_xy = Logit(exo.v - t0, exo.dimensionsToSumFirms) #firms' choice probabilities

    D = q_xy * exo.n_y # firms' demand

    return D, q_xy

def Vector2Matrix(var,matrixDimensions):
    return jnp.reshape(var, matrixDimensions, order='F')

def Matrix2Vector(var,vectorLength):
    return jnp.reshape(var, vectorLength, order='F')

def UpdateTransfer(t0_vector,exo):
    t0_matrix = Vector2Matrix(t0_vector, exo.matrixDimensions) #Initial guesswages
    
    S = Supply(t0_matrix,exo)[0] # Calculate workers' labor supply
    D = Demand(t0_matrix,exo)[0] # Calculate firms' labor demand

    Z = Matrix2Vector(D - S, exo.vectorLength) # Calculate excess demand, market clearing condition

    t1_matrix = t0_matrix + (1/2) * jnp.log(D/S) # Evaluate the fixed point equation

    t1_vector = Matrix2Vector(t1_matrix, exo.vectorLength)
    return t1_vector, Z

def ChoiceProbabilities(t0,exo):
    t0_tensor = Vector2Matrix(t0,exo.matrixDimensions)

    prob_M_x = Supply(t0_tensor,exo)[1] #workers' choice probabilities for any match
    prob_M_y = Demand(t0_tensor,exo)[1] #firms' choice probabilities for any match

    prob_S_x = 1 - jnp.sum(prob_M_x, axis=exo.dimensionsToSumWorkers) #workers' choice probabilities of being unmatched
    prob_S_y = 1 - jnp.sum(prob_M_y, axis=exo.dimensionsToSumFirms) #firms' choice probabilities of being unmatched
    
    return prob_M_x, prob_M_y, prob_S_x, prob_S_y



In [None]:
t1_vector, Z0 = UpdateTransfer(t0_vector, exo)
t1_matrix = Vector2Matrix(t1_vector, exo.matrixDimensions)

print('Initial wages, t0:')
print(t0_matrix)
print('Excess demand given the initial wages, Z(t0) = D(t0) - S(t0):')
print(Vector2Matrix(Z0, exo.matrixDimensions))
print('Updated wages, t1 = t0 + (1/2)*log{ D(t0) / S(t0) }:')
print(t1_matrix)

In [None]:
#class containing simple implementations of numerical optimizers

class solver():
    def __init__(self,**kwargs):
        self.setup(**kwargs)

    def setup(self,**kwargs): #**kwargs means that it is optionally to include this input
        self.output = 'iter'

        self.step_tol = 1.0e-10 # step tolerance
        self.root_tol = 1.0e-10 # root tolerance (only used in the solver RootSuccessiveApproximation and NewtonRaphson)
        self.iter_max = 10000 # Maximum number of iterations

        # Poly
        self.sa_max = 20 # Maximum number of successive approximation steps
        self.nk_max = 5 # Maximum number of newton-kantorovich steps
        self.max_fxpiter = 5 # Maximum number of times to switch between Newton-Kantorovich iterations and contraction iterations.
        
        # If kwargs exist
        for key,val in kwargs.items():
            setattr(self,key,val)

    def SuccessiveApproximation(self, fun, x0):
        # Solve for fixed point using successive approximations
        class out: pass
        out.converged = 'None'
        out.iter = 0

        if self.output != 'none':
            print('Succesive approximation (SA)')

        fun_jit = lambda x: jit(fun)(x)

        for i in range(self.iter_max):
            x1 = fun_jit(x0)
            step_norm = jnp.linalg.norm(x1-x0)

            # Stopping criteria 1
            if step_norm<self.step_tol:
                out.converged = 'converged'
                break

            # Stopping criteria 2
            if jnp.any(jnp.isnan(x1)):
                out.converged = 'NaN'
                break

            if self.output == 'iter':
              print('SA: iterations='+str(i)+', norm='+str(step_norm))

            x0 = x1.copy()

        out.iter = i
        out.norm = step_norm

        if self.output != 'none':
            print('SA: iterations='+str(out.iter)+', norm='+str(step_norm)+', solver: '+out.converged)
        
        return x1, out

    def NewtonRaphson(self, f, df, x0):
        class out: pass
        out.converged = 'None'
        out.iter = 0

        if self.output != 'none':
            print('Newton-Raphson (NR)')

        for i in range(self.iter_max):
            fx0 = f(x0)

            root_norm = jnp.linalg.norm(fx0)

            # Stopping criteria 1
            if root_norm<self.root_tol:
                out.converged = 'converged'
                break

            dx = jnp.linalg.solve(df(x0), fx0)
            x1 = x0 - dx

            step_norm = jnp.linalg.norm(dx)

            # Stopping criteria 2
            if step_norm<self.step_tol:
                out.converged = 'no change in x'
                break

            if jnp.any(jnp.isnan(x1)):
                out.converged = 'NaN'
                break

            if self.output == 'iter':
                print('NR: iterations='+str(i)+', norm='+str(root_norm))

            x0 = x1.copy()

        out.iter = i
        out.norm = root_norm

        if self.output != 'none':
            print('NR: iterations='+str(out.iter)+', norm='+str(root_norm)+', solver: '+out.converged)
        
        return x1, out
    
    def NewtonKantorovich(self, f, df, x0):
        class out: pass
        out.converged = 'None'
        out.iter = 0

        if self.output != 'none':
            print('Newton-Kantorovich (NK)')

        I = jnp.eye(jnp.size(x0))

        for i in range(self.iter_max):
            dx = jnp.linalg.solve((I - df(x0)), (x0 - f(x0)))
            x_nk = x0 - dx
            x1 = f(x_nk)
            
            step_norm = jnp.linalg.norm(x0 - x1)
            
            # Stopping criteria
            if step_norm<self.step_tol:
                out.converged = 'converged'
                break

            if jnp.any(jnp.isnan(x1)):
                out.converged = 'NaN'
                break

            if self.output == 'iter':
                print('NK: iterations='+str(i)+', norm='+str(step_norm))

            x0 = x1.copy()

        out.iter = i
        out.norm = step_norm

        if self.output != 'none':
            print('NK: iterations='+str(out.iter)+', norm='+str(step_norm)+', solver: '+out.converged)
        
        return x1, out
    
    def Poly(self, f, df, x0):

        if self.output != 'none':
            print('Combine SA and NK')

        for s in range(self.max_fxpiter):
            self.iter_max = self.sa_max
            x_sa, out_sa = self.SuccessiveApproximation(f, x0)

            if out_sa.norm<self.step_tol:
                x1, out = x_sa, out_sa
                break

            if jnp.any(jnp.isnan(out_sa.norm)):
                x1, out = x_sa, out_sa
                out.converged = 'NaN'
                break

            self.iter_max = self.nk_max
            x_nk, out_nk = self.NewtonKantorovich(f, df, x_sa)

            if out_nk.norm<self.step_tol:
                x1, out = x_nk, out_nk
                break

            if jnp.any(jnp.isnan(out_nk.norm)):
                x1, out = x_nk, out_nk
                out.converged = 'NaN'
                break

            x0 = x_nk.copy()

        if s == self.max_fxpiter-1:
            x1, out = x_nk, out_nk

        out.s = s

        return x1, out

**Newton-Raphson**

Newton-Raphson is a gradient based numerical method for root finding, $g(x)=0$, relying a first order taylor approximation.

\begin{equation}
g(x_{n+1}) ≈ g(x_n) + \bigtriangledown g(x_n)(x_{n+1}-x_n)=0
\end{equation}

\begin{equation}
\Leftrightarrow x_{n+1} = x_n - \bigtriangledown g(x_n)^{-1} g(x_n)
\end{equation}

        for i in range(self.iter_max):
            gx0 = g(x0)

            root_norm = jnp.linalg.norm(gx0)

            # Stopping criteria 1
            if root_norm<self.root_tol:
                out.converged = 'converged'
                break

            dx = jnp.linalg.solve(dg(x0), gx0)
            x1 = x0 - dx

            step_norm = jnp.linalg.norm(dx)

            # Stopping criteria 2
            if step_norm<self.step_tol:
                out.converged = 'no change in x'
                break

            x0 = x1.copy()

In [None]:
s = solver()

g = lambda t: UpdateTransfer(t,exo)[1] # used to find root of excess demand, g(t) = 0

g_jit = jit(g)
dg_jit = jit(jacobian(g))

t_nr, out_nr = s.NewtonRaphson(g_jit,dg_jit,t0_vector)

print('Found solution, t, Newton-Raphson:')
print(Vector2Matrix(t_nr,exo.matrixDimensions))
print('Z(t)')
print(Vector2Matrix(UpdateTransfer(t_nr,exo)[1],exo.matrixDimensions))

**Successive approximation**

Successive approximation is a gradient free numerical method for solving fixed point equations, $x=f(x)$.

For a given $x_0$ iteratetively update $x_n$
\begin{equation}
x_{n+1} = f(x_n)
\end{equation}

stop when $|x_{n+1} - x_n|<ɛ$

        for i in range(self.iter_max):
            x1 = f(x0)
            step_norm = jnp.linalg.norm(x1-x0)

            # Stopping criteria
            if step_norm<self.step_tol:
                out.converged = 'converged'
                break

            x0 = x1.copy()

In [None]:
s = solver()

f = lambda t: UpdateTransfer(t,exo)[0] # t1 = f(t0), used to find fixed point for transfers, t = f(t)
f_jit = jit(f)

t_sa, out_sa = s.SuccessiveApproximation(f_jit,t0_vector)
print('Found solution, t, Successive approximation:')
print(Vector2Matrix(t_sa,exo.matrixDimensions))
print('Z(t)')
print(Vector2Matrix(UpdateTransfer(t_sa,exo)[1],exo.matrixDimensions))

**Newton-Kantorovich**

Newton-Kantorovich is a gradient based numerical method for solving fixed point equations, $x=f(x)$, relying on a first order taylor approximation

\begin{equation}
f(x_{n+1}) \approx f(x_n) + \bigtriangledown f(x_n)(x_{n+1}-x_n) = x_{n+1}
\end{equation}

\begin{equation}
\Leftrightarrow x_{n+1} = (I - \bigtriangledown f(x_n))^{-1} (x_n - f(x_n))
\end{equation}

        I = jnp.eye(jnp.size(x0))

        for i in range(self.iter_max):
            dx = jnp.linalg.solve((I - df(x0)), (x0 - f(x0)))
            x_nk = x0 - dx
            x1 = f(x_nk)
            
            step_norm = jnp.linalg.norm(x0 - x1)
            
            # Stopping criteria
            if step_norm<self.step_tol:
                out.converged = 'converged'
                break

            x0 = x1.copy()

In [None]:
s = solver()

df_jit = jit(jacobian(f))

t_nk, out_nk = s.NewtonKantorovich(f_jit,df_jit,t0_vector)

**Combine successive approximations and Newton-Kantorovich**

Rust (1987) proposed to combine succesive approximations and Newton-Kantorovich.

Rust, John. “Optimal Replacement of GMC Bus Engines: An Empirical Model of Harold Zurcher.” Econometrica 55, no. 5 (1987): 999–1033. https://doi.org/10.2307/1911259.

In [None]:
s = solver()

t_poly, out_poly = s.Poly(f_jit,df_jit,t0_vector)

In [None]:
s = solver(output='none')

print('time Newton-Raphson')
%time t, out = s.NewtonRaphson(g_jit,dg_jit,t0_vector)
print('time Successive approximations (SA)')
%time t, out = s.SuccessiveApproximation(f_jit,t0_vector)
print('time Newton-Kantorovich (NK)')
%time t, out = s.NewtonKantorovich(f_jit,df_jit,t0_vector)
print('time combining SA and NK')
%time t, out = s.Poly(f_jit,df_jit,t0_vector)

**Recap**

In today's workshop we discussed
<ol>
  <li>Two-sided matching models with transferables utility.</li>
  <li>Implementations of numerical methods to solve non-linear system of demand and supply equation or fixed point equations.</li>
</ol>

More precisely, we considered the following three numerical methods
<ol>
  <li>Newton-Raphson iterations.</li>
  <li>Successive approximations.</li>
  <li>Newton-Kantorovich iterations.</li>
</ol>