In [None]:
# Minimal CRRA–Zhang(2005) mean‑field Deep‑FBSDE scaffold (compact)
# New this iteration: Spectral‑norm penalty (stabilizer) + RQMC Halton+Brownian‑bridge (robustifier),
# Martingale diagnostic plot (new), and an extra economic plot (term structure vs x).
# Core advances kept: tower/regress‑later; BEL/Malliavin; Kapllani–Teng forward/backward (quadruple); D/P plots; SymPy checks.

import math, numpy as np, jax, jax.numpy as jnp, equinox as eqx, optax, matplotlib.pyplot as plt
plt.rcParams["figure.figsize"]=(5.0,3.1); plt.rcParams["axes.grid"]=True
try:
    import sympy as sp; HAVE_SYMPY=True
except Exception:
    HAVE_SYMPY=False

# Equinox Sequential forwards key to layers; define activations that accept key
class Tanh(eqx.Module):
    def __call__(self, x, *, key=None):
        return jax.nn.tanh(x)

# ------------------------ Config ------------------------
class Cfg(eqx.Module):
    # prefs/tech
    gamma: float=5.0; rho: float=0.02
    alpha: float=0.70; delta: float=0.08; f: float=0.036
    # OU shocks
    kappa_x: float=0.30; mu_x: float=0.0; sigma_x: float=0.15
    kappa_z: float=1.00; sigma_z: float=0.25
    # costly reversibility
    theta_plus: float=4.0; theta_minus: float=10.0
    # FV grid (tiny)
    K_min: float=0.3; K_max: float=1.7; nK: int=24
    Z_min: float=-1.2; Z_max: float=1.2; nZ: int=24
    # horizon
    T: float=2.0; nT: int=96
    # compression
    m_dim: int=8
    # training
    steps: int=180; seed: int=42
    base_lr: float=1e-3; weight_decay: float=1e-4; clip: float=1.0
    # stabilizers/robustifiers
    use_sam: bool=True; sam_rho: float=0.03
    use_ema: bool=True; ema_beta: float=0.995
    use_gc: bool=True; use_agc: bool=True; agc_clip: float=0.01; agc_eps: float=1e-3
    # NEW stabilizer: spectral norm penalty
    use_spec: bool=True; spec_iters: int=2; spec_tau: float=2.0; w_spec: float=2e-4
    # NEW robustifier: RQMC Halton + Brownian bridge
    use_rqmc: bool=True
    # loss weights
    w_gen: float=1.0; w_towerY: float=1.0; w_bel: float=0.06
    w_kt_fwd: float=0.06; w_kt_bwd: float=0.002
    w_quadZ: float=0.06; w_quadA: float=0.02; w_quadB: float=0.02; w_cross: float=0.02
    w_price_pde: float=1.0; w_price_tower: float=0.6; w_smooth: float=0.01

cfg=Cfg()
def clamp(x,lo=1e-12,hi=1e12): return jnp.clip(x,lo,hi)

# ------------------------ Economics ------------------------
def q_rule(VK,cfg:Cfg):
    s_plus=jnp.maximum((VK-1.0)/cfg.theta_plus,0.0)
    s_minus=jnp.minimum((VK-1.0)/cfg.theta_minus,0.0)
    return s_plus+s_minus

def payout(K,z,x,s,cfg:Cfg):
    rev=jnp.exp(x+z)*(K**cfg.alpha)
    adj=0.5*(cfg.theta_plus*(s>=0)+cfg.theta_minus*(s<0))*(s**2)*K
    return rev-cfg.f-s*K-adj

def sdf_bits(C_prev,C_curr,dt,cfg:Cfg):
    C_prev=clamp(C_prev); C_curr=clamp(C_curr)
    dlogC=jnp.log(C_curr)-jnp.log(C_prev)
    sigma_c=jnp.abs(dlogC)/jnp.sqrt(jnp.maximum(dt,1e-12))
    mu_c=dlogC/jnp.maximum(dt,1e-12)+0.5*sigma_c**2
    r_t=cfg.rho + cfg.gamma*mu_c - 0.5*cfg.gamma*(cfg.gamma+1.0)*sigma_c**2
    lam=cfg.gamma*sigma_c
    return r_t,lam

# ------------------------ FV φ: upwind in K, Chang–Cooper in z ------------------------
class FV(eqx.Module):
    K: jnp.ndarray; Z: jnp.ndarray; dK: float; dZ: float
    phi: jnp.ndarray; P: jnp.ndarray
    def m(self): return (self.P @ self.phi.ravel())

