# Krusell–Smith (1998) Model with SRL

Heterogeneous-agent economy: **wealth is productive capital** rented to a representative firm. Production $Y_t = z_t K_t^\alpha L_t^{1-\alpha}$ with **total labour $L_t = 1$** (fixed). Factor prices: $r_t = \alpha z_t K_t^{\alpha-1} - \delta$, $w_t = (1-\alpha) z_t K_t^\alpha$.

**Policy (consumption, labour):** $(c, n) = \pi(b, y; r, w)$ — consumption and labour supply depend on **prices** $(r, w)$.  
**Individual state:** $(b, y)$ — capital and idiosyncratic labour productivity.  
**Aggregate:** $K_t = \int b\, dG_t$; $(r_t, w_t) = P^*(K_t, z_t)$ with $L=1$.  
**Budget:** $c + b' = (1+r)b + w\, n\, y$, $b \geq 0$, $n \in [0,1]$. We assume **total labour supply** $\int n\, dG = 1$.

In [None]:
# ========== 依赖库 / Dependencies ==========
# 中文: numpy 数值计算与网格；matplotlib 作图；scipy.stats.norm 为 Tauchen 离散化提供正态 CDF
# EN: numpy for arrays/grids; matplotlib for plots; scipy.stats.norm for Tauchen (normal CDF)
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

np.random.seed(42)   # 固定随机种子，便于复现 / fix seed for reproducibility
plt.rcParams['font.size'] = 10
plt.rcParams['axes.unicode_minus'] = False  # 避免图中负号显示为方块 / avoid minus sign rendering as block

## Calibration (SRL Section 4.2 & Appendix A.2, Table 4 & 5)

- **Preferences:** $\beta=0.95$, CRRA $\sigma=3$.
- **Production:** capital share $\alpha=0.36$, depreciation $\delta=0.08$.
- **Idiosyncratic income** $y$: same as Huggett (log AR(1), $\rho_y=0.6$, $\nu_y=0.2$), 3 states.
- **Aggregate TFP** $z$: log AR(1), $\rho_z=0.9$, $\nu_z=0.03$.
- **Capital:** $b\geq 0$, grid $n_b=200$, $b_{\max}=100$.

In [None]:
# ========== 校准参数 / Calibration (SRL Section 4.2 & Appendix A.2, Table 4 & 5) ==========

# --- 偏好 / Preferences ---
beta = 0.95       # 折现因子 / discount factor
sigma = 3.0       # CRRA 系数 σ / CRRA coefficient

# --- 生产 / Production ---
alpha = 0.36
delta = 0.08

# 个体收入 y（与 Huggett 一致：水平 AR(1)，3 状态）/ Idiosyncratic y: same as Huggett, 3 states
rho_y = 0.6
nu_y = 0.2
ny = 3

# 总量 z：log z ~ AR(1) / Aggregate TFP z: log z ~ AR(1)
rho_z = 0.9
nu_z = 0.03

# 资本网格 / Capital grid
b_min = 0.0
b_max = 100.0
nb = 200
J = nb * ny       # 复合状态数 j = ib*ny+iy / composite state size

# 价格网格（政策查表）/ Price grids for policy (Table 5)
nr = 30
r_min, r_max = 0.02, 0.07
nw = 50
w_min, w_max = 0.9, 1.5
nz = 30

# 数值截断 / Truncation (same style as Huggett)
c_min = 1e-3
T_trunc = 90
etrunc = 1e-2

# ========== SPG training (same style as Huggett Table 3) ==========
N_epoch = 200
N_warmup = 25
lr_ini = 5e-4
lr_decay = 0.5
N_sample = 32      # 每轮轨迹数 / trajectories per update
e_converge = 3e-4

print("Krusell-Smith calibration (SRL 4.2, App A.2):")
print(f"  β={beta}, σ={sigma}, α={alpha}, δ={delta}")
print(f"  ρy={rho_y}, νy={nu_y}, ρz={rho_z}, νz={nu_z}")
print(f"  b∈[0,{b_max}], nb={nb}, ny={ny}, J={J}; r∈[{r_min},{r_max}], nr={nr}; w∈[{w_min},{w_max}], nw={nw}")
print(f"  T_trunc={T_trunc}, c_min={c_min}")
print("  SPG: Nepoch=%d, Nwarmup=%d, lr_ini=%s, lr_decay=%s, Nsample=%d, econverge=%s" % (N_epoch, N_warmup, lr_ini, lr_decay, N_sample, e_converge))

