# One-Account HANK Model with SRL

One-asset heterogeneous agent New Keynesian model (SRL Section 4.3). Households have **endogenous labour supply**; budget: $c_t + b_{t+1} = (1+r_t)b_t + w_t y_t n_t + d_t - T_t$, $b\geq 0$.

**Policy (consumption, labour):** $(c, n) = \pi(b, y; r, w)$ — depend on **prices** $(r, w)$ (and $d$, $T$ from firms/govt). Same SRL idea as Huggett/K-S: **gradient-stop** on $(r, w)$ so households take prices as given.

**Prices:** In full GE, $(r_t, w_t)$ from bond/labour market clearing; Taylor rule + Phillips curve. Here we use a **simplified price process** so the notebook runs without the full NK block; SRL on the household side is unchanged.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

np.random.seed(42)
plt.rcParams['font.size'] = 10
plt.rcParams['axes.unicode_minus'] = False

## Calibration (SRL Table 6, Appendix A.3)

- **Preferences:** $\beta=0.975$, CRRA in consumption $\sigma=1$, separable labor with Frisch $1/\eta$, $\eta=1$.
- **Taylor rule:** $\phi=1.5$, $\bar{R}=1.025$; monetary shock $e_t$ AR(1) $\rho_e=0.9$, $\nu_e=0.002$.
- **Aggregate TFP** $z_t$: $\rho_z=0.9$, $\nu_z=0.07$.
- **Idiosyncratic** $y$: same as Huggett (log AR(1), 3 states).
- **Bonds:** $b\geq 0$, fixed supply $B$; government $r_t B = T_t$.

In [None]:
beta = 0.975
sigma = 1.0
eta = 1.0       # inverse Frisch elasticity
phi_taylor = 1.5
R_bar = 1.025
rho_z, nu_z = 0.9, 0.07
rho_e, nu_e = 0.9, 0.002
rho_y, nu_y = 0.6, 0.2
ny = 3
B_supply = 10.0   # fixed bond supply
b_min = 0.0
b_max = 100.0
nb = 200
nz = 50
ne = 50
nr, r_min, r_max = 30, 0.01, 0.04
nw, w_min, w_max = 30, 0.7, 1.0
c_min = 1e-3
n_min = 1e-4
T_trunc = 90

print("One-account HANK calibration (SRL 4.3, Table 6):")
print(f"  β={beta}, σ={sigma}, η={eta}, R̄={R_bar}, φ={phi_taylor}")
print(f"  ρz={rho_z}, νz={nu_z}; ρe={rho_e}, νe={nu_e}")

In [None]:
def tauchen_ar1(rho, sigma_innov, n_states, m=3):
    sigma_x = sigma_innov / np.sqrt(1 - rho**2)
    x_min, x_max = -m * sigma_x, m * sigma_x
    x_grid = np.linspace(x_min, x_max, n_states)
    step = (x_max - x_min) / (n_states - 1)
    P = np.zeros((n_states, n_states))
    for i in range(n_states):
        for j in range(n_states):
            if j == 0:
                P[i, j] = norm.cdf((x_grid[j] - rho * x_grid[i] + step / 2) / sigma_innov)
            elif j == n_states - 1:
                P[i, j] = 1 - norm.cdf((x_grid[j] - rho * x_grid[i] - step / 2) / sigma_innov)
            else:
                P[i, j] = (norm.cdf((x_grid[j] - rho * x_grid[i] + step / 2) / sigma_innov)
                           - norm.cdf((x_grid[j] - rho * x_grid[i] - step / 2) / sigma_innov))
    P = P / P.sum(axis=1, keepdims=True)
    return x_grid, P

log_y_grid, Ty = tauchen_ar1(rho_y, nu_y, ny)
y_grid = np.exp(log_y_grid)
invariant_y = np.linalg.matrix_power(Ty.T, 200)[:, 0]
y_grid = y_grid / (y_grid @ invariant_y)

log_z_grid, Tz = tauchen_ar1(rho_z, nu_z, nz)
z_grid = np.exp(log_z_grid)
e_grid, Te = tauchen_ar1(rho_e, nu_e, ne)

b_grid = np.linspace(b_min, b_max, nb)
r_grid = np.linspace(r_min, r_max, nr)
w_grid = np.linspace(w_min, w_max, nw)

