In [None]:
import jax
import jax.numpy as jnp
import sys
import decimal
decimal.getcontext().prec = 5

In [None]:
# Model Parameters
R = 1.05
BETA = 0.945
RHO = 0.9
STD_U = 0.010
M = 1
GAMMA_C = 1

## Functions

In [None]:
def rowenhorst(rho, sigma_eps, n):
    '''
    rho : 1st order autocorrelation
    sigma_eps : Standard Deviation of the error term
    n : Number of points in the discrete approximation
    '''
    mu_eps = 0
    q = (rho+1) / 2
    nu = ((n-1)/(1-rho**2))**(1/2) * sigma_eps
    # Might need the reshape for this to have the same shape as what we see in the matlab code.
    z_grid = jnp.linspace(mu_eps/(1-rho)-nu, mu_eps/(1-rho)+nu,n).reshape(1,-1)
    P = jnp.array(((q, 1-q), (1-q, q)))
    for i in range (2,n):
        P = q * jnp.pad(P, ((0,1),(0,1)), constant_values=0) + \
            q * jnp.pad(P, ((1,0),(1,0)), constant_values=0) + \
            (1-q) * jnp.pad(P, ((0,1),(1,0)), constant_values=0) + \
            (1-q) * jnp.pad(P, ((1,0),(0,1)), constant_values=0)
        P = P.at[1,:].divide(2)

    return z_grid, P

In [None]:
def all_inter_211(x1, x2, x1i, x2i, pf1):
    '''
    Linear interpolation/extrapolation (2 states, 1 policy, 1 stochastic comp)
    x* : grid
    x*i : point to evaluate
    pf* : policy function
    '''
    # Grid lengths
    nx1 = x1.shape[1]
    nx2 = x2.shape[1]
    # Number of Stochastic Realizations
    x2i_pts = len(x2i)

    # Preallocate Output
    o1 = jnp.zeros((x2i_pts,1))
    s1 = x1[0,1] - x1[0,0]
    x1i_min = x1i - x1[0,0]
    loc1 = int(min(nx1-1, max(1,jnp.floor(x1i_min/s1) + 1)))
    for i2 in range(x2i_pts):
        s2 = x2[0,1] - x2[0,0]
        x2i_min = x2i[i2] - x2[0,0]
        loc2 = int(min(nx2-1, max(1,jnp.floor(x2i_min/s2) + 1)))
        xi = jnp.array((x1i, x2i[i2])).reshape(1,-1)
        xi_left = jnp.array([x1[0,loc1-1], x2[0,loc2-1]]).reshape(1,-1)
        xi_right = jnp.array((x1[0,loc1], x2[0,loc2])).reshape(1,-1)

        w_2 = (xi - xi_left) / (xi_right - xi_left)
        w_1 = 1 - w_2
        w1 = jnp.array((w_1[0,0], w_2[0,0])).reshape(1,-1)
        w2 = jnp.array((w_1[0,1], w_2[0,1])).reshape(1,-1)

        for m2 in range(2):
            for m1 in range(2):
                o1 = o1.at[i2,0].set(o1[i2,0] + w1[0, m1] * w2[0, m2] * pf1[loc1 + m1 - 1, loc2 + m2 - 1])
    return o1

In [None]:
def solve_fiPIT(b_in, p_in, z_in, b_dec, n_in):

    std_u = n_in[0]
    rho = n_in[1]
    gamma_c = n_in[2]
    r = n_in[3]
    m = n_in[4]
    bet = n_in[5]

    nz, nb = b_dec.shape
    b_min = b_in[0]
    b_max = b_in[nb]
    z_m, b_m = jnp.meshgrid(z_in,b_in)
    c_dec = b_dec - r * b_m + z_m
    int_ct_prime = jnp.zeros((nz, nb))
    c_dec_new = jnp.zeros((nz, nb))
    b_dec_new = jnp.zeros((nz, nb))

    for iz in range(nz):
        for ib in range(nb):
            b_use = b_in[ib]
            z_use = z_in[iz]
            b_prime = b_dec[iz,ib]
            c_use = z_use - r * b_use + z_use

            if b_prime < b_min:
                b_prime = b_min
                c_use = b_prime - r * b_use + z_use

            for iq in range(nz):
                int_ct_prime[iq] = all_inter_211(
                    x1 = b_in,
                    x2 = z_in,
                    x1i = b_dec[iz,ib],
                    x2i = z_in[iq],
                    pf1 = c_dec.T,
                )

                int_ct_prime[iq] = max(int_ct_prime[iq], 1e-20)^(-gamma_c)

            sol_ct2 = p_in[iz,:] * int_ct_prime.T
            c_dec_new[iz,ib] = (bet * r * sol_ct2)^(-1/gamma_c)
            b_dec_new[iz,ib] = c_dec_new[iz,ib] - z_use + r * b_use

            if b_dec_new[iz,ib] > M * z_use:
                b_dec_new[iz,ib] = M * z_use

    return b_dec_new

In [None]:
# Grid for Decision Rules and Integration Nodes
nz = 11 # Income grid in terms of z=log(Y)
nb = 200 # Grid points for debt