In [None]:
# ========== Tauchen (1986) 离散化 AR(1)，与 Huggett 统一形式 / Discretize AR(1) (unified) ==========
def tauchen_ar1(rho, sigma_innov, n_states, m=3, mean=0.0):
    """Unified Tauchen. mean=0 → zero-mean (e.g. log z); mean=μ → level process (e.g. y with E[y]=1)."""
    std = sigma_innov / np.sqrt(1 - rho**2)
    if mean == 0:
        x_min, x_max = -m * std, m * std
    else:
        x_min = max(1e-6, mean - m * std)
        x_max = mean + m * std
    x_grid = np.linspace(x_min, x_max, n_states)
    step = (x_max - x_min) / (n_states - 1) if n_states > 1 else 1.0
    mu_i = (1 - rho) * mean + rho * x_grid
    z_lo = (x_grid - mu_i[:, None] + step / 2) / sigma_innov
    z_hi = (x_grid - mu_i[:, None] - step / 2) / sigma_innov
    P = np.zeros((n_states, n_states))
    P[:, 0] = norm.cdf(z_lo[:, 0])
    P[:, -1] = 1 - norm.cdf(z_hi[:, -1])
    if n_states > 2:
        P[:, 1:-1] = norm.cdf(z_lo[:, 1:-1]) - norm.cdf(z_hi[:, 1:-1])
    P = P / P.sum(axis=1, keepdims=True)
    return x_grid, P

# 个体收入 y：水平 AR(1)，mean=1 再归一化（与 Huggett 一致）/ Idiosyncratic y: level AR(1), E[y]=1
y_grid, Ty = tauchen_ar1(rho_y, nu_y, ny, m=3, mean=1.0)
invariant_y = np.linalg.matrix_power(Ty.T, 200)[:, 0]
y_grid = y_grid / (y_grid @ invariant_y)

# 总量 z：log z ~ AR(1) / Aggregate z: log z ~ AR(1)
log_z_grid, Tz = tauchen_ar1(rho_z, nu_z, nz, mean=0.0)
z_grid = np.exp(log_z_grid)
invariant_z = np.linalg.matrix_power(Tz.T, 200)[:, 0]
z_grid = z_grid / (z_grid @ invariant_z)

b_grid = np.linspace(b_min, b_max, nb)
r_grid = np.linspace(r_min, r_max, nr)
w_grid = np.linspace(w_min, w_max, nw)

# ---------- 效用函数（CRRA）/ Utility (CRRA; log(c) when σ=1) ----------
def u(c, sig=sigma):
    c = np.maximum(c, c_min)
    if np.abs(sig - 1.0) < 1e-10:
        return np.log(c)
    return (c ** (1 - sig)) / (1 - sig)

def u_prime(c, sig=sigma):
    return np.maximum(c, c_min) ** (-sig)

print("y_grid (idiosyncratic):", y_grid)
print("z_grid (first/last):", z_grid[0], z_grid[-1])

## Production and prices

Given aggregate capital $K$ and TFP $z$: $r^K = \alpha z K^{\alpha-1}$, $w = (1-\alpha)z K^\alpha$ (with $L=1$). Net return $r = r^K - \delta$.

In [None]:
def K_to_prices(K, z):
    """(K, z) -> (r_net, w). Total labour L=1 (fixed)."""
    K = np.maximum(K, 1e-8)
    rK = alpha * z * (K ** (alpha - 1))
    w = (1 - alpha) * z * (K ** alpha)
    r_net = rK - delta
    return r_net, w

# Example: steady-state K from β: 1 = β(1 + r) => r_ss = 1/beta - 1
r_ss = 1.0 / beta - 1
z_ss = 1.0
K_ss = (alpha * z_ss / (r_ss + delta)) ** (1 / (1 - alpha))
r_ss_check, w_ss_check = K_to_prices(K_ss, z_ss)
print(f"Steady state (z=1): r_ss={r_ss:.4f}, K_ss={K_ss:.4f}, w_ss={w_ss_check:.4f}")

