# Implement shooting method to solve for time-path of variables

In [1]:
import numpy as np
from scipy.optimize import fsolve

In [2]:
# Define parameters
rho = 0.01
r = 0.05
a = 1
delta = 0.05
alpha = 0.5
gamma = 0.5
sigma = 0.5
params = (rho, r, a, delta, alpha, gamma, sigma)

# Initial conditions
K_0 = 1
A_0 = 100
mu_0 = 1
p_0 = 1
state_0 = (K_0, A_0, mu_0, p_0)
T = 20

In [3]:
# Define g(x) and g'(x) (labor market equilibrium tradeoff) function
def g(x):
    k = 5/4
    s = 1/2
    a = 1/(k**0.5 - s)
    y = k - (x/a + s)**2
    
    if y <= 0:
        return 0
    
    return y

def g_prime(x):
    k = 5/4
    s = 1/2
    a = 1/(k**0.5 - s)
    y = (-2*(x/a + s))/a
    
    return y

In [4]:
# Define utility functions
def u(c, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    
    if c <= 0:
        return -1e999
    
    return (c**(1-alpha))/(1-alpha)

def u_prime(c, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    
    if c <= 0:
        return 1e999
    
    return c**(-alpha)

def v(l, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    
    if l <= 0:
        return -1e999
    
    return (l**(1-gamma))/(1-gamma)

def v_prime(l, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    
    if l <= 0:
        return 1e999
    
    return l**(-gamma)

def B(A, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    
    if A <= 0:
        return -1e999
    
    return 100*(A**(1-sigma))/(1-sigma)

def B_prime(A, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    
    if A <= 0:
        return 1e999
    
    return 100*A**(-sigma)

In [5]:
# Define functions to compute FOCs
# Consumption FOC
def c_foc(state, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    K, A, mu, p = state # Unpack state vector
    
    return mu**(-1/alpha)

# Leisure FOC
def l_rhs(x, state, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    K, A, mu, p = state # Unpack state vector
    
    rhs = mu*K*g(x) + p*a*K*x
    
    return rhs

# Job choice FOC
def x_lhs(l, x, state, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    K, A, mu, p = state # Unpack state vector
    
    h = 1-l
    lhs = h*(a*p*K + mu*K*g_prime(x))
    
    return lhs

In [6]:
# Define function that checks FOCs
def l_foc(l, x, state, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    K, A, mu, p = state # Unpack state vector
    
    lhs = v_prime(l, params)
    rhs = l_rhs(x, state, params)
    diff = lhs - rhs
    
    return diff

def x_foc(l, x, state, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    K, A, mu, p = state # Unpack state vector
    
    diff = x_lhs(l, x, state, params) - 0
    
    return diff

In [7]:
# Combine both FOC checkers within one function
def lx_foc(guess, state, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    K, A, mu, p = state # Unpack state vector
    l, x = guess
    
    res1 = l_foc(l, x, state, params) # Check FOC for l
    res2 = x_foc(l, x, state, params) # Check FOC for x
    
    error = np.array([res1, res2])
    
    return error

In [8]:
# Define a function that solves for l and x for a given mu and p from FOCs
def foc_solve(state, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    K, A, mu, p = state # Unpack state vector
    
    # Find c first
    c = c_foc(state, params)
    
    # Try looking for interior solutions first
    [l, x] = fsolve(lx_foc, [0.5, 0.5], args = (state, params))
    
    # Check if solutions are interior
    if 0 < l < 1 and 0 < x < 1:
        #print("Interior solution obtained")
        return np.array([l, x, c])
    
    # If not interior, try boundary solutions for x
    # Try x = 0
    x = 0
    [l] = fsolve(l_foc, [0.5], args = (x, state, params))
    # Check if l is valid
    if 0 < l < 1:
        # Check FOC for x holds
        check = x_lhs(l, x, state, params)
        if check <= 0:
            #print("x = 0 is a valid boundary solution")
            return np.array([l, x, c])
        
    # If l is not valid, try x = 1
    x = 1
    [l] = fsolve(l_foc, [0.5], args = (x, state, params))
    if 0 < l < 1:
        # Check FOC for x holds
        check = x_lhs(l, x, state, params)
        if check >= 0:
            #print("x = 1 is a valid boundary solution")
            return np.array([l, x, c])
    
    # If neither x = 0 nor x = 1 are valid, check the solution l = 1 (x can be anything)
    l = 1
    x = 1
    check = l_foc(l, x, state, params)
    if check >= 0:
        #print("l = 1 is a valid boundary solution")
        return np.array([l, x, c])
    else:
        print("No valid solution???")

In [9]:
# Define functions that produce state tomorrow given today's solution
def state_iterate(choice, state, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    K, A, mu, p = state # Unpack state vector
    l, x, c = choice
    
    h = 1 - l
    mu_1 = (1 + rho - r)*mu
    p_1 = p*(1 + rho + delta - a*x*h) - g(x)*h*mu
    A_1 = (1 + r)*A + g(x)*h*K - c
    K_1 = (1 + a*x*h - delta)*K
    
    return np.array([K_1, A_1, mu_1, p_1])

In [10]:
def iterate(state_0, T, params):
    # Define arrays to store results
    choice_path = np.zeros((3, T))
    state_path = np.zeros((4, T+1))

    state_path[:, 0] = state_0

    # Try iterating forward in time given initial conditions
    for t in range(T):
        choice_path[:, t] = foc_solve(state_path[:, t], params) # Find choice today given state today
        state_path[:, t+1] = state_iterate(choice_path[:, t], state_path[:, t], params) # Given choice today, find next state
        
    return (choice_path, state_path)

In [11]:
def terminal_condition(guess, K_0, A_0, T, params):
    rho, r, a, delta, alpha, gamma, sigma = params # Unpack parameters
    mu, p = guess # Unpack guess
    state_0 = [K_0, A_0, mu, p]
    choice_path, state_path = iterate(state_0, T, params) # Iterate forward to get terminal state
    
    K_T, A_T, mu_T, p_T = state_path[:, -2] # Unpack terminal state
    
    # Evaluate terminal condition
    err1 = K_T*p_T
    err2 = mu_T - B_prime(A_T, params)
    
    print(f"err1 = {err1}, err2 = {err2}")
    
    return np.array([err1, err2])

In [12]:
[mu_sol, p_sol] = fsolve(terminal_condition, [mu_0, p_0], args=(K_0, A_0, T, params), maxfev=999999, xtol = 1e-16)

err1 = 0.9601512402143622, err2 = -6.782379622711732
err1 = 0.9601512402143622, err2 = -6.782379622711732
err1 = 0.9601512402143622, err2 = -6.782379622711732
err1 = 0.960151209615314, err2 = -6.782379579406219
err1 = 0.9601512377788669, err2 = -6.782379621994645
err1 = -37.96629168963486, err2 = -4.544880762506606
err1 = -7.691292362989523, err2 = -5.737706565973599
err1 = 0.960151209615314, err2 = -6.782379579406219
err1 = 0.9601512377788669, err2 = -6.782379621994645
err1 = -37.96629168963486, err2 = -4.544880762506606
err1 = -7.691292362989523, err2 = -5.737706565973599
err1 = -6.330482223745463, err2 = -6.372854727133106
err1 = -6.20706508788248, err2 = -6.1878215555799025
err1 = -3.3313339946195235, err2 = -6.630596645048826
err1 = -0.8338698306609915, err2 = -6.517011217992589
err1 = -8.859060600112448, err2 = -6.1268164471028665
err1 = 0.3031069005404361, err2 = -6.475795144022659
err1 = -3.0485412601627826, err2 = -6.22833781412556
err1 = -1.6133813355181599, err2 = -6.3656573

In [13]:
state_sol = [K_0, A_0, mu_sol, p_sol]

In [14]:
[choice_path, state_path] = iterate(state_sol, T, params)

In [15]:
choice_path[0,:]

array([0.5       , 1.        , 1.        , 0.27490923, 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       0.99830117, 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ])

In [16]:
choice_path[1,:]

array([0.5       , 0.8835886 , 0.86985741, 0.        , 0.38839046,
       0.33365981, 0.27494991, 0.21221733, 0.14549367, 0.07490632,
       0.        , 1.        , 1.        , 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ])

In [17]:
choice_path[2,:]

array([0.25239086, 0.27386161, 0.29715887, 0.32243801, 0.34986763,
       0.37963067, 0.41192564, 0.44696793, 0.48499124, 0.52624918,
       0.57101691, 0.619593  , 0.67230143, 0.72949374, 0.79155136,
       0.8588882 , 0.93195334, 1.01123409, 1.09725922, 1.19060245])

In [18]:
state_path[0,:]

array([1.        , 1.2       , 1.14      , 1.083     , 1.02885   ,
       0.9774075 , 0.92853713, 0.88211027, 0.83800476, 0.79610452,
       0.75629929, 0.71848433, 0.68256011, 0.64843211, 0.6160105 ,
       0.58520997, 0.55594948, 0.528152  , 0.5017444 , 0.47665718,
       0.45282432])

In [19]:
state_path[1,:]

array([100.        , 105.04535489, 110.02376102, 115.22779021,
       121.45201502, 127.17474814, 133.15385487, 139.39962197,
       145.92263514, 152.73377565, 159.84421525, 167.26669394,
       175.01043564, 183.08865599, 191.51359505, 200.29772344,
       209.45372142, 218.99445415, 228.93294276, 239.28233068,
       250.05584477])

# Perform grid search instead
n = 50
mu_grid = np.linspace(-100, 100, num = n)
p_grid = np.linspace(-100, 100, num = n)
err1_grid = np.zeros((n, n))
err2_grid = np.zeros((n, n))

for i, mu in np.ndenumerate(mu_grid):
    for j, p in np.ndenumerate(p_grid):
        err1_grid[i, j], err2_grid[i, j] = terminal_condition([mu, p], K_0, A_0, T, params)
        print(f"i = {i}, j = {j}")

err1_grid_abs = np.absolute(err1_grid)
err2_grid_abs = np.absolute(err2_grid)
err1_ind = np.unravel_index(np.argmin(err1_grid_abs), err1_grid.shape) 
err2_ind = np.unravel_index(np.argmin(err2_grid_abs), err2_grid.shape) 