# Rowenhorst Discretization
log_y, p = rowenhorst(rho=RHO, sigma_eps=STD_U, n=nz)
y_grid = jnp.exp(log_y)
z_grid = log_y

# Borrowing Grid
b_min = 0.75 * M
b_max = M * y_grid[0,-1]
b_grid = jnp.linspace(b_min, b_max, nb).reshape(1,-1)

# Initial guess
y_m, b_m = jnp.meshgrid(jnp.ravel(y_grid),jnp.ravel(b_grid))
y_m, b_m = y_m.T, b_m.T
c_dec_old = jnp.maximum(1e-100, -R * b_m + (1 + M) * y_m)
b_dec_old = c_dec_old + R * b_m - y_m
b_dec_use = b_dec_old

In [None]:
count = 1
dist = 100

while dist > 1e-10:
    # b_dec_up = solve_fiPIT(
    #     b_in = b_grid,
    #     p_in = p,
    #     z_in = y_grid,
    #     b_dec = b_dec_use,
    #     n_in = jnp.array(((STD_U),(RHO),(GAMMA_C),(R),(M),(BETA)))
    #     )
    b_in = b_grid
    p_in = p
    z_in = y_grid
    b_dec = b_dec_use
    n_in = jnp.array(((STD_U),(RHO),(GAMMA_C),(R),(M),(BETA)))
    std_u = n_in[0]
    rho = n_in[1]
    gamma_c = n_in[2]
    r = n_in[3]
    m = n_in[4]
    bet = n_in[5]

    nz, nb = b_dec.shape
    b_min = b_in[0,0]
    b_max = b_in[nb,nb]
    z_m, b_m = jnp.meshgrid(jnp.ravel(z_in),jnp.ravel(b_in))
    z_m, b_m = z_m.T, b_m.T
    c_dec = b_dec - r * b_m + z_m
    int_ct_prime = jnp.empty((1,nz))
    c_dec_new = jnp.empty((nz, nb))
    b_dec_new = jnp.empty((nz, nb))
    for iz in range(nz):
        for ib in range(nb):
            b_use = b_in[0,ib]
            z_use = z_in[0,iz]
            b_prime = b_dec[iz,ib]
            c_use = z_use - r * b_use + z_use

            if b_prime < b_min:
                b_prime = b_min
                c_use = b_prime - r * b_use + z_use

            for iq in range(nz):
                o1 = all_inter_211(x1 = b_in,
                    x2 = z_in,
                    x1i = b_dec[iq,ib],
                    x2i = [z_in[0,iq]],
                    pf1 = c_dec.T,
                )
                int_ct_prime = int_ct_prime.at[0,iq].set(o1[0,0])
                int_ct_prime = int_ct_prime.at[0,iq].set(max(int_ct_prime[0,iq], 1e-20)**(-gamma_c))
            sol_ct2 = jnp.matmul(p_in[iz,:].reshape(1,-1),int_ct_prime.T)[0,0]
            c_dec_new = c_dec_new.at[(iz,ib)].set((bet * r * sol_ct2)**(-1/gamma_c))
            b_dec_new = b_dec_new.at[(iz,ib)].set(c_dec_new[iz,ib] - z_use + r * b_use)
            if b_dec_new[iz,ib] > M * z_use:
                b_dec_new = b_dec_new.at[(iz,ib)].set(M * z_use)
    b_dec_new = b_dec_new * 0.25 + b_dec_use * (1-0.25)
    dist = jnp.linalg.norm(abs(b_dec_new - b_dec_use))
    if count%10 == 0:
        print(f"\n {count}, dist = {dist:.4f}")
    b_dec_use = b_dec_new
    count = count + 1
    # b_dec_up = b_dec_new
    # b_dec_new = b_dec_up * 0.25 + b_dec_use * (1-0.25)
    # dist = jnp.linalg.norm(abs(b_dec_up - b_dec_use))
    # if count%10 == 0:
    #     print(f"\n {count}, dist = {dist:.4f}")
    # b_dec_use = b_dec_new
    # count = count + 1

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
# %%timeit -r 3 -n 3
# Solve Model Time Iteration

# Initial guess
y_m, b_m = jnp.meshgrid(y_grid,b_grid)
c_dec_old = jnp.maximum(1e-100, -R * b_m + (1 + M) * y_m)
b_dec_old = c_dec_old + R * b_m - y_m
b_dec_use = b_dec_old

count = 1
dist = 100