## Policy and simulation (SRL: price-based)

**Policy (consumption, labour):** $(c, n) = \pi(b, y, r, w)$ from **prices** $(r, w)$. In GE, $(r_t, w_t) = P^*(K_t, z_t)$ with $K_t = \int b\, dG_t$ and **total labour** $L=1$. We **stop-gradient** on $(r,w)$.

In [None]:
def draw_next_state(y_idx, z_idx, Ty, Tz):
    """按 Ty、Tz 抽取下一期 (y',z') 的网格索引 / draw next (y,z) indices (same as Huggett)."""
    y_next = np.random.choice(Ty.shape[1], p=Ty[y_idx, :])
    z_next = np.random.choice(Tz.shape[1], p=Tz[z_idx, :])
    return y_next, z_next

def policy_placeholder(b, y, r, w, z, save_frac=0.25, n_labour=1.0):
    """(c, n, b') = π(b, y, r, w). Budget: c + b' = (1+r)*b + w*n*y. Returns raw (c, n, b'); n=1."""
    cash = (1 + r) * b + w * n_labour * y
    b_next = np.clip(save_frac * cash, b_min, b_max)
    c = np.maximum(cash - b_next, c_min)
    n = np.full_like(b, n_labour) if np.isscalar(b) else np.full(len(b), n_labour)
    return c, n, b_next

# ========== 从网格 θ 得到政策 π(b,y,r,w,z)=(c,n,b')，只返回连续值；彩票在转移时用 / Policy from grid θ, raw (c,n,b') ==========
def policy_from_grid_ks(b, y, r, w, z, theta_grid, b_grid, y_grid, z_grid, r_grid, w_grid, ny, c_min_val=1e-3, n_labour=1.0):
    """Input: (b,y), (r,w,z), theta_grid (c grid shape J,nz,nr,nw). Output: (c, n, b_next) raw; no lottery here."""
    if hasattr(theta_grid, 'detach'):
        theta_grid = theta_grid.detach().cpu().numpy()
    if hasattr(b_grid, 'detach'):
        b_grid = b_grid.detach().cpu().numpy()
    if hasattr(y_grid, 'detach'):
        y_grid = y_grid.detach().cpu().numpy()
    if hasattr(z_grid, 'detach'):
        z_grid = z_grid.detach().cpu().numpy()
    if hasattr(r_grid, 'detach'):
        r_grid = r_grid.detach().cpu().numpy()
    if hasattr(w_grid, 'detach'):
        w_grid = w_grid.detach().cpu().numpy()
    b_grid = np.asarray(b_grid)
    y_grid = np.asarray(y_grid)
    z_grid = np.asarray(z_grid)
    r_grid = np.asarray(r_grid)
    w_grid = np.asarray(w_grid)
    ib = np.atleast_1d(np.clip(np.searchsorted(b_grid, b, side='right') - 1, 0, len(b_grid)-1))
    iy = np.atleast_1d(np.clip(np.searchsorted(y_grid, y, side='right') - 1, 0, len(y_grid)-1))
    iz = np.atleast_1d(np.clip(np.searchsorted(z_grid, z, side='right') - 1, 0, len(z_grid)-1))
    ir = np.atleast_1d(np.clip(np.searchsorted(r_grid, r, side='right') - 1, 0, len(r_grid)-1))
    iw = np.atleast_1d(np.clip(np.searchsorted(w_grid, w, side='right') - 1, 0, len(w_grid)-1))
    b, y, r, w, z = np.atleast_1d(b, dtype=float), np.atleast_1d(y, dtype=float), np.atleast_1d(r, dtype=float), np.atleast_1d(w, dtype=float), np.atleast_1d(z, dtype=float)
    j = ib * ny + iy
    c = np.maximum(theta_grid[j, iz, ir, iw], c_min_val)
    cash = (1 + r) * b + w * n_labour * y
    b_next = np.clip(cash - c, b_min, b_max)
    c = np.maximum(cash - b_next, c_min_val)
    n = np.full_like(b, n_labour)
    if c.size == 1:
        return c.ravel()[0], n.ravel()[0], b_next.ravel()[0]
    return c, n, b_next

