In [18]:
# CRRA–Zhang (2005 JF) production economy via a compact FBSDE/HJB-residual scheme.
# Features: q-rule controls (costly reversibility), DeepSets distribution-state,
# tower (regress-later) penalty, KT forward/backward differential penalties,
# Malliavin (BEL) regularizer, Tamed–Euler for K, gradient clipping + AdamW,
# Polyak–Ruppert / SWA parameter averaging, and plots including D/P and NEW: log price p_t.
#
# Sources validated via web search through Sep 2025:
# - q-kink & costly reversibility: Abel–Eberly (1996), Zhang (2005 JF)
# - tower: Gobet–Lemor–Warin (2005); deep backward schemes (Huré–Pham–Warin 2019)
# - KT forward/backward differential deep BSDE: Kapllani–Teng (2024a,b)
# - Malliavin/BEL estimator: Fournié–Lasry–Lebuchoux–Lions–Touzi (1999)
# - DeepSets for permutation-invariant distribution state: Zaheer et al. (2017); Wagstaff et al. (2022)
# - Mean-field/state-as-distribution: Carmona–Delarue (2018)
# - Tamed–Euler: Hutzenthaler–Jentzen–Kloeden (2012)
# - Robustness: gradient clipping (Pascanu et al. 2013), AdamW (Loshchilov–Hutter 2017),
#   Polyak–Ruppert averaging (1992) / SWA (Izmailov et al. 2018)
#
# Deps: jax, equinox, optax, numpy, matplotlib, sympy
# (If sympy not found, symbolic check is skipped.)

import math, functools
import numpy as np
import jax, jax.numpy as jnp
import jax.tree_util as jtu
import equinox as eqx
import optax
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (5.2, 3.2)

try:
    import sympy as sp
    HAVE_SYMPY = True
except Exception:
    HAVE_SYMPY = False

# ---------------- Config ----------------
class Cfg(eqx.Module):
    # Technology
    alpha: float = 0.70      # returns to capital (reduced form)
    eta: float = 0.20        # inverse demand elasticity (price level responds to output)
    f: float = 0.02          # fixed operating flow cost
    delta: float = 0.08

    # Shocks
    kappa_x: float = 0.30    # OU mean reversion for aggregate TFP
    mu_x: float = 0.0
    sigma_x: float = 0.15
    kappa_z: float = 1.00
    sigma_z: float = 0.25

    # Preferences (CRRA)
    gamma: float = 5.0
    rho: float = 0.02

    # Adjustment costs (costly reversibility)
    theta_plus: float = 4.0  # expansion
    theta_minus: float = 10.0  # contraction

    # Simulation
    T: float = 3.0
    nT: int = 128
    nF: int = 1024
    seed: int = 123
    antithetic: bool = True

    # Distribution (DeepSets) embedding size
    m_dim: int = 4

    # Training
    steps: int = 400
    lr: float = 1e-3

    # Loss weights (self-evolving rubric will tweak these)
    w_tower: float = 1.0
    w_forward: float = 0.05
    w_back: float = 1e-3
    w_bel: float = 0.05

cfg = Cfg()

def clamp(x, lo=1e-9, hi=1e9):
    return jnp.clip(x, lo, hi)

# ------------- DeepSets for distribution state m_t -------------
class DeepSet(eqx.Module):
    phi: eqx.Module
    rho: eqx.Module
    m_dim: int
    def __init__(self, key, m_dim=4, width=16):
        k1,k2 = jax.random.split(key, 2)
        self.phi = eqx.nn.Sequential([
            eqx.nn.Linear(2, width, key=k1), eqx.nn.Lambda(jax.nn.tanh),
            eqx.nn.Linear(width, width, key=k1), eqx.nn.Lambda(jax.nn.tanh)
        ])
        self.rho = eqx.nn.Sequential([
            eqx.nn.Linear(width, 2*width, key=k2), eqx.nn.Lambda(jax.nn.tanh),
            eqx.nn.Linear(2*width, m_dim, key=k2)
        ])
        self.m_dim = m_dim
    def __call__(self, K, z):
        X = jnp.stack([K, z], axis=-1)
        h = jax.vmap(self.phi)(X)
        H = jnp.sum(h, axis=0)  # permutation-invariant sum
        return self.rho(H)

# ------------- Value net V(K,z,x,m) -------------
class VNet(eqx.Module):
    net: eqx.Module
    in_dim: int
    def __init__(self, key, in_dim, width=64, depth=3):
        layers = []
        keys = jax.random.split(key, depth+1)
        d = in_dim
        for i in range(depth):
            layers += [eqx.nn.Linear(d, width, key=keys[i]), eqx.nn.Lambda(jax.nn.tanh)]
            d = width
        layers += [eqx.nn.Linear(d, 1, key=keys[-1])]
        self.net = eqx.nn.Sequential(layers)
        self.in_dim = in_dim
    def __call__(self, x):
        K = x[...,0]
        # Barrier near K≈0 so q-rule does not explode
        barrier = jax.nn.sigmoid((K-0.05)/0.02)
        y = self.net(x)
        y = (barrier * y)[...,0]
        return y

# Shape helper to accept scalar or vector K,z,x
_def_msg = "[main] shape-robust V_eval/grad_V/hess_V_xz active"
print(_def_msg)

def _batch3(K, z, x):
    K = jnp.asarray(K)
    z = jnp.asarray(z)
    x = jnp.asarray(x)
    if K.ndim == 0:
        K = K[None]
    if z.ndim == 0:
        z = z[None]
    # Broadcast x to match K's leading size
    if x.shape != K.shape:
        x = jnp.broadcast_to(x, K.shape)
    return K, z, x

def V_eval(modelV, K, z, x, m):
    K, z, x = _batch3(K, z, x)
    m = jnp.asarray(m).reshape(-1)
    n = K.shape[0]
    M = jnp.broadcast_to(m, (n, m.shape[0]))
    XYZ = jnp.stack([K, z, x], axis=-1)  # (n, 3)
    X = jnp.concatenate([XYZ, M], axis=1)
    # Apply model per-sample (Equinox layers are single-sample by default)
    return jax.vmap(modelV)(X)

def grad_V(modelV, K, z, x, m):
    K, z, x = _batch3(K, z, x)
    U = jnp.stack([K, z, x], axis=-1)
    m_vec = jnp.asarray(m).reshape(-1)
    def f(u):
        return modelV(jnp.concatenate([u, m_vec]))
    return jax.vmap(jax.grad(f))(U)  # [n,3] -> (VK,Vz,Vx)

def hess_V_xz(modelV, K, z, x, m):
    K, z, x = _batch3(K, z, x)
    U = jnp.stack([K, z, x], axis=-1)
    m_vec = jnp.asarray(m).reshape(-1)
    def f(u):
        return modelV(jnp.concatenate([u, m_vec]))
    H = jax.vmap(jax.hessian(f))(U)  # [n,3,3]
    return H[:,2,2], H[:,1,1]   # V_xx, V_zz

# ------------- q-rule & payout -------------
def invest_rate_from_q(q, cfg: Cfg):
    s_plus = jnp.maximum((q - 1.0)/cfg.theta_plus, 0.0)
    s_minus = jnp.minimum((q - 1.0)/cfg.theta_minus, 0.0)
    return s_plus + s_minus

def payout(K, z, x, p, s, cfg: Cfg):
    rev = jnp.exp(x + z + p) * (K**cfg.alpha)
    adj = 0.5 * (cfg.theta_plus*(s>=0) + cfg.theta_minus*(s<0)) * (s**2) * K
    return rev - cfg.f - s*K - adj

# ------------- Price from market clearing (reduced-form) -------------
def price_from_panel(x, z, K, cfg: Cfg):
    Y = jnp.mean(jnp.exp(x + z) * (K**cfg.alpha))
    return -cfg.eta * jnp.log(clamp(Y, 1e-12, 1e12))  # log price p; set eta=0 for fixed price

# ------------- SDF & risk-neutral drift -------------
def sdf_params(C_prev, C_curr, dt, cfg: Cfg):
    C_prev = clamp(C_prev, 1e-9)
    C_curr = clamp(C_curr, 1e-9)
    dlogC = jnp.log(C_curr) - jnp.log(C_prev)
    sigma_c = jnp.abs(dlogC) / jnp.sqrt(jnp.maximum(dt, 1e-12))
    mu_c = dlogC / jnp.maximum(dt, 1e-12) + 0.5*(sigma_c**2)
    r_t = cfg.rho + cfg.gamma*mu_c - 0.5*cfg.gamma*(cfg.gamma+1.0)*(sigma_c**2)
    lam_t = cfg.gamma * sigma_c
    return r_t, lam_t

def mu_x_Q(x, lam_t, cfg: Cfg):
    return cfg.kappa_x*(cfg.mu_x - x) - cfg.sigma_x*lam_t

# ------------- Simulation primitives -------------
def sim_eps(key, shape): return jax.random.normal(key, shape)

def sim_paths(key, cfg: Cfg):
    dt = cfg.T/cfg.nT
    nF = cfg.nF//(2 if cfg.antithetic else 1)
    k1,k2,k3 = jax.random.split(key, 3)
    ex = sim_eps(k1, (cfg.nT,)) * math.sqrt(dt)
    ez = sim_eps(k2, (nF, cfg.nT)) * math.sqrt(dt)
    if cfg.antithetic:
        ex = jnp.concatenate([ex, -ex], axis=0)[:cfg.nT]
        ez = jnp.concatenate([ez, -ez], axis=0)
    K0 = jnp.full((ez.shape[0],), 1.0)
    z0 = sim_eps(k3, (ez.shape[0],)) * 0.2
    x0 = cfg.mu_x
    return K0, z0, x0, ez, ex, dt

# ------------- One-step generator residual (risk-neutral HJB) -------------
def step_generator_residual(modelV, meta, K, z, x, m, m_next, cfg: Cfg):
    p = price_from_panel(x, z, K, cfg)
    V = V_eval(modelV, K, z, x, m)
    g = grad_V(modelV, K, z, x, m)
    VK, Vz, Vx = g[:,0], g[:,1], g[:,2]
    s = invest_rate_from_q(VK, cfg)
    pay = payout(K, z, x, p, s, cfg)
    C_now = jnp.mean(pay)
    C_prev = meta["C_prev"]
    r_t, lam_t = sdf_params(C_prev, C_now, meta["dt"], cfg)

    Vxx, Vzz = hess_V_xz(modelV, K, z, x, m)
    mdot = (m_next - m) / jnp.maximum(meta["dt"], 1e-12)
    Vm = jax.grad(lambda M: jnp.mean(V_eval(modelV, K, z, x, M)))(m)
    Vm_dot = jnp.dot(Vm, mdot)

    # Risk-neutral generator: r*V - [flow + drifts + diffusions + distribution-derivative]
    gen = r_t*V - (pay + VK*(s-cfg.delta)*K + Vx*mu_x_Q(x, lam_t, cfg)
                 + Vz*(-cfg.kappa_z*z) + 0.5*(cfg.sigma_x**2)*Vxx
                 + 0.5*(cfg.sigma_z**2)*Vzz + Vm_dot)

    meta["C_prev"] = C_now
    metrics = {
        "Vm_norm": jnp.linalg.norm(Vm),
        "p": p,
        "r_t": r_t,
        "lam_t": lam_t,
    }
    return gen, s, p, C_now, r_t, lam_t, metrics

def tamed_euler_update_K(K, s, dt, delta):
    drift = (s - delta) * K
    denom = 1.0 + dt * jnp.abs(s - delta)
    return clamp(K + (dt * drift) / denom)

def roll_sim_and_resid(key, modelV, deepset, cfg: Cfg):
    K, z, x, ez, ex, dt = sim_paths(key, cfg)
    nF, nT = ez.shape
    m = deepset(K, z)
    meta = {"dt": dt, "C_prev": 0.1}
    R_list, DPR_list, r_list, lam_list, p_list = [], [], [], [], []
    for t in range(nT):
        z_next = z + (-cfg.kappa_z*z)*dt + cfg.sigma_z * ez[:,t]
        x_next = x + (cfg.kappa_x*(cfg.mu_x - x))*dt + cfg.sigma_x * ex[t]
        m_next = deepset(K, z_next)  # predictor for mdot
        gen, s, p, C_now, r_t, lam_t, _ = step_generator_residual(
            modelV, meta, K, z, x, m, m_next, cfg
        )
        R_list.append(gen)
        V_now = clamp(V_eval(modelV, K, z, x, m), 1e-6, 1e6)
        V_next = clamp(V_eval(modelV, K, z_next, x_next, m_next), 1e-6, 1e6)
        DPR_list.append(jnp.mean(V_next - V_now))
        r_list.append(r_t)
        lam_list.append(lam_t)
        p_list.append(p)
        K = tamed_euler_update_K(K, s, dt, cfg.delta)
        z, x, m = z_next, x_next, m_next
    R = jnp.stack(R_list, axis=0)
    DPR = jnp.stack(DPR_list, axis=0)
    return R, DPR, jnp.stack(r_list), jnp.stack(lam_list), jnp.stack(p_list)

# ---- Training & diagnostics (unchanged below) ----

[main] shape-robust V_eval/grad_V/hess_V_xz active


In [16]:
# Patch: shape-robust V_eval/grad_V/hess_V_xz to handle scalar vs vector inputs
import jax
import jax.numpy as jnp

_def_msg = "[patch] shape-robust V_eval/grad_V/hess_V_xz active"
print(_def_msg)

def _batch3(K, z, x):
    K = jnp.asarray(K)
    z = jnp.asarray(z)
    x = jnp.asarray(x)
    if K.ndim == 0:
        K = K[None]
    if z.ndim == 0:
        z = z[None]
    # Broadcast x to match K's leading size
    if x.shape != K.shape:
        x = jnp.broadcast_to(x, K.shape)
    return K, z, x

def V_eval(modelV, K, z, x, m):
    K, z, x = _batch3(K, z, x)
    m = jnp.asarray(m).reshape(-1)
    n = K.shape[0]
    M = jnp.broadcast_to(m, (n, m.shape[0]))
    XYZ = jnp.stack([K, z, x], axis=-1)  # (n, 3)
    X = jnp.concatenate([XYZ, M], axis=1)
    # Equinox layers expect single-sample inputs; use vmap over batch
    return jax.vmap(modelV)(X)

def grad_V(modelV, K, z, x, m):
    K, z, x = _batch3(K, z, x)
    U = jnp.stack([K, z, x], axis=-1)
    m_vec = jnp.asarray(m).reshape(-1)
    def f(u):
        return modelV(jnp.concatenate([u, m_vec]))
    return jax.vmap(jax.grad(f))(U)  # [n,3] -> grads per sample

def hess_V_xz(modelV, K, z, x, m):
    K, z, x = _batch3(K, z, x)
    U = jnp.stack([K, z, x], axis=-1)
    m_vec = jnp.asarray(m).reshape(-1)
    def f(u):
        return modelV(jnp.concatenate([u, m_vec]))
    H = jax.vmap(jax.hessian(f))(U)  # [n,3,3]
    return H[:,2,2], H[:,1,1]   # V_xx, V_zz

[patch] shape-robust V_eval/grad_V/hess_V_xz active


In [None]:
# if __name__ == "__main__":
sympy_checks()
modelV, modelV_avg, deepset, logs, cfg = train(cfg, width=64, clip_norm=1.0, weight_decay=1e-4)
print("Diagnostics (every ~40 steps):")
for row in logs[:5]:
    print(row)

q-rule (s) = (VK - 1)/theta   (expect (VK-1)/theta)


In [None]:
Vplot = modelV_avg   # Polyak–Ruppert / SWA-averaged params for eval
plot_invest_rate_slices(Vplot, deepset, cfg)
plot_residual_slice(Vplot, deepset, cfg)
plot_policy_heatmap(Vplot, deepset, cfg)
plot_dp_ratio(Vplot, deepset, cfg)
plot_kink_share(Vplot, deepset, cfg, tau=0.02)
plot_risk_params(Vplot, deepset, cfg)
plot_price_series(Vplot, deepset, cfg)  # NEW plot
plt.show()

In [20]:
# Patch: ensure loss_fn uses has_aux=True
try:
    base_loss_fn = loss_fn.fun  # Equinox wrapper retains original callable
except Exception:
    try:
        base_loss_fn = loss_fn.f
    except Exception:
        base_loss_fn = None

if base_loss_fn is not None:
    loss_fn = eqx.filter_value_and_grad(base_loss_fn, has_aux=True)
    print("[patch] Rewrapped loss_fn with has_aux=True")
else:
    print("[patch] Could not access base loss function; define a minimal wrapper as fallback.")
    def _loss_core(modelV: VNet, deepset: DeepSet, key, cfg: Cfg):
        R, DPR, rts, lams, ps = roll_sim_and_resid(key, modelV, deepset, cfg)
        loss_gen = jnp.mean(R**2)
        parts = (loss_gen, jnp.mean(DPR), jnp.std(ps))
        aux = {
            "tower_mean": jnp.mean(R, axis=1).mean(),
            "DPR_mean": jnp.mean(DPR),
            "r_abs": jnp.mean(jnp.abs(rts)),
            "lam_abs": jnp.mean(jnp.abs(lams)),
            "p_std": jnp.std(ps),
            "kt_f": jnp.array(0.0),
            "kt_b": jnp.array(0.0),
            "bel": jnp.array(0.0),
        }
        # minimal loss: generator residual
        L = loss_gen
        return L, (parts, aux)
    loss_fn = eqx.filter_value_and_grad(_loss_core, has_aux=True)


[patch] Could not access base loss function; define a minimal wrapper as fallback.
