# Krusell–Smith (1998) Model with SRL

Heterogeneous-agent economy: households as in Huggett, but **wealth is productive capital** rented to a representative firm. Production $Y_t = z_t K_t^\alpha L_t^{1-\alpha}$; $L_t=1$ (inelastic). Factor prices: $r_t^K = \alpha Y_t/K_t$, $w_t = (1-\alpha)Y_t/L_t$; net return $r_t = r_t^K - \delta$.

**Individual state:** $(b, y)$ — capital and idiosyncratic labor productivity.  
**Aggregate:** $K_t = \int b\, dG_t(b,y)$; prices $(r_t, w_t)$ from $(K_t, z_t)$.  
**Budget:** $c + b' = (1+r)b + w\,y$ (one unit labor inelastically), $b \geq 0$.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

np.random.seed(42)
plt.rcParams['font.size'] = 10
plt.rcParams['axes.unicode_minus'] = False

## Calibration (SRL Section 4.2 & Appendix A.2, Table 4 & 5)

- **Preferences:** $\beta=0.95$, CRRA $\sigma=3$.
- **Production:** capital share $\alpha=0.36$, depreciation $\delta=0.08$.
- **Idiosyncratic income** $y$: same as Huggett (log AR(1), $\rho_y=0.6$, $\nu_y=0.2$), 3 states.
- **Aggregate TFP** $z$: log AR(1), $\rho_z=0.9$, $\nu_z=0.03$.
- **Capital:** $b\geq 0$, grid $n_b=200$, $b_{\max}=100$.

In [None]:
# Preferences
beta = 0.95
sigma = 3.0

# Production
alpha = 0.36
delta = 0.08

# Idiosyncratic income (same as Huggett)
rho_y = 0.6
nu_y = 0.2
ny = 3

# Aggregate TFP
rho_z = 0.9
nu_z = 0.03

# Capital
b_min = 0.0
b_max = 100.0
nb = 200

# Price grids for policy (Table 5)
nr = 30
r_min, r_max = 0.02, 0.07
nw = 50
w_min, w_max = 0.9, 1.5
nz = 30

c_min = 1e-3
T_trunc = 90
etrunc = 1e-2

print("Krusell-Smith calibration (SRL 4.2, App A.2):")
print(f"  β={beta}, σ={sigma}, α={alpha}, δ={delta}")
print(f"  ρy={rho_y}, νy={nu_y}, ρz={rho_z}, νz={nu_z}")
print(f"  b∈[0,{b_max}], nb={nb}; r∈[{r_min},{r_max}], nr={nr}; w∈[{w_min},{w_max}], nw={nw}")

In [None]:
def tauchen_ar1(rho, sigma_innov, n_states, m=3):
    sigma_x = sigma_innov / np.sqrt(1 - rho**2)
    x_min, x_max = -m * sigma_x, m * sigma_x
    x_grid = np.linspace(x_min, x_max, n_states)
    step = (x_max - x_min) / (n_states - 1)
    P = np.zeros((n_states, n_states))
    for i in range(n_states):
        for j in range(n_states):
            if j == 0:
                P[i, j] = norm.cdf((x_grid[j] - rho * x_grid[i] + step / 2) / sigma_innov)
            elif j == n_states - 1:
                P[i, j] = 1 - norm.cdf((x_grid[j] - rho * x_grid[i] - step / 2) / sigma_innov)
            else:
                P[i, j] = (norm.cdf((x_grid[j] - rho * x_grid[i] + step / 2) / sigma_innov)
                           - norm.cdf((x_grid[j] - rho * x_grid[i] - step / 2) / sigma_innov))
    P = P / P.sum(axis=1, keepdims=True)
    return x_grid, P

log_y_grid, Ty = tauchen_ar1(rho_y, nu_y, ny)
y_grid = np.exp(log_y_grid)
invariant_y = np.linalg.matrix_power(Ty.T, 200)[:, 0]
y_grid = y_grid / (y_grid @ invariant_y)

log_z_grid, Tz = tauchen_ar1(rho_z, nu_z, nz)
z_grid = np.exp(log_z_grid)
invariant_z = np.linalg.matrix_power(Tz.T, 200)[:, 0]
z_grid = z_grid / (z_grid @ invariant_z)