def b_next_to_grid_lottery(b_next_feasible, b_grid):
    """连续 b' 按 Young 法随机舍入到 b_grid 相邻两点（仅当需要离散 b' 时用）。"""
    b_next_feasible = np.asarray(b_next_feasible)
    i_lo = np.clip(np.searchsorted(b_grid, b_next_feasible, side='right') - 1, 0, len(b_grid) - 2)
    gap = b_grid[i_lo + 1] - b_grid[i_lo]
    lam = np.clip((b_next_feasible - b_grid[i_lo]) / (gap + 1e-12), 0.0, 1.0)
    lottery = np.random.rand(len(np.atleast_1d(b_next_feasible))) < lam
    return np.where(lottery, b_grid[i_lo + 1], b_grid[i_lo])

def aggregate_capital(b_vec):
    return np.mean(b_vec)

def simulate_krusell_smith(T, N_agents, policy_fn, y_grid, z_grid, Ty, Tz, b0=None, z0_idx=None):
    """Simulate K-S: K_t=mean(b_t), (r_t,w_t)=P*(K_t,z_t) with L=1, (c,n,b')=π(b,y,r,w). y,z from Ty,Tz."""
    if z0_idx is None:
        z0_idx = np.random.randint(0, len(z_grid))
    y0_idx = np.random.randint(0, len(y_grid), size=N_agents)
    if b0 is None:
        b0 = np.full(N_agents, K_ss / N_agents)
    b = b0.copy()
    y_idx = y0_idx.copy()
    z_idx = z0_idx
    paths = {'b': np.zeros((T+1, N_agents)), 'y': np.zeros((T+1, N_agents)), 'z': np.zeros(T+1),
             'r': np.zeros(T+1), 'w': np.zeros(T+1), 'K': np.zeros(T+1), 'c': np.zeros((T+1, N_agents)), 'n': np.zeros((T+1, N_agents))}
    paths['b'][0], paths['y'][0], paths['z'][0] = b, y_grid[y_idx], z_grid[z_idx]
    paths['K'][0] = np.mean(b)
    paths['r'][0], paths['w'][0] = K_to_prices(paths['K'][0], paths['z'][0])
    paths['c'][0], paths['n'][0] = np.nan, np.nan
    for t in range(T):
        K_t = np.mean(b)
        z_t = z_grid[z_idx]
        r_t, w_t = K_to_prices(K_t, z_t)
        y_val = y_grid[y_idx]
        c_t, n_t, b_next = policy_fn(b, y_val, r_t, w_t, z_t)
        paths['r'][t], paths['w'][t], paths['K'][t] = r_t, w_t, K_t
        paths['c'][t], paths['n'][t] = c_t, n_t
        b = b_next
        paths['b'][t+1] = b
        _, z_idx = draw_next_state(y_idx[0], z_idx, Ty, Tz)
        y_idx = np.array([draw_next_state(y_idx[i], z_idx, Ty, Tz)[0] for i in range(N_agents)])
        paths['y'][t+1] = y_grid[y_idx]
        paths['z'][t+1] = z_grid[z_idx]
    paths['K'][T] = np.mean(b)
    paths['r'][T], paths['w'][T] = K_to_prices(paths['K'][T], paths['z'][T])
    c_T, n_T, _ = policy_placeholder(b, paths['y'][T], paths['r'][T], paths['w'][T], paths['z'][T])
    paths['c'][T], paths['n'][T] = c_T, n_T
    v_hat = np.sum([(beta**t) * np.mean(u(paths['c'][t])) for t in range(T)])
    return paths, v_hat

In [None]:
policy_fn = lambda b, y, r, w, z: policy_placeholder(b, y, r, w, z, save_frac=0.25)
T_sim, N_agents = 200, 2000
paths, v_hat = simulate_krusell_smith(T_sim, N_agents, policy_fn, y_grid, z_grid, Ty, Tz)
print("Simulation done. v̂_π =", v_hat)
print("Mean K:", np.mean(paths['K']))
print("Mean r, w:", np.mean(paths['r']), np.mean(paths['w']))
print("Mean n (total labour ≈ 1):", np.nanmean(paths['n']))

### Part A：给定政策下的分布模拟（与 Huggett 一致）/ Distribution simulation with given policy