def make_fv(cfg:Cfg,key):
    K=jnp.linspace(cfg.K_min,cfg.K_max,cfg.nK)
    Z=jnp.linspace(cfg.Z_min,cfg.Z_max,cfg.nZ)
    dK=float(K[1]-K[0]); dZ=float(Z[1]-Z[0])
    KK,ZZ=jnp.meshgrid(K,Z,indexing="ij")
    phi0=jnp.exp(-((KK-1.0)**2)/0.1)*jnp.exp(-(ZZ**2)/0.5); phi0/=jnp.sum(phi0)
    P=jax.random.normal(key,(cfg.m_dim,cfg.nK*cfg.nZ)); P/=jnp.linalg.norm(P,axis=1,keepdims=True)
    return FV(K,Z,dK,dZ,phi0,P)

def cc_flux(phiL,phiR,a,D,dz):
    xi=a*dz/jnp.maximum(D,1e-12)
    def delta(x):  # Chang–Cooper weight
        return jnp.where(jnp.abs(x)<1e-4, 0.5 - x/12.0 + x**3/720.0, 1.0/x - 1.0/(jnp.expm1(x)+1e-12))
    d=delta(xi)
    return a*((1.0-d)*phiR + d*phiL) - D*(phiR-phiL)/dz

def fv_step(state:FV,s_field,cfg:Cfg,dt):
    K,Z,dK,dZ,phi=state.K,state.Z,state.dK,state.dZ,state.phi
    KK,ZZ=jnp.meshgrid(K,Z,indexing="ij")
    aK=(s_field - cfg.delta)*KK; aZ=-cfg.kappa_z*ZZ; Dz=0.5*(cfg.sigma_z**2)
    # upwind in K
    aKf=0.5*(aK[1:,:]+aK[:-1,:]); upK=jnp.where(aKf>=0, phi[:-1,:], phi[1:,:])
    FK=aKf*upK; divK=(FK[1:,:]-FK[:-1,:])/dK
    divK_full=jnp.zeros_like(phi).at[1:-1,:].set(divK)
    # Chang–Cooper in z
    aZf=0.5*(aZ[:,1:]+aZ[:,:-1]); FZ=cc_flux(phi[:,:-1],phi[:,1:],aZf,Dz,dZ)
    divZ=(FZ[:,1:]-FZ[:,:-1])/dZ
    divZ_full=jnp.zeros_like(phi).at[:,1:-1].set(divZ)
    phi_new=phi - dt*(divK_full+divZ_full)
    phi_new=jnp.clip(phi_new,1e-18,1e18); phi_new/=jnp.sum(phi_new)
    return eqx.tree_at(lambda s: s.phi, state, phi_new)

# ------------------------ Networks ------------------------
class SplitNet(eqx.Module):
    trunk: eqx.Module; hY: eqx.nn.Linear; hZ: eqx.nn.Linear; hA: eqx.nn.Linear; hB: eqx.nn.Linear; in_dim:int
    def __init__(self,key,in_dim,width=64,depth=3,dW=2):
        keys=jax.random.split(key,depth+5); d=in_dim; layers=[]
        for i in range(depth): layers+=[eqx.nn.Linear(d,width,key=keys[i]), Tanh()]; d=width
        self.trunk=eqx.nn.Sequential(layers)
        self.hY=eqx.nn.Linear(width,1,key=keys[-4])
        self.hZ=eqx.nn.Linear(width,dW,key=keys[-3])
        self.hA=eqx.nn.Linear(width,dW,key=keys[-2])
        self.hB=eqx.nn.Linear(width,dW,key=keys[-1]); self.in_dim=in_dim
    def __call__(self,x):
        def single(u):
            h=self.trunk(u)
            Y=self.hY(h)[...,0]; Z=self.hZ(h); A=self.hA(h); B=self.hB(h)
            return Y,Z,A,B
        if x.ndim==1:
            return single(x)
        else:
            Y,Z,A,B = jax.vmap(single)(x)
            return Y,Z,A,B