b_grid = np.linspace(b_min, b_max, nb)
r_grid = np.linspace(r_min, r_max, nr)
w_grid = np.linspace(w_min, w_max, nw)

def u(c):
    c = np.maximum(c, c_min)
    return (c ** (1 - sigma)) / (1 - sigma)

def u_prime(c):
    return np.maximum(c, c_min) ** (-sigma)

print("y_grid:", y_grid)
print("z_grid (first/last):", z_grid[0], z_grid[-1])

## Production and prices

Given aggregate capital $K$ and TFP $z$: $r^K = \alpha z K^{\alpha-1}$, $w = (1-\alpha)z K^\alpha$ (with $L=1$). Net return $r = r^K - \delta$.

In [None]:
def K_to_prices(K, z):
    """(K, z) -> (r_net, w). L=1."""
    K = np.maximum(K, 1e-8)
    rK = alpha * z * (K ** (alpha - 1))
    w = (1 - alpha) * z * (K ** alpha)
    r_net = rK - delta
    return r_net, w

# Example: steady-state K from β: 1 = β(1 + r) => r_ss = 1/beta - 1
r_ss = 1.0 / beta - 1
z_ss = 1.0
K_ss = (alpha * z_ss / (r_ss + delta)) ** (1 / (1 - alpha))
r_ss_check, w_ss_check = K_to_prices(K_ss, z_ss)
print(f"Steady state (z=1): r_ss={r_ss:.4f}, K_ss={K_ss:.4f}, w_ss={w_ss_check:.4f}")

## Policy and simulation (SRL: price-based)

Policy $c(b,y,r,w,z)$; in GE, $(r_t, w_t) = P^*(K_t, z_t)$ with $K_t = \int b\, dG_t$. We **stop-gradient** on $(r,w)$ so agents take prices as given.

In [None]:
def policy_placeholder(b, y, r, w, z, save_frac=0.25):
    """π(b, y, r, w, z) -> (c, b_next). Cash = (1+r)*b + w*y."""
    cash = (1 + r) * b + w * y
    b_next = np.clip(save_frac * cash, b_min, b_max)
    c = np.maximum(cash - b_next, c_min)
    return c, b_next

def aggregate_capital(b_vec, y_vec):
    return np.mean(b_vec)

def simulate_krusell_smith(T, N_agents, policy_fn, y_grid, z_grid, Ty, Tz, b0=None, z0_idx=None):
    """Simulate K-S economy: each period K_t = mean(b_t), (r_t,w_t)=P*(K_t,z_t), then π -> b'."""
    if z0_idx is None:
        z0_idx = np.random.randint(0, len(z_grid))
    y0_idx = np.random.randint(0, len(y_grid), size=N_agents)
    if b0 is None:
        b0 = np.full(N_agents, K_ss / N_agents)
    b = b0.copy()
    y_idx = y0_idx.copy()
    z_idx = z0_idx
    paths = {'b': np.zeros((T+1, N_agents)), 'y': np.zeros((T+1, N_agents)), 'z': np.zeros(T+1),
             'r': np.zeros(T+1), 'w': np.zeros(T+1), 'K': np.zeros(T+1), 'c': np.zeros((T+1, N_agents))}
    paths['b'][0], paths['y'][0], paths['z'][0] = b, y_grid[y_idx], z_grid[z_idx]
    paths['K'][0] = np.mean(b)
    paths['r'][0], paths['w'][0] = K_to_prices(paths['K'][0], paths['z'][0])
    paths['c'][0] = np.nan
    for t in range(T):
        K_t = np.mean(b)
        z_t = z_grid[z_idx]
        r_t, w_t = K_to_prices(K_t, z_t)
        y_val = y_grid[y_idx]
        c_t, b_next = policy_fn(b, y_val, r_t, w_t, z_t)
        paths['r'][t], paths['w'][t], paths['K'][t] = r_t, w_t, K_t
        paths['c'][t] = c_t
        b = b_next
        paths['b'][t+1] = b
        paths['y'][t+1] = y_grid[y_idx]
        for i in range(N_agents):
            y_idx[i] = np.searchsorted(np.cumsum(Ty[y_idx[i], :]), np.random.rand())
        z_idx = np.searchsorted(np.cumsum(Tz[z_idx, :]), np.random.rand())
        paths['z'][t+1] = z_grid[z_idx]
    paths['K'][T] = np.mean(b)
    paths['r'][T], paths['w'][T] = K_to_prices(paths['K'][T], paths['z'][T])
    paths['c'][T] = policy_placeholder(b, paths['y'][T], paths['r'][T], paths['w'][T], paths['z'][T])[0]
    v_hat = np.sum([(beta**t) * np.mean(u(paths['c'][t])) for t in range(T)])
    return paths, v_hat