d 为分布 (nb,ny) 或 (J,)；K = d·b（当期资本）；(r,w)=P*(K,z)；update_d_direct 用 Young 法 d_{t+1} 不建 A。

In [None]:
def d_to_mat(d):
    """统一为 (nb, ny) 矩阵；若已是 (J,) 则 reshape。"""
    d = np.asarray(d)
    if d.ndim == 1:
        return d.reshape(nb, ny)
    return d

def d_to_flat(d):
    """转为 (J,) 向量，j = ib*ny+iy。"""
    return np.asarray(d).reshape(nb, ny).ravel()

# 总资本 K = d·b（当期资产加权），用于 (r,w)=K_to_prices(K,z)
def aggregate_capital_from_d(d, b_grid, y_grid, ny):
    """K = Σ_j d(j)*b(j) = d·b_vals。"""
    nb_ = len(b_grid)
    b_vals = np.repeat(b_grid, ny)
    return np.dot(d_to_flat(d), b_vals)

# d_t → d_{t+1} 直接映射（Young 法），不建 A；policy 返回 (c, n, b_next) 连续
def update_d_direct(d, r, w, z, policy_fn, b_grid, y_grid, Ty):
    """从 d 和 policy 直接得到 d_new；Young 权重散射后 Q @ Ty。"""
    nb_, ny_ = len(b_grid), len(y_grid)
    J_ = nb_ * ny_
    d_flat = d_to_flat(d)
    b_flat = np.repeat(b_grid, ny_)
    y_flat = np.tile(y_grid, nb_)
    c, n, b_next_feasible = policy_fn(b_flat, y_flat, r, w, z)
    i_lo = np.clip(np.searchsorted(b_grid, b_next_feasible, side='right') - 1, 0, nb_ - 2)
    gap = b_grid[i_lo + 1] - b_grid[i_lo]
    lam = np.clip((b_next_feasible - b_grid[i_lo]) / (gap + 1e-12), 0.0, 1.0)
    iy_all = np.arange(J_) % ny_
    Q = np.zeros((nb_, ny_))
    np.add.at(Q, (i_lo, iy_all), (1 - lam) * d_flat)
    np.add.at(Q, (i_lo + 1, iy_all), lam * d_flat)
    d_new = Q @ Ty
    d_new = d_new / (d_new.sum() + 1e-20)
    return d_new

# 可选：分布层面模拟（每期 K_t=d·b, (r,w)=P*(K,z), d_{t+1}=update_d_direct）
# paths_d = simulate_krusell_smith_distribution(T, policy_fn, b_grid, y_grid, z_grid, Ty, Tz, d0=None, z0_idx=None)
# 实现方式同 Huggett 的 simulate_huggett，此处略。


## SRL / SPG: Grid policy $(c,n)(b,y,r,w)$ with gradient-stop on $(r,w)$

Policy $(c, n) = \pi(b, y; r, w)$; total labour $L=1$. Same as Huggett: maximize $L(\theta) = \mathbb{E}[\sum_t \beta^t d_t^\top u(c_t)]$; $K_t$ from $d_t$, $(r_t,w_t)=P^*(K_t,z_t)$ **detached**.

In [None]:
# ========== Part B：SRL/SPG 训练（θ 可微；d 用向量 (J,)；软权重、P* detach）/ Part B: SRL training ==========
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float32

# SPG 用较粗网格加速（Table 5: nb=200, nr=30, nw=50；此处缩小以加速）
nb_spg, nr_spg, nw_spg, nz_spg = 50, 15, 25, 10
ny_spg = ny
J = nb_spg * ny_spg

b_grid_t = torch.tensor(np.linspace(b_min, b_max, nb_spg), dtype=dtype, device=device)
r_grid_t = torch.tensor(np.linspace(r_min, r_max, nr_spg), dtype=dtype, device=device)
w_grid_t = torch.tensor(np.linspace(w_min, w_max, nw_spg), dtype=dtype, device=device)
z_grid_t = torch.tensor(z_grid[np.linspace(0, nz-1, nz_spg, dtype=int)], dtype=dtype, device=device)
y_grid_t = torch.tensor(y_grid, dtype=dtype, device=device)
Ty_t = torch.tensor(Ty, dtype=dtype, device=device)
iz_spg = np.linspace(0, nz-1, nz_spg, dtype=int)
Tz_sub = Tz[np.ix_(iz_spg, iz_spg)]
Tz_sub = Tz_sub / Tz_sub.sum(axis=1, keepdims=True)
Tz_t = torch.tensor(Tz_sub, dtype=dtype, device=device)
nz_spg = Tz_t.shape[0]