def u_separable(c, n):
    """u(c,n) = c^(1-σ)/(1-σ) - n^(1+η)/(1+η)."""
    c = np.maximum(c, c_min)
    n = np.maximum(np.minimum(n, 10.0), n_min)
    if np.abs(sigma - 1) < 1e-10:
        uc = np.log(c)
    else:
        uc = (c ** (1 - sigma)) / (1 - sigma)
    un = -(n ** (1 + eta)) / (1 + eta)
    return uc + un

print("y_grid:", y_grid)
print("z_grid (first/last):", z_grid[0], z_grid[-1])

## Simplified price process (placeholder for full NK)

In the full model, $(r_t, w_t, \Pi_t, d_t, T_t)$ come from firm optimality (Phillips curve), Taylor rule, Fisher equation, and market clearing. Here we use a **simplified law of motion** so we can run SRL on the household side: $r_t$ and $w_t$ follow exogenous AR(1)-like processes around steady state, $T_t = r_t B$.

In [None]:
r_ss = R_bar - 1.0
w_ss = 0.85
d_ss = 0.1

def prices_next(r_t, w_t, z_idx, e_idx, Tz, Te, z_grid, e_grid):
    """Simplified: r, w mean-revert and are shocked by z, e. T = r*B."""
    z_val = z_grid[z_idx]
    e_val = e_grid[e_idx]
    r_next = 0.7 * r_t + 0.2 * r_ss + 0.02 * (z_val - 1) + 0.01 * e_val
    w_next = 0.7 * w_t + 0.2 * w_ss + 0.15 * (z_val - 1)
    r_next = np.clip(r_next, r_min, r_max)
    w_next = np.clip(w_next, w_min, w_max)
    T_next = r_next * B_supply
    d_next = d_ss * z_val
    return r_next, w_next, T_next, d_next

def policy_placeholder(b, y, r, w, d, T, c_frac=0.6, n_frac=0.5):
    """Placeholder: consume c_frac of cash-on-hand, work n_frac. Cash = (1+r)b + w*y*n + d - T."""
    n = np.full_like(b, np.clip(n_frac, n_min, 2.0))
    cash = (1 + r) * b + w * y * n + d - T
    c = np.maximum(c_frac * cash, c_min)
    b_next = np.clip(cash - c, b_min, b_max)
    return c, b_next, n

def draw_next_state(y_idx, z_idx, e_idx, Ty, Tz, Te):
    """Draw (y, z, e)_{t+1} indices from transition matrices (same idea as Huggett/K-S)."""
    y_next = np.random.choice(Ty.shape[1], p=Ty[y_idx, :])
    z_next = np.random.choice(Tz.shape[1], p=Tz[z_idx, :])
    e_next = np.random.choice(Te.shape[1], p=Te[e_idx, :])
    return y_next, z_next, e_next

## Simulation (household side with simplified prices)

Given a policy $(c, n)(b, y, r, w, z)$ and a price process, we simulate paths and compute $\hat{v}_\pi$. Bond market clearing $B = \int b\, dG$ would pin down $r$ in full GE; here we use the simplified $(r_t, w_t)$ process.

In [None]:
def simulate_hank(T, N_agents, policy_fn, y_grid, z_grid, e_grid, Ty, Tz, Te):
    y_idx = np.random.randint(0, ny, size=N_agents)
    z_idx = np.random.randint(0, nz)
    e_idx = np.random.randint(0, ne)
    b = np.zeros(N_agents)
    r_t, w_t = r_ss, w_ss
    T_t = r_t * B_supply
    d_t = d_ss * z_grid[z_idx]
    paths = {'b': np.zeros((T+1, N_agents)), 'y': np.zeros((T+1, N_agents)), 'z': np.zeros(T+1),
             'r': np.zeros(T+1), 'w': np.zeros(T+1), 'c': np.zeros((T+1, N_agents)), 'n': np.zeros((T+1, N_agents))}
    paths['b'][0], paths['y'][0] = b, y_grid[y_idx]
    paths['z'][0], paths['r'][0], paths['w'][0] = z_grid[z_idx], r_t, w_t
    paths['c'][0], paths['n'][0] = np.nan, np.nan
    for t in range(T):
        y_val = y_grid[y_idx]
        c_t, b_next, n_t = policy_fn(b, y_val, r_t, w_t, d_t, T_t)
        paths['c'][t], paths['n'][t] = c_t, n_t
        paths['r'][t], paths['w'][t] = r_t, w_t
        b = b_next
        paths['b'][t+1] = b
        _, z_idx, e_idx = draw_next_state(y_idx[0], z_idx, e_idx, Ty, Tz, Te)
        y_idx = np.array([draw_next_state(y_idx[i], z_idx, e_idx, Ty, Tz, Te)[0] for i in range(N_agents)])
        paths['y'][t+1] = y_grid[y_idx]
        r_t, w_t, T_t, d_t = prices_next(r_t, w_t, z_idx, e_idx, Tz, Te, z_grid, e_grid)
        paths['z'][t+1] = z_grid[z_idx]
    paths['r'][T], paths['w'][T] = r_t, w_t
    paths['c'][T], _, paths['n'][T] = policy_fn(b, paths['y'][T], r_t, w_t, d_t, T_t)
    v_hat = np.sum([(beta**t) * np.mean(u_separable(paths['c'][t], paths['n'][t])) for t in range(T)])
    return paths, v_hat

In [None]:
def policy_fn(b, y, r, w, d, T):
    return policy_placeholder(b, y, r, w, d, T, c_frac=0.6, n_frac=0.5)

paths, v_hat = simulate_hank(100, 500, policy_fn, y_grid, z_grid, e_grid, Ty, Tz, Te)
print("v̂_π =", v_hat)
print("Mean c, n:", np.nanmean(paths['c']), np.nanmean(paths['n']))

## SRL / SPG for one-account HANK

Policy parameters $\theta$ represent **consumption** and **labor** on a grid over $(b, y, r, w, z)$. We maximize $L(\theta) = \mathbb{E}[\sum_t \beta^t u(c_t, n_t)]$ with **stop-gradient** on $(r_t, w_t)$ so households take prices as given. In the full model, $(r, w)$ would be determined by bond/labor market clearing and the NK block (Phillips curve + Taylor rule); here we keep the simplified price process.

## Summary

- **Model:** One-account HANK: (c, n) = π(b, y; r, w); budget c + b' = (1+r)b + w·y·n + d − T; separable u(c,n).
- **State dynamics:** y, z, e from transition matrices Ty, Tz, Te via `draw_next_state` (same as Huggett/K-S).
- **Prices:** Simplified process here; full model would use bond/labour clearing + Taylor rule + Phillips curve.
- **SRL/SPG:** Maximize L(θ) with gradient-stop on (r, w); policy (c, n) on (b, y, r, w, z) grid.

In [None]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32

nb_spg, nr_spg, nw_spg, nz_spg = 30, 10, 10, 8
ny_spg = ny
J = nb_spg * ny_spg

b_grid_t = torch.tensor(np.linspace(b_min, b_max, nb_spg), dtype=dtype, device=device)
r_grid_t = torch.tensor(np.linspace(r_min, r_max, nr_spg), dtype=dtype, device=device)
w_grid_t = torch.tensor(np.linspace(w_min, w_max, nw_spg), dtype=dtype, device=device)
iz = np.linspace(0, nz-1, nz_spg, dtype=int)
z_grid_t = torch.tensor(z_grid[iz], dtype=dtype, device=device)
y_grid_t = torch.tensor(y_grid, dtype=dtype, device=device)
Ty_t = torch.tensor(Ty, dtype=dtype, device=device)
Tz_sub = Tz[np.ix_(iz, iz)] / Tz[np.ix_(iz, iz)].sum(axis=1, keepdims=True)
Tz_t = torch.tensor(Tz_sub, dtype=dtype, device=device)

def theta_to_cn(theta_c, theta_n):
    """θ_c, θ_n (J, nz, nr, nw) -> c >= c_min, n in [n_min, 2]."""
    c = torch.nn.functional.softplus(theta_c) + c_min
    n = torch.sigmoid(theta_n) * (2.0 - n_min) + n_min
    return c, n

def init_theta_hank(b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, c_frac=0.6, n_val=0.5):
    nz, nr, nw = len(z_grid_t), len(r_grid_t), len(w_grid_t)
    theta_c = torch.zeros(J, nz, nr, nw, dtype=dtype, device=device)
    theta_n = torch.zeros(J, nz, nr, nw, dtype=dtype, device=device)
    d_ss_t = 0.1
    for ib in range(nb_spg):
        for iy in range(ny_spg):
            j = ib * ny_spg + iy
            b = b_grid_t[ib].item()
            y = y_grid_t[iy].item()
            for iz in range(nz):
                z = z_grid_t[iz].item()
                for ir in range(nr):
                    r = r_grid_t[ir].item()
                    for iw in range(nw):
                        w = w_grid_t[iw].item()
                        T = r * B_supply
                        d = d_ss_t * z
                        cash = (1 + r) * b + w * y * n_val + d - T
                        c_init = max(c_frac * cash, c_min)
                        theta_c[j, iz, ir, iw] = np.log(np.exp(c_init - c_min) - 1 + 1e-8)
                        theta_n[j, iz, ir, iw] = np.log((n_val - n_min) / (2.0 - n_val + 1e-8))
    return theta_c, theta_n

def u_torch(c_vec, n_vec):
    c_vec = torch.clamp(c_vec, min=c_min)
    n_vec = torch.clamp(n_vec, min=n_min, max=10.0)
    return (c_vec ** (1 - sigma)) / (1 - sigma) - (n_vec ** (1 + eta)) / (1 + eta)

In [None]:
# Simplified: (r,w) drawn from grid indices each period (no market clearing). SPG maximizes L(θ).
def spg_objective_hank(theta_c, theta_n, N_traj, T_horizon, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t,
                        Ty_t, Tz_t, nb_spg, ny_spg, nz_spg, nr_spg, nw_spg, beta_t, d_ss_val=0.1, d0=None):
    J = nb_spg * ny_spg
    if d0 is None:
        d0 = torch.ones(J, device=device, dtype=dtype) / J
    L_list = []
    for n in range(N_traj):
        iz = np.random.randint(0, nz_spg)
        ir = np.random.randint(0, nr_spg)
        iw = np.random.randint(0, nw_spg)
        d = d0.clone()
        L_n = torch.tensor(0.0, device=device, dtype=dtype)
        for t in range(T_horizon):
            r_t = r_grid_t[ir].detach()
            w_t = w_grid_t[iw].detach()
            z_val = z_grid_t[iz]
            T_t = r_t * B_supply
            d_t = d_ss_val * z_val
            c, n = theta_to_cn(theta_c, theta_n)
            c_t = c[:, iz, ir, iw]
            n_t = n[:, iz, ir, iw]
            L_n = L_n + (beta_t ** t) * (d @ u_torch(c_t, n_t))
            b_vals = b_grid_t.repeat_interleave(ny_spg)
            y_vals = y_grid_t.repeat(nb_spg)
            b_next = (1 + r_t) * b_vals + w_t * y_vals * n_t + d_t - T_t - c_t
            b_next = torch.clamp(b_next, b_min, b_max)
            sigma_b = 0.5
            dist = b_next.unsqueeze(1) - b_grid_t.unsqueeze(0)
            w_b = torch.exp(-dist.pow(2) / (2 * sigma_b**2))
            w_b = w_b / (w_b.sum(dim=1, keepdim=True) + 1e-8)
            A = torch.zeros(J, J, dtype=dtype, device=device)
            for j in range(J):
                ib, iy = j // ny_spg, j % ny_spg
                for ibp in range(nb_spg):
                    for iyp in range(ny_spg):
                        jp = ibp * ny_spg + iyp
                        A[jp, j] = w_b[j, ibp] * Ty_t[iy, iyp]
            d = A.T @ d
            iz = np.random.choice(nz_spg, p=Tz_t[iz, :].detach().cpu().numpy())
            ir = np.random.randint(0, nr_spg)
            iw = np.random.randint(0, nw_spg)
        L_list.append(L_n)
    return torch.stack(L_list).mean()

In [None]:
theta_c, theta_n = init_theta_hank(b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t)
theta_c = theta_c.requires_grad_(True)
theta_n = theta_n.requires_grad_(True)
optimizer = torch.optim.Adam([theta_c, theta_n], lr=5e-4)
N_traj, T_horizon = 24, 40
n_epochs = 150
loss_hist = []
for epoch in range(n_epochs):
    optimizer.zero_grad()
    L = spg_objective_hank(theta_c, theta_n, N_traj, T_horizon, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t,
                           Ty_t, Tz_t, nb_spg, ny_spg, nz_spg, nr_spg, nw_spg, beta)
    loss_hist.append(L.item())
    (-L).backward()
    optimizer.step()
    if (epoch + 1) % 30 == 0:
        print(f"Epoch {epoch+1}, L(θ) = {L.item():.6f}")
plt.figure(figsize=(8, 4))
plt.plot(loss_hist)
plt.xlabel('Epoch')
plt.ylabel('L(θ)')
plt.title('One-account HANK SRL/SPG (household block)')
plt.grid(True, alpha=0.3)
plt.show()