In [None]:
policy_fn = lambda b, y, r, w, z: policy_placeholder(b, y, r, w, z, save_frac=0.25)
T_sim, N_agents = 200, 2000
paths, v_hat = simulate_krusell_smith(T_sim, N_agents, policy_fn, y_grid, z_grid, Ty, Tz)
print("Simulation done. v̂_π =", v_hat)
print("Mean K:", np.mean(paths['K']))
print("Mean r, w:", np.mean(paths['r']), np.mean(paths['w']))

## SRL / SPG: Grid policy $c(b,y,r,w,z)$ with gradient-stop on $(r,w)$

Same idea as Huggett: maximize $L(\theta) = \mathbb{E}[\sum_t \beta^t d_t^\top u(c_t)]$; $K_t$ from $d_t$, $(r_t,w_t)=P^*(K_t,z_t)$ **detached**.

In [None]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32

# Coarser grid for SPG (Table 5: nb=200, nr=30, nw=50; we use smaller for speed)
nb_spg, nr_spg, nw_spg, nz_spg = 50, 15, 25, 10
ny_spg = ny
J = nb_spg * ny_spg

b_grid_t = torch.tensor(np.linspace(b_min, b_max, nb_spg), dtype=dtype, device=device)
r_grid_t = torch.tensor(np.linspace(r_min, r_max, nr_spg), dtype=dtype, device=device)
w_grid_t = torch.tensor(np.linspace(w_min, w_max, nw_spg), dtype=dtype, device=device)
z_grid_t = torch.tensor(z_grid[np.linspace(0, nz-1, nz_spg, dtype=int)], dtype=dtype, device=device)
y_grid_t = torch.tensor(y_grid, dtype=dtype, device=device)
Ty_t = torch.tensor(Ty, dtype=dtype, device=device)
iz_spg = np.linspace(0, nz-1, nz_spg, dtype=int)
Tz_sub = Tz[np.ix_(iz_spg, iz_spg)]
Tz_sub = Tz_sub / Tz_sub.sum(axis=1, keepdims=True)
Tz_t = torch.tensor(Tz_sub, dtype=dtype, device=device)
nz_spg = Tz_t.shape[0]

def theta_to_c_grid(theta, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t):
    """θ (J, nz, nr, nw) -> c >= c_min."""
    return torch.nn.functional.softplus(theta) + c_min

def init_theta_ks(b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, save_frac=0.25):
    nz, nr, nw = len(z_grid_t), len(r_grid_t), len(w_grid_t)
    c_grid = torch.zeros(J, nz, nr, nw, dtype=dtype, device=device)
    for ib in range(nb_spg):
        for iy in range(ny_spg):
            j = ib * ny_spg + iy
            b = b_grid_t[ib].item()
            y = y_grid_t[iy].item()
            for iz in range(nz):
                z = z_grid_t[iz].item()
                for ir in range(nr):
                    r = r_grid_t[ir].item()
                    for iw in range(nw):
                        w = w_grid_t[iw].item()
                        cash = (1 + r) * b + w * y
                        c_grid[j, iz, ir, iw] = max((1 - save_frac) * cash, c_min)
    x = c_grid - c_min
    return torch.log(torch.exp(x) - 1 + 1e-8)

def K_from_d(d, b_grid_t, ny_spg):
    """Aggregate capital K = d^T b (over b indices)."""
    b_vals = b_grid_t.repeat_interleave(ny_spg)
    return (d * b_vals).sum()

def P_star_detach_ks(theta, d, iz, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, ny_spg):
    """(K,z) -> (r,w) from production; K = d·b. Return (r,w) detached."""
    K = K_from_d(d, b_grid_t, ny_spg).item()
    z_val = z_grid_t[iz].item()
    r_net, w_val = K_to_prices(K, z_val)
    r_t = torch.tensor(r_net, device=device, dtype=dtype)
    w_t = torch.tensor(w_val, device=device, dtype=dtype)
    return r_t.detach(), w_t.detach()