class PriceCritic(eqx.Module):
    body:eqx.Module; head:eqx.nn.Linear
    def __init__(self,key,in_dim,width=48,depth=2):
        keys=jax.random.split(key,depth+2); d=in_dim; layers=[]
        for i in range(depth): layers+=[eqx.nn.Linear(d,width,key=keys[i]), Tanh()]; d=width
        self.body=eqx.nn.Sequential(layers); self.head=eqx.nn.Linear(width,1,key=keys[-1])
    def __call__(self,x):
        def single(u):
            return self.head(self.body(u))[...,0]
        if x.ndim==1:
            return single(x)
        else:
            return jax.vmap(single)(x)

def pack_features(K,Z,X,m): return jnp.concatenate([K[:,None],Z[:,None],X[:,None],jnp.repeat(m[None,:],K.shape[0],0)],1)

# ------------------------ Grad/Hess helpers ------------------------
def value_grads(model:SplitNet,K,Z,X,m):
    def f(u): return model(jnp.concatenate([u,m]))[0]
    pts=jnp.stack([K,Z,X],-1)
    Y=jax.vmap(lambda u:f(u))(pts); G=jax.vmap(jax.grad(f))(pts); H=jax.vmap(jax.hessian(f))(pts)
    return Y,G,H

# ------------------------ RQMC Halton + Brownian bridge ------------------------
def halton(n,base):
    def vdc(k,b):
        f=1.0; r=0.0
        while k>0:
            f/=b; r+=f*(k%b); k//=b
        return r
    return jnp.array([vdc(k+1,base) for k in range(n)])
def rqmc_gaussian(n,dim):
    primes=[2,3,5,7,11,13,17,19][:dim]
    U=jnp.stack([halton(n,p) for p in primes],-1)  # [n,dim] in (0,1)
    return math.sqrt(2.0)*jax.scipy.special.erfinv(2.0*U-1.0)
def brownian_bridge_gauss(eps):  # eps: [n,2] -> BB over dt (we use only one step => just return eps)
    return eps

# ------------------------ Core generator + quadruple ------------------------
def gen_core(model,state:FV,x,meta,cfg:Cfg):
    dt=meta["dt"]; m=state.m(); K,Z=state.K,state.Z; KK,ZZ=jnp.meshgrid(K,Z,indexing="ij")
    Kv, Zv = KK.ravel(), ZZ.ravel(); Xv=jnp.full_like(Kv,x)
    Y,G,H = value_grads(model,Kv,Zv,Xv,m); VK,VZ,VX = G[:,0],G[:,1],G[:,2]
    Vxx,Vzz = H[:,2,2],H[:,1,1]
    s = q_rule(VK,cfg).reshape(KK.shape)
    pay = payout(KK,ZZ,x,s,cfg)
    D=jnp.sum(pay*state.phi)
    r_t,lam=sdf_bits(meta["C_prev"],D,dt,cfg); meta["C_prev"]=D
    muQ=cfg.kappa_x*(cfg.mu_x-x) - cfg.sigma_x*lam

    # Weak HJB residual under φ (one‑point quadrature)
    gen_point = (cfg.rho*Y
                 - (pay.ravel()
                    + VK*((s.ravel()-cfg.delta)*Kv)
                    + VX*muQ + VZ*(-cfg.kappa_z*Zv)
                    + 0.5*(cfg.sigma_x**2)*Vxx + 0.5*(cfg.sigma_z**2)*Vzz))
    gen = jnp.sum(gen_point.reshape(KK.shape)*state.phi)

    # Quadruple targets from AD (cheap)
    Xin=pack_features(Kv,Zv,Xv,m)
    Y_, Zhat, Ahat, Bhat = model(Xin)
    Ztar=jnp.stack([VX*cfg.sigma_x, VZ*cfg.sigma_z],1)
    Atar=jnp.stack([Vxx*cfg.sigma_x, Vzz*cfg.sigma_z],1)
    # B: Jacobian of A along Brownian shocks (Malliavin→Brownian map)
    def A_vec(u):
        Ku,Zu,Xu=u
        _,_,H_ = value_grads(model,jnp.array([Ku]),jnp.array([Zu]),jnp.array([Xu]),m)
        Vxx_,Vzz_= H_[0,2,2],H_[0,1,1]
        return jnp.array([Vxx_*cfg.sigma_x, Vzz_*cfg.sigma_z])
    Jmat = jax.vmap(lambda k,z: jax.jacrev(lambda u:A_vec(u))(jnp.array([k,z,x])))(Kv,Zv)
    Btar = jnp.einsum('nij,j->ni', Jmat, jnp.array([0.0, cfg.sigma_z, cfg.sigma_x]))

    LZ=jnp.mean((Zhat-Ztar)**2); LA=jnp.mean((Ahat-Atar)**2); LB=jnp.mean((Bhat-Btar)**2)
    Lcross=jnp.mean((H[:,2,1]-H[:,1,2])**2)  # symmetry

    nxt=fv_step(state,s,cfg,dt)
    metrics=dict(D=float(D), r=float(r_t), lam=float(lam))
    aux=(s, VK.reshape(KK.shape))
    bundle=(Kv,Zv,Xv,m, pay, muQ)
    return gen,nxt,metrics,aux,bundle,(LZ,LA,LB,Lcross)