while dist > 1e-10:
    # b_dec_up = solve_fiPIT(
    #     b_in = b_grid,
    #     p_in = p,
    #     z_in = y_grid,
    #     b_dec = b_dec_use,
    #     n_in = jnp.array(((STD_U),(RHO),(GAMMA_C),(R),(M),(BETA)))
    #     )
    b_in = b_grid
    p_in = p
    z_in = y_grid
    b_dec = b_dec_use
    n_in = jnp.array(((STD_U),(RHO),(GAMMA_C),(R),(M),(BETA)))
    std_u = n_in[0]
    rho = n_in[1]
    gamma_c = n_in[2]
    r = n_in[3]
    m = n_in[4]
    bet = n_in[5]

    nz, nb = b_dec.shape
    b_min = b_in[0]
    b_max = b_in[nb]
    z_m, b_m = jnp.meshgrid(z_in,b_in)
    c_dec = b_dec - r * b_m + z_m
    int_ct_prime = jnp.empty((nb,1))
    c_dec_new = jnp.empty((nz, nb))
    b_dec_new = jnp.empty((nz, nb))

    for iz in range(nz):
        for ib in range(nb):
            b_use = b_in[ib]
            z_use = z_in[iz]
            b_prime = b_dec[iz,ib]
            c_use = z_use - r * b_use + z_use

            if b_prime < b_min:
                b_prime = b_min
                c_use = b_prime - r * b_use + z_use

            for iq in range(nz):
                # int_ct_prime[iq] = all_inter_211(
                #     x1 = b_in,
                #     x2 = z_in,
                #     x1i = b_dec[iz,ib],
                #     x2i = z_in[iq],
                #     pf1 = c_dec.T,
                # )
                x1 = b_in
                x2 = z_in
                x1i = b_dec[iq,ib]
                x2i = [z_in[iq]]
                pf1 = c_dec.T

                # Grid lengths
                nx1 = len(x1)
                nx2 = len(x2)

                # Number of Stochastic Realizations
                x2i_pts = len(x2i)

                # Preallocate Output
                o1 = jnp.zeros((x2i_pts,1))
                s1 = x1[1] - x1[0]
                x1i_min = x1i - x1[0]
                loc1 = int(min(nx1-1, max(1,jnp.floor(x1i_min/s1) + 1)))

                for i2 in range(x2i_pts):
                    s2 = x2[1] - x2[0]
                    x2i_min = x2i[i2] - x2[0]
                    loc2 = int(min(nx2-1, max(1,jnp.floor(x2i_min/s2) + 1)))

                    xi = jnp.array((x1i, x2i[i2]));
                    xi_left = jnp.array([x1[loc1], x2[loc2]])
                    xi_right = jnp.array((x1[loc1+1], x2[loc2+1]))

                    w_2 = (xi - xi_left) / (xi_right - xi_left)
                    w_1 = 1 - w_2
                    w1 = jnp.array((w_1[0], w_2[1]))
                    w2 = jnp.array((w_1[1], w_2[1]))

                    for m2 in range(2):
                        for m1 in range(2):
                            o1 = o1.at[i2].set(o1[i2] + w1[m1 + 1] * w2[m2 + 1] * pf1[loc1 + m1, loc2 + m2])
                int_ct_prime = o1
                int_ct_prime = int_ct_prime.at[iq].set(max(int_ct_prime[iq], 1e-20)**(-gamma_c))
            sol_ct2 = p_in[iz,:] * int_ct_prime.T

            c_dec_new = c_dec_new.at[(iz,ib)].set((bet * r * sol_ct2)**(-1/gamma_c))
            b_dec_new = b_dec_new.at[(iz,ib)].set(c_dec_new[iz,ib] - z_use + r * b_use)

            if b_dec_new[iz,ib] > M * z_use:
                b_dec_new[iz,ib] = M * z_use
    b_dec_new = b_dec_up * 0.25 + b_dec_use * (1-0.25)
    dist = jnp.linalg.norm(abs(b_dec_up - b_dec_use))
    if count%10 == 0:
        print(f"\n {count}, dist = {dist:.4f}")
    b_dec_use = b_dec_new
    count = count + 1

In [None]:
o1

In [None]:
print(p_in[iz,:])
print(int_ct_prime.T)
jnp.matmul(p_in[iz,:], int_ct_prime)

In [None]:
import decimal
decimal.getcontext().prec = 10
for i in int_ct_prime:
    print(format(i[0], '.10f'))

In [None]:
o1 = jnp.empty((x2i_pts,1))
for m2 in range(2):
    for m1 in range(2):


In [None]:
# Store decision rules
kb_ti = b_dec_use
kc_ti = kb_ti - R * b_m + y_m

# Compute decision rule for multiplier
b_dec = kb_ti
c_dec = b_dec - R * b_m + y_m

int_ct_prime = []
l_dec = []
for iz in range(nz):
    for ib in range(nb):
        b_use = b_grid[ib]
        z_use = y_grid[iz]

        b_prime = b_dec[iz,ib]
        c_use = z_use - R * b_use + b_dec[iz,ib]

        if b_prime < b_min:
            b_prime = b_min
            c_use = b_prime - R * b_use + z_use

        for iq in range(nz):
            int_ct_prime[iq] = all_inter_211(
                x1 = b_grid,
                x2 = y_grid,
                x1i = b_dec[iz,ib],
                x2i = y_grid[iq],
                pf1 = c_dec.T,
                )
            int_ct_prime[iq] = max(int_ct_prime[iq], 1e-20)^(-GAMMA_C)

        sol_ct2 = p[iz,:] * int_ct_prime.T
        l_dec[iz,ib] = max(0, c_use^(-GAMMA_C) - (BETA*R*sol_ct2))