def rw_to_ir_iw(r_val, w_val, r_grid_t, w_grid_t):
    rn, wn = r_grid_t.cpu().numpy(), w_grid_t.cpu().numpy()
    ir = int(np.clip(np.searchsorted(rn, r_val), 0, len(rn)-1))
    iw = int(np.clip(np.searchsorted(wn, w_val), 0, len(wn)-1))
    return ir, iw

In [None]:
def build_A_pi_ks(theta, iz, ir, iw, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, Ty_t, nb_spg, ny_spg, sigma_b=0.5):
    J = nb_spg * ny_spg
    z_val = z_grid_t[iz]
    r_val = r_grid_t[ir]
    w_val = w_grid_t[iw]
    c = theta_to_c_grid(theta, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t)[:, iz, ir, iw]
    b_next = (1 + r_val) * b_grid_t.repeat_interleave(ny_spg) + y_grid_t.repeat(nb_spg) * w_val - c
    b_next = torch.clamp(b_next, b_min, b_max)
    dist = b_next.unsqueeze(1) - b_grid_t.unsqueeze(0)
    w_b = torch.exp(-dist.pow(2) / (2 * sigma_b**2))
    w_b = w_b / (w_b.sum(dim=1, keepdim=True) + 1e-8)
    A = torch.zeros(J, J, dtype=dtype, device=device)
    for j in range(J):
        ib, iy = j // ny_spg, j % ny_spg
        for ibp in range(nb_spg):
            for iyp in range(ny_spg):
                jp = ibp * ny_spg + iyp
                A[jp, j] = w_b[j, ibp] * Ty_t[iy, iyp]
    return A

def u_torch(c_vec):
    c_vec = torch.clamp(c_vec, min=c_min)
    return (c_vec ** (1 - sigma)) / (1 - sigma)

def spg_objective_ks(theta, N_traj, T_horizon, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, Ty_t, Tz_t,
                       nb_spg, ny_spg, nz_spg, nr_spg, nw_spg, beta_t, d0=None):
    J = nb_spg * ny_spg
    if d0 is None:
        d0 = torch.ones(J, device=device, dtype=dtype) / J
    L_list = []
    for n in range(N_traj):
        iz = np.random.randint(0, nz_spg)
        d = d0.clone()
        L_n = torch.tensor(0.0, device=device, dtype=dtype)
        for t in range(T_horizon):
            r_t, w_t = P_star_detach_ks(theta, d, iz, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, ny_spg)
            ir, iw = rw_to_ir_iw(r_t.item(), w_t.item(), r_grid_t, w_grid_t)
            c = theta_to_c_grid(theta, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t)
            c_t = c[:, iz, ir, iw]
            L_n = L_n + (beta_t ** t) * (d @ u_torch(c_t))
            A = build_A_pi_ks(theta, iz, ir, iw, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, Ty_t, nb_spg, ny_spg)
            d = A.T @ d
            iz = min(np.searchsorted(np.cumsum(Tz_t[iz, :].cpu().numpy()), np.random.rand()), nz_spg - 1)
        L_list.append(L_n)
    return torch.stack(L_list).mean()

In [None]:
theta = init_theta_ks(b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t)
theta = theta.requires_grad_(True)
optimizer = torch.optim.Adam([theta], lr=5e-4)
N_traj, T_horizon = 32, 50
n_epochs = 200
loss_hist = []
for epoch in range(n_epochs):
    optimizer.zero_grad()
    L = spg_objective_ks(theta, N_traj, T_horizon, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t,
                          Ty_t, Tz_t, nb_spg, ny_spg, nz_spg, nr_spg, nw_spg, beta)
    loss_hist.append(L.item())
    (-L).backward()
    optimizer.step()
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}, L(θ) = {L.item():.6f}")
plt.figure(figsize=(8, 4))
plt.plot(loss_hist)
plt.xlabel('Epoch')
plt.ylabel('L(θ)')
plt.title('Krusell-Smith SRL/SPG')
plt.grid(True, alpha=0.3)
plt.show()