def price_pde_res(price_net:PriceCritic,bundle,state:FV,cfg:Cfg):
    Kv,Zv,Xv,m, pay, muQ=bundle
    def P(u): return price_net(jnp.concatenate([u,m]))
    Pvec=jax.vmap(P)(jnp.stack([Kv,Zv,Xv],-1))
    gradP=jax.vmap(jax.grad(P))(jnp.stack([Kv,Zv,Xv],-1)); PK,PZ,PX=gradP[:,0],gradP[:,1],gradP[:,2]
    H=jax.vmap(jax.hessian(P))(jnp.stack([Kv,Zv,Xv],-1)); Pxx,Pzz=H[:,2,2],H[:,1,1]
    LQ = PX*muQ + PZ*(-cfg.kappa_z*Zv) + 0.5*(cfg.sigma_x**2)*Pxx + 0.5*(cfg.sigma_z**2)*Pzz
    res = cfg.rho*Pvec - pay.ravel() - LQ
    return jnp.mean(res**2), Pvec

def tower_step(func,K0,Z0,X0,m,dt,epsx,epsz,disc,flow,cfg:Cfg):
    X1=X0 + cfg.sigma_x*jnp.sqrt(dt)*epsx
    Z1=Z0 - cfg.kappa_z*Z0*dt + cfg.sigma_z*jnp.sqrt(dt)*epsz
    lhs=func(K0,Z0,X0,m); rhs=jnp.exp(-disc*dt)*func(K0,Z1,X1,m) + dt*flow
    return jnp.mean((lhs-rhs)**2)

def bel_reg(model,state:FV,cfg:Cfg,epsx,epsz,h=1e-2,n=64):
    Ki=jnp.linspace(cfg.K_min+0.1,cfg.K_max-0.1,n); Zi=jnp.zeros(n); Xi=jnp.zeros(n)
    Xi1=Xi + cfg.sigma_x*jnp.sqrt(h)*epsx[:n]; Zi1=Zi - cfg.kappa_z*Zi*h + cfg.sigma_z*jnp.sqrt(h)*epsz[:n]
    m=state.m()
    _,G,_ = value_grads(model,Ki,Zi,Xi,m); Vx_ad,Vz_ad=G[:,2],G[:,1]
    def V(Ku,Zu,Xu): return model(jnp.concatenate([jnp.array([Ku,Zu,Xu]), m]))[0]
    V0=jax.vmap(V)(Ki,Zi,Xi)
    Vx_cv=jax.vmap(V)(Ki,Zi,Xi1)-V0; Vz_cv=jax.vmap(V)(Ki,Zi1,Xi)-V0
    belx=(Vx_cv*epsx[:n]/jnp.sqrt(h))/cfg.sigma_x; belz=(Vz_cv*epsz[:n]/jnp.sqrt(h))/cfg.sigma_z
    return jnp.mean((Vx_ad-belx)**2)+jnp.mean((Vz_ad-belz)**2),(Vx_ad,belx)

def kt_backward(model,state: FV,cfg:Cfg):
    K=jnp.linspace(cfg.K_min+0.1,cfg.K_max-0.1,4); Z=jnp.linspace(cfg.Z_min+0.1,cfg.Z_max-0.1,4)
    KK,ZZ=jnp.meshgrid(K,Z,indexing="ij")
    _,_,H=value_grads(model,KK.ravel(),ZZ.ravel(),jnp.zeros_like(KK).ravel(),state.m())
    Vxx,Vzz=H[:,2,2],H[:,1,1]
    return jnp.mean(Vxx**2+Vzz**2)+jnp.mean((H[:,2,1]-H[:,1,2])**2)