# 政策参数 θ → 消费网格 c = softplus(θ)+c_min（保证 c>0 且可导，同 Huggett）
def theta_to_c_grid(theta, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, c_min_val=None):
    """θ (J, nz, nr, nw) -> c >= c_min."""
    if c_min_val is None:
        c_min_val = c_min
    return torch.nn.functional.softplus(theta) + c_min_val

# 初值：常数储蓄率规则得 c，反解 θ（同 Huggett init_theta）
def init_theta_ks(b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, save_frac=0.25):
    nz, nr, nw = len(z_grid_t), len(r_grid_t), len(w_grid_t)
    c_grid = torch.zeros(J, nz, nr, nw, dtype=dtype, device=device)
    for ib in range(nb_spg):
        for iy in range(ny_spg):
            j = ib * ny_spg + iy
            b = b_grid_t[ib].item()
            y = y_grid_t[iy].item()
            for iz in range(nz):
                z = z_grid_t[iz].item()
                for ir in range(nr):
                    r = r_grid_t[ir].item()
                    for iw in range(nw):
                        w = w_grid_t[iw].item()
                        cash = (1 + r) * b + w * y
                        c_grid[j, iz, ir, iw] = max((1 - save_frac) * cash, c_min)
    x = c_grid - c_min
    return torch.log(torch.exp(x) - 1 + 1e-8)

def K_from_d(d, b_grid_t, ny_spg):
    """Aggregate capital K = d^T b (over b indices)."""
    b_vals = b_grid_t.repeat_interleave(ny_spg)
    return (d * b_vals).sum()

def P_star_detach_ks(theta, d, iz, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, ny_spg):
    """(K,z) -> (r,w) from production; K = d·b. Return (r,w) detached."""
    K = K_from_d(d, b_grid_t, ny_spg).item()
    z_val = z_grid_t[iz].item()
    r_net, w_val = K_to_prices(K, z_val)
    r_t = torch.tensor(r_net, device=device, dtype=dtype)
    w_t = torch.tensor(w_val, device=device, dtype=dtype)
    return r_t.detach(), w_t.detach()

def rw_to_ir_iw(r_val, w_val, r_grid_t, w_grid_t):
    rn, wn = r_grid_t.cpu().numpy(), w_grid_t.cpu().numpy()
    ir = int(np.clip(np.searchsorted(rn, r_val), 0, len(rn)-1))
    iw = int(np.clip(np.searchsorted(wn, w_val), 0, len(wn)-1))
    return ir, iw

In [None]:
# ========== d_t -> d_{t+1} direct（向量化 Q@Ty，不建 J×J）/ same as Huggett update_G_pi_direct ==========
def update_d_pi_direct_ks(theta, d, iz, ir, iw, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, Ty_t, nb_spg, ny_spg, sigma_b=0.5):
    """d_new = (implied A) @ d；用 Q(ib',iy)=sum_ib w_b(ib*ny+iy,ib')*d(ib,iy)，再 d_new = Q @ Ty，无 J 循环、不建 A。"""
    J = nb_spg * ny_spg
    d_mat = d.view(nb_spg, ny_spg) if d.dim() == 1 else d
    z_val = z_grid_t[iz]
    r_val = r_grid_t[ir]
    w_val = w_grid_t[iw]
    c = theta_to_c_grid(theta, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t)[:, iz, ir, iw]
    b_next = (1 + r_val) * b_grid_t.repeat_interleave(ny_spg) + y_grid_t.repeat(nb_spg) * w_val - c
    b_next = torch.clamp(b_next, b_min, b_max)
    dist = b_next.unsqueeze(1) - b_grid_t.unsqueeze(0)
    w_b = torch.exp(-dist.pow(2) / (2 * sigma_b**2))
    w_b = w_b / (w_b.sum(dim=1, keepdim=True) + 1e-8)
    M = w_b.view(nb_spg, ny_spg, nb_spg).permute(2, 0, 1)
    Q = (M * d_mat.unsqueeze(0)).sum(dim=1)
    d_new = (Q @ Ty_t).reshape(-1)
    return d_new / (d_new.sum() + 1e-20)

def build_A_pi_ks(theta, iz, ir, iw, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, Ty_t, nb_spg, ny_spg, sigma_b=0.5):
    """(Optional) Build full A_π; use update_d_pi_direct_ks in training. Evolution: d_new = A @ d."""
    J = nb_spg * ny_spg
    z_val = z_grid_t[iz]
    r_val = r_grid_t[ir]
    w_val = w_grid_t[iw]
    c = theta_to_c_grid(theta, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t)[:, iz, ir, iw]
    b_next = (1 + r_val) * b_grid_t.repeat_interleave(ny_spg) + y_grid_t.repeat(nb_spg) * w_val - c
    b_next = torch.clamp(b_next, b_min, b_max)
    dist = b_next.unsqueeze(1) - b_grid_t.unsqueeze(0)
    w_b = torch.exp(-dist.pow(2) / (2 * sigma_b**2))
    w_b = w_b / (w_b.sum(dim=1, keepdim=True) + 1e-8)
    A = torch.zeros(J, J, dtype=dtype, device=device)
    for j in range(J):
        ib, iy = j // ny_spg, j % ny_spg
        for ibp in range(nb_spg):
            for iyp in range(ny_spg):
                jp = ibp * ny_spg + iyp
                A[jp, j] = w_b[j, ibp] * Ty_t[iy, iyp]
    return A

def u_torch(c_vec, sig=sigma):
    """CRRA；σ=1 时为 log(c)，同 Huggett。"""
    c_vec = torch.clamp(c_vec, min=c_min)
    if abs(sig - 1.0) < 1e-8:
        return torch.log(c_vec)
    return (c_vec ** (1 - sig)) / (1 - sig)

# Steady-state d0：用 update_d_pi_direct_ks 迭代（不建 A），与 Huggett steady_state_G0 一致
def steady_state_d0_ks(theta, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, Ty_t, nb_spg, ny_spg, nz_spg, nr_spg, nw_spg, n_iter=150):
    J = nb_spg * ny_spg
    iz_mid = nz_spg // 2
    ir_mid = max(0, nr_spg // 2)
    iw_mid = max(0, nw_spg // 2)
    with torch.no_grad():
        d = torch.ones(J, device=device, dtype=dtype) / J
        for _ in range(n_iter):
            d = update_d_pi_direct_ks(theta, d, iz_mid, ir_mid, iw_mid, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, Ty_t, nb_spg, ny_spg)
    return d

def spg_objective_ks(theta, N_traj, T_horizon, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, Ty_t, Tz_t,
                       nb_spg, ny_spg, nz_spg, nr_spg, nw_spg, beta_t, d0=None, warm_up=False):
    """warm_up=True: d stays at d0 (no update). warm_up=False: d evolves via update_d_pi_direct_ks (d_new = A @ d)."""
    J = nb_spg * ny_spg
    if d0 is None:
        d0 = torch.ones(J, device=device, dtype=dtype) / J
    L_list = []
    for n in range(N_traj):
        iz = np.random.randint(0, nz_spg)
        d = d0.clone()
        L_n = torch.tensor(0.0, device=device, dtype=dtype)
        for t in range(T_horizon):
            r_t, w_t = P_star_detach_ks(theta, d, iz, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, ny_spg)
            ir, iw = rw_to_ir_iw(r_t.item(), w_t.item(), r_grid_t, w_grid_t)
            c = theta_to_c_grid(theta, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t)
            c_t = c[:, iz, ir, iw]
            L_n = L_n + (beta_t ** t) * (d @ u_torch(c_t))
            if not warm_up:
                d = update_d_pi_direct_ks(theta, d, iz, ir, iw, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, Ty_t, nb_spg, ny_spg)
            iz = np.random.choice(nz_spg, p=Tz_t[iz, :].detach().cpu().numpy())
        L_list.append(L_n)
    return torch.stack(L_list).mean()

In [None]:
# ========== Init θ and run SPG：两阶段（warm-up 固定 d0，再 d 演化），与 Huggett 一致 ==========
T_horizon = min(T_trunc, 50)   # 每条轨迹期数

theta = init_theta_ks(b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t)
theta = theta.requires_grad_(True)
optimizer = torch.optim.Adam([theta], lr=lr_ini)
d0_steady = steady_state_d0_ks(theta, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t, Ty_t, nb_spg, ny_spg, nz_spg, nr_spg, nw_spg)

loss_hist = []
for epoch in range(N_epoch):
    warm_up = (epoch < N_warmup)
    t0 = max(epoch - N_warmup, 0) / max(N_epoch - N_warmup, 1)
    lr_t = lr_ini * (lr_decay ** t0)
    for g in optimizer.param_groups:
        g['lr'] = lr_t
    theta_old = theta.detach().clone()
    optimizer.zero_grad()
    d0_phase = d0_steady.detach() if warm_up else None
    L = spg_objective_ks(theta, N_sample, T_horizon, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t,
                          Ty_t, Tz_t, nb_spg, ny_spg, nz_spg, nr_spg, nw_spg, beta, d0=d0_phase, warm_up=warm_up)
    loss_hist.append(L.item())
    (-L).backward()
    optimizer.step()
    param_change = (theta.detach() - theta_old).abs().max().item()
    if param_change < e_converge:
        print(f"Converged at epoch {epoch+1}, |Δθ| = {param_change:.2e}")
        break
    if (epoch + 1) % 50 == 0 or (epoch + 1) <= 3 or (epoch + 1) == N_warmup:
        phase = "warm-up (d fixed)" if warm_up else "adaptive (d evolves)"
        print(f"Epoch {epoch+1}, L(θ) = {L.item():.6f}, lr = {lr_t:.2e}, {phase}")
plt.figure(figsize=(8, 4))
plt.plot(loss_hist)
plt.xlabel('Epoch')
plt.ylabel('L(θ)')
plt.title('Krusell-Smith SRL/SPG (two-phase)')
plt.grid(True, alpha=0.3)
plt.show()

## Using Trained Policy

After training, use `policy_from_grid_ks` with the consumption grid from θ (same idea as Huggett's policy_from_grid).

In [None]:
# ---------- 使用训练好的政策：θ→消费网格→policy_from_grid_ks / Use trained policy ----------
try:
    theta_trained = theta.detach()
    c_grid_ks = theta_to_c_grid(theta_trained, b_grid_t, y_grid_t, z_grid_t, r_grid_t, w_grid_t)
    c_grid_ks_np = c_grid_ks.cpu().numpy()
    def policy_trained_ks(b, y, r, w, z):
        return policy_from_grid_ks(b, y, r, w, z, c_grid_ks_np,
            b_grid_t.cpu().numpy(), y_grid_t.cpu().numpy(), z_grid_t.cpu().numpy(),
            r_grid_t.cpu().numpy(), w_grid_t.cpu().numpy(), ny_spg)
    print("Testing trained policy:")
    b_test = np.array([1.0, 5.0, 10.0])
    y_test = y_grid[1]
    r_test, w_test = (r_min + r_max) / 2.0, (w_min + w_max) / 2.0
    z_test = z_grid[len(z_grid)//2]
    c_test, n_test, b_next_test = policy_trained_ks(b_test, y_test, r_test, w_test, z_test)
    print(f"  b = {b_test}, y = {y_test:.3f}, r = {r_test:.3f}, w = {w_test:.3f}, z = {z_test:.3f}")
    print(f"  c = {c_test}, n = {n_test}, b_next = {b_next_test}")
except NameError:
    print('Run the SPG training cell above first to define theta.')

## Summary

- **Model:** Krusell–Smith (1998): (c, n) = π(b, y, r, w); budget c + b' = (1+r)b + w·n·y; **total labour L = 1**. Prices (r, w) = P*(K, z) with L=1.
- **State dynamics:** y, z from Ty, Tz (same as Huggett).
- **Policy:** Raw (c, b') from θ; transition uses soft weights in `update_d_pi_direct_ks` (no full A matrix), same design as Huggett.
- **SRL/SPG:** Two-phase training: warm-up with fixed d0 (steady-state under initial θ), then d evolves via direct map d_{t+1} = (implied A)@d_t. Gradient-stop on (r, w); L(θ) = E[Σ_t β^t d_t' u(c_t)].