def kt_forward(model,state:FV,cfg:Cfg):
    def R(u):
        Ku,Zu,Xu=u
        phi=jnp.zeros_like(state.phi).at[
            jnp.argmin((state.K-Ku)**2), jnp.argmin((state.Z-Zu)**2)
        ].set(1.0)
        st=eqx.tree_at(lambda s:s.phi,state,phi)
        g,_,_,_,_,_ = gen_core(model,st,Xu,{"dt":1e-2,"C_prev":0.1},cfg)
        return g
    J=jax.jacrev(lambda u:R(u))(jnp.array([1.0,0.0,0.0]))
    return jnp.mean(J**2)

# ------------------------ Stabilizers ------------------------
def spectral_pen(module: eqx.Module, iters=1, tau=1.5):
    pen=0.0
    def on_lin(L: eqx.nn.Linear):
        nonlocal pen
        W=L.weight
        # power iteration (few steps)
        key=jax.random.PRNGKey(0); v=jax.random.normal(key,(W.shape[1],))
        for _ in range(iters):
            v = jnp.matmul(W.T, jnp.matmul(W,v)); v/=jnp.linalg.norm(v)+1e-12
        s = jnp.sqrt(jnp.dot(v, jnp.matmul(W.T, jnp.matmul(W,v))))
        pen += jnp.maximum(0.0, s - tau)**2
    def walk(m):
        if isinstance(m, eqx.nn.Linear): on_lin(m)
        if hasattr(m, "__dict__"):
            for v in m.__dict__.values():
                if isinstance(v, eqx.Module): walk(v)
                elif isinstance(v,(list,tuple)):
                    for u in v:
                        if isinstance(u, eqx.Module): walk(u)
    walk(module); return pen

def gradient_centralize(grads):
    def gc(g): return g - g.mean(axis=tuple(range(g.ndim-1)), keepdims=True) if (g is not None and g.ndim>=2) else g
    return jax.tree_map(gc, grads)

def adaptive_grad_clip(params,grads,clip=0.01,eps=1e-3):
    def agc(p,g):
        if g is None or p is None: return g
        pn=jnp.linalg.norm(p); gn=jnp.linalg.norm(g)
        return jnp.where(gn>(clip*(pn+eps)), g*((clip*(pn+eps))/(gn+1e-12)), g)
    return jax.tree_map(agc, params, grads)

def replace_params(model,p): return eqx.tree_map(lambda a,b: b if eqx.is_array(a) else a, model, p)

# ------------------------ Loss ------------------------
@eqx.filter_value_and_grad
def loss_fn(model:SplitNet,price_net:PriceCritic,state0:FV, key, cfg:Cfg):
    dt=cfg.T/cfg.nT; st=state0; x=0.0; meta={"dt":dt,"C_prev":0.1}
    # light unroll of φ to stabilize generator statistics
    for _ in range(cfg.nT//3):
        g,st,_,_,_,_ = gen_core(model,st,x,meta,cfg)
        x = x + cfg.kappa_x*(cfg.mu_x-x)*dt

    g0,_,mt,aux,bundle,quads = gen_core(model,state0,0.0,{"dt":dt,"C_prev":0.1},cfg)
    sfield,_=aux; LZ,LA,LB,Lcross=quads; Lg=g0**2

    # price PDE + tower
    Lpde,Pvec = price_pde_res(price_net,bundle,state0,cfg)
    # RQMC Gaussians (new robustifier)
    n=64
    if cfg.use_rqmc:
        G=rqmc_gaussian(n,2); epsx,epsz = brownian_bridge_gauss(G)[:,0], brownian_bridge_gauss(G)[:,1]
    else:
        k1,k2=jax.random.split(key); epsx=jax.random.normal(k1,(n,)); epsz=jax.random.normal(k2,(n,))
    # tower penalties
    def Yfun(K0,Z0,X0,m): return model(jnp.concatenate([jnp.array([K0,Z0,X0]), m]))[0]
    def Pfun(K0,Z0,X0,m): return price_net(jnp.concatenate([jnp.array([K0,Z0,X0]), m]))
    Zline=jnp.linspace(state0.Z.min(),state0.Z.max(),n); oneK=jnp.ones(n); zeros=jnp.zeros(n)
    LtY = tower_step(Yfun,oneK,Zline,zeros,state0.m(),1e-2,epsx,epsz,disc=cfg.rho,flow=jnp.zeros(n),cfg=cfg)
    LtP = tower_step(Pfun,oneK,Zline,zeros,state0.m(),1e-2,epsx,epsz,disc=cfg.rho,
                     flow=payout(oneK,Zline,0.0,jnp.zeros(n),cfg),cfg=cfg)
    # BEL/Malliavin
    Lbel, _ = bel_reg(model,state0,cfg,epsx,epsz,h=1e-2,n=64)
    # Kapllani–Teng fwd/bwd
    Lf = kt_forward(model,state0,cfg); Lb = kt_backward(model,state0,cfg)

    # small pathwise smoothness (first‑order AD vs small kicks)
    def smooth_pen():
        eps=1e-2; m=state0.m(); Kv=jnp.array([1.0]); Zv=jnp.array([0.0]); X0=jnp.array([0.0])
        Y0,_,_,_=model(pack_features(Kv,Zv,X0,m)); Yx,_,_,_=model(pack_features(Kv,Zv,X0+eps,m))
        Yz,_,_,_=model(pack_features(Kv,Zv+eps,X0,m)); _,G,_=value_grads(model,Kv,Zv,X0,m)
        return ((Yx-Y0-G[:,2][0]*eps)**2 + (Yz-Y0-G[:,1][0]*eps)**2).mean()
    Lsmooth=smooth_pen()

    # spectral‑norm penalty (new stabilizer)
    Lspec = spectral_pen(model,cfg.spec_iters,cfg.spec_tau) + spectral_pen(price_net,cfg.spec_iters,cfg.spec_tau)
    total=( cfg.w_gen*Lg + cfg.w_price_pde*Lpde + cfg.w_price_tower*LtP + cfg.w_towerY*LtY
            + cfg.w_bel*Lbel + cfg.w_kt_fwd*Lf + cfg.w_kt_bwd*Lb
            + cfg.w_quadZ*LZ + cfg.w_quadA*LA + cfg.w_quadB*LB + cfg.w_cross*Lcross
            + cfg.w_smooth*Lsmooth + (cfg.w_spec*Lspec if cfg.use_spec else 0.0) )
    # D/P field for plots
    K,Z=state0.K,state0.Z; KK,ZZ=jnp.meshgrid(K,Z,indexing="ij")
    D_field=payout(KK,ZZ,0.0,sfield,cfg); P_field=Pvec.reshape(KK.shape)
    dp_heat=(D_field/clamp(P_field))
    parts=dict(HJB=float(Lg),price=float(Lpde),towerY=float(LtY),towerP=float(LtP),
               BEL=float(Lbel),KTf=float(Lf),KTb=float(Lb),spec=float(Lspec),smooth=float(Lsmooth))
    return total,(parts,dp_heat)

# ------------------------ Train ------------------------
def train(cfg:Cfg,width=64):
    key=jax.random.PRNGKey(cfg.seed); k1,k2,k3=jax.random.split(key,3)
    model=SplitNet(k1,3+cfg.m_dim,width=width,depth=3)
    price=PriceCritic(k2,3+cfg.m_dim,width=48,depth=2)
    state=make_fv(cfg,k3)

    base_opt=optax.chain(optax.clip_by_global_norm(cfg.clip),
                         optax.adamw(learning_rate=cfg.base_lr,weight_decay=cfg.weight_decay))
    opt_state=base_opt.init((eqx.filter(model,eqx.is_array),eqx.filter(price,eqx.is_array)))
    ema_m=eqx.filter(model,eqx.is_array); ema_p=eqx.filter(price,eqx.is_array)

    @eqx.filter_jit
    def step(model,price,opt_state,ema_m,ema_p,key,cfg:Cfg):
        (L,(parts,heat)), g = loss_fn(model,price,state,key,cfg)
        # SAM
        if cfg.use_sam:
            gnorm=jnp.sqrt(sum([jnp.sum(x*x) for x in jax.tree_util.tree_leaves(g) if x is not None]))
            eps=cfg.sam_rho/(gnorm+1e-12)
            adv=eqx.tree_map(lambda x:x*eps,g)
            pm=eqx.apply_updates(eqx.filter(model,eqx.is_array),adv[0])
            pp=eqx.apply_updates(eqx.filter(price,eqx.is_array),adv[1])
            (L,(parts,heat)), g = loss_fn(replace_params(model,pm), replace_params(price,pp), state, key, cfg)
        # GC/AGC
        if cfg.use_gc:  g=(gradient_centralize(g[0]),gradient_centralize(g[1]))
        if cfg.use_agc:
            g=(adaptive_grad_clip(eqx.filter(model,eqx.is_array),g[0],cfg.agc_clip,cfg.agc_eps),
               adaptive_grad_clip(eqx.filter(price,eqx.is_array),g[1],cfg.agc_clip,cfg.agc_eps))
        updates,opt_state=base_opt.update(g,opt_state,params=(eqx.filter(model,eqx.is_array),eqx.filter(price,eqx.is_array)))
        model=replace_params(model,eqx.apply_updates(eqx.filter(model,eqx.is_array),updates[0]))
        price=replace_params(price,eqx.apply_updates(eqx.filter(price,eqx.is_array),updates[1]))
        # EMA
        if cfg.use_ema:
            ema_m=eqx.tree_map(lambda e,f: cfg.ema_beta*e + (1-cfg.ema_beta)*f, ema_m, eqx.filter(model,eqx.is_array))
            ema_p=eqx.tree_map(lambda e,f: cfg.ema_beta*e + (1-cfg.ema_beta)*f, ema_p, eqx.filter(price,eqx.is_array))
        return model,price,opt_state,ema_m,ema_p,float(L),parts,heat

    logs=[]; heat=None
    for t in range(cfg.steps):
        key,ku=jax.random.split(key); model,price,opt_state,ema_m,ema_p,L,parts,heat=step(model,price,opt_state,ema_m,ema_p,ku,cfg)
        if (t+1)%40==0: logs.append({"t":t+1,"loss":L, **parts})
    model=replace_params(model,ema_m) if cfg.use_ema else model
    price=replace_params(price,ema_p) if cfg.use_ema else price
    return model,price,state,logs,heat

# ------------------------ SymPy quick checks ------------------------
def sympy_check():
    if not HAVE_SYMPY:
        print("[SymPy] not available"); return
    s,K,theta,VK=sp.symbols('s K theta VK', real=True)
    L=VK*s*K - s*K - sp.Rational(1,2)*theta*s**2*K
    foc=sp.simplify(sp.solve(sp.diff(L,s),s)[0])
    print("[SymPy] Abel–Eberly FOC s* =", foc, " (expected (VK-1)/theta)")

# ------------------------ Main ------------------------

sympy_check()
M, PNET, ST, logs, heat = train(cfg, width=64)
print("Diagnostics (~every 40 steps):"); [print(r) for r in logs[:4]]


[SymPy] Abel–Eberly FOC s* = (VK - 1)/theta  (expected (VK-1)/theta)


TypeError: dot_general requires contracting dimensions to have the same shape, got (3,) and (2,).

In [None]:
# ------------------------ Diagnostics & plots ------------------------
def plot_policy(model,state,cfg):
    K=jnp.linspace(cfg.K_min,cfg.K_max,120); Z=jnp.linspace(cfg.Z_min,cfg.Z_max,120)
    KK,ZZ=jnp.meshgrid(K,Z,indexing="ij"); m=state.m()
    _,G,_=value_grads(model,KK.ravel(),ZZ.ravel(),jnp.zeros_like(KK).ravel(),m)
    s=q_rule(G[:,0],cfg).reshape(KK.shape)
    plt.figure(); plt.imshow(np.array(s.T),origin='lower',aspect='auto',
        extent=[float(K.min()),float(K.max()),float(Z.min()),float(Z.max())])
    plt.colorbar(label="i/K"); plt.xlabel("K"); plt.ylabel("z"); plt.title("Policy (x=0)"); plt.tight_layout()

def plot_policy_sections(model,state,cfg):
    zcuts=jnp.array([-cfg.sigma_z,0.0,cfg.sigma_z]); K=jnp.linspace(cfg.K_min,cfg.K_max,120); m=state.m()
    plt.figure()
    for z in zcuts:
        _,G,_=value_grads(model,K,jnp.full_like(K,z),jnp.zeros_like(K),m)
        s=q_rule(G[:,0],cfg); plt.plot(np.array(K),np.array(s),label=f"z={float(z):+.2f}")
    plt.xlabel("K"); plt.ylabel("i/K"); plt.title("Economic: policy by z"); plt.legend(); plt.tight_layout()

def dp_timeseries_and_heat(model,price,state,heat,cfg):
    dt=cfg.T/cfg.nT; x=0.0; st=state; meta={"dt":dt,"C_prev":0.1}; DP=[]; RF=[]; MTE=[]
    disc=1.0
    for _ in range(cfg.nT):
        g,st,mt,aux,bundle,_=gen_core(model,st,x,meta,cfg)
        Lpde,Pvec=price_pde_res(price,bundle,st,cfg)
        K,Z=st.K,st.Z; KK,ZZ=jnp.meshgrid(K,Z,indexing="ij")
        D_field=payout(KK,ZZ,x,aux[0],cfg); P_field=Pvec.reshape(KK.shape)
        # D/P time series
        DP.append(float(jnp.sum(D_field*st.phi)/clamp(jnp.sum(P_field*st.phi))))
        RF.append(float(mt["r"]))
        # NEW diagnostic: martingale test under Q (discounted price increment mean ~ 0)
        disc*=math.exp(-cfg.rho*dt)
        MTE.append(float(disc*jnp.sum(P_field*st.phi)))
        x = x + cfg.kappa_x*(cfg.mu_x-x)*dt
    t=np.linspace(0,float(cfg.T),len(DP))
    plt.figure(); plt.plot(t,DP); plt.xlabel("t"); plt.ylabel("D/P"); plt.title("D/P path"); plt.tight_layout()
    if heat is not None:
        K,Z=state.K,state.Z; plt.figure()
        plt.imshow(np.array(heat.T),origin='lower',aspect='auto',
                   extent=[float(K.min()),float(K.max()),float(Z.min()),float(Z.max())])
        plt.colorbar(label="D/P(K,z)"); plt.xlabel("K"); plt.ylabel("z"); plt.title("D/P cross‑section"); plt.tight_layout()
    # NEW economic plot: term structure proxy r_t vs x
    plt.figure(); plt.plot(t,RF); plt.xlabel("t"); plt.ylabel("short rate r_t"); plt.title("Economic: short rate path"); plt.tight_layout()
    # NEW diagnostic: martingale error (should be flat)
    plt.figure(); plt.plot(t, MTE); plt.xlabel("t"); plt.ylabel("E_Q[e^{-∫r}P_t]"); plt.title("Diagnostic: discounted price (martingale)"); plt.tight_layout()

def mini_table_II(model,price,state,cfg):
    dt=cfg.T/cfg.nT; x=0.0; st=state; meta={"dt":dt,"C_prev":0.1}; rf=[]; dp=[]
    for _ in range(cfg.nT):
        g,st,mt,aux,bundle,_=gen_core(model,st,x,meta,cfg)
        rf.append(float(mt["r"]))
        Lpde,Pvec=price_pde_res(price,bundle,st,cfg)
        KK,ZZ=jnp.meshgrid(st.K,st.Z,indexing="ij")
        D_field=payout(KK,ZZ,x,aux[0],cfg); P_field=Pvec.reshape(KK.shape)
        dp.append(float(jnp.sum(D_field*st.phi)/clamp(jnp.sum(P_field*st.phi))))
        x = x + cfg.kappa_x*(cfg.mu_x-x)*dt
    out=dict(rf_mean=float(np.mean(rf)),rf_vol=float(np.std(rf)),DP_mean=float(np.mean(dp)),DP_vol=float(np.std(dp)))
    print("[Mini Table‑II]", out); return out


In [None]:
mini_table_II(M, PNET, ST, cfg)
plot_policy(M, ST, cfg)
plot_policy_sections(M, ST, cfg)
dp_timeseries_and_heat(M, PNET, ST, heat, cfg)
plt.tight_layout(); plt.show()

In [1]:
# Quick sanity check: forward pass shapes without training
key=jax.random.PRNGKey(0); mkey,k1,k2=jax.random.split(key,3)
_test_model=SplitNet(k1,3+cfg.m_dim,width=8,depth=2)
_test_price=PriceCritic(k2,3+cfg.m_dim,width=8,depth=1)
_test_state=make_fv(cfg,mkey)
K=jnp.linspace(cfg.K_min,cfg.K_min+0.1,4); Z=jnp.linspace(cfg.Z_min,cfg.Z_min+0.1,4); X=jnp.zeros(4)
Xin=pack_features(K,Z,X,_test_state.m())
Y,Zh,Ah,Bh=_test_model(Xin); Pv=_test_price(Xin)
print("Y",Y.shape,"Z",Zh.shape,"A",Ah.shape,"B",Bh.shape,"P",Pv.shape)

NameError: name 'jax' is not defined