
# Minimal Stationary **HJB–FP × PINN** (JAX + Equinox)

Compact, single-notebook implementation of a stationary coupled **Hamilton–Jacobi–Bellman (HJB)** and **Fokker–Planck (FP)**
fixed point on a rectangle. Two tiny MLPs (value `V(k,z)` and density `μ(k,z)`) trained by residual losses.
**Reflecting (no-flux) boundaries** are penalized via **probability currents**. **Price clears** via a damped, log-space update.

*Backbone:* Achdou–Han–Lasry–Lions–Moll (REStud, 2022).  
*Irreversibility wedge:* Abel & Eberly (REStud, 1996).  
*PINNs & DGM:* Raissi et al. (JCP, 2019); Sirignano & Spiliopoulos (JCP, 2018).  
*Neural FP positivity:* Zhai–Dobson–Li (PMLR, 2022).  
*No‑flux boundaries:* standard FP current conditions.

**Non-goals:** No time dependence, no grids/FEM, no fancy samplers. Keep it tiny.


In [1]:
# Ensure deps (CPU) are available; install if missing
try:
    import jax, equinox, optax, sympy, matplotlib  # noqa: F401
except Exception:  # pragma: no cover
    import sys, subprocess
    pkgs = ["jax", "equinox>=0.11", "optax>=0.2", "sympy", "matplotlib"]
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", *pkgs])


In [2]:
import math, functools
from dataclasses import dataclass

import jax
import jax.numpy as jnp
from jax import jit, vmap, jacrev, hessian, random

import equinox as eqx
import optax

import matplotlib.pyplot as plt
from functools import partial

import sympy as sp

# Use float64 for stability (must be set at startup)
jax.config.update("jax_enable_x64", True)

key = random.PRNGKey(0)


In [3]:

@dataclass
class Params:
    # Domain
    k_min: float = 0.2
    k_max: float = 5.0
    z_max: float = 2.5

    # Technology / preferences (stylized)
    alpha: float = 0.35
    delta: float = 0.08
    r: float = 0.05

    # OU shock in z
    kappa_z: float = 0.5
    sigma_z: float = 0.2

    # Smooth costly (ir)reversibility
    theta_plus: float = 2.0
    theta_minus: float = 6.0
    beta_theta: float = 5.0

    # Price aggregator
    eta: float = 1.0         # P = Y^{-eta}
    x_log: float = 0.0       # y = exp(x + z) k^alpha
    lam_price: float = 0.7   # damping in log space

    # Training sizes
    n_colloc: int = 2000
    n_bndry: int = 1000     # total; split across edges
    n_mass: int  = 4000
    n_agg: int   = 4000

    # Optimization
    steps_V: int = 500
    steps_M: int = 500
    outer_steps: int = 8
    lr_V: float = 3e-3
    lr_M: float = 3e-3
    grad_clip: float = 1.0

    # Loss weights
    w_bndry: float = 1.0
    w_mass: float = 1.0

    # Early stop on price
    tol_logP: float = 5e-4

par = Params()
area = (par.k_max - par.k_min) * (2.0 * par.z_max)


In [4]:

class AffineBox(eqx.Module):
    k_min: float
    k_max: float
    z_max: float

    def __call__(self, k: jnp.ndarray, z: jnp.ndarray):
        # map box -> [-1,1]^2
        khat = 2.0 * (k - self.k_min) / (self.k_max - self.k_min) - 1.0
        zhat = z / self.z_max  # since z \in [-zmax,zmax]
        return jnp.stack([khat, zhat], axis=-1)

class ValueNet(eqx.Module):
    box: AffineBox
    mlp: eqx.nn.MLP

    def __init__(self, key, width=64, depth=3):
        box = AffineBox(par.k_min, par.k_max, par.z_max)
        mlp = eqx.nn.MLP(
            in_size=2, out_size=1, width_size=width, depth=depth,
            key=key, activation=jnp.tanh, final_activation=lambda y: y,
        )
        self.box, self.mlp = box, mlp

    def __call__(self, k: jnp.ndarray, z: jnp.ndarray):
        x = self.box(k, z)                 # inside map; autodiff sees physical inputs
        return self.mlp(x).squeeze(-1)     # scalar V

class DensityNet(eqx.Module):
    box: AffineBox
    mlp: eqx.nn.MLP

    def __init__(self, key, width=64, depth=3):
        box = AffineBox(par.k_min, par.k_max, par.z_max)
        mlp = eqx.nn.MLP(
            in_size=2, out_size=1, width_size=width, depth=depth,
            key=key, activation=jnp.tanh, final_activation=lambda y: y,
        )
        self.box, self.mlp = box, mlp

    def __call__(self, k: jnp.ndarray, z: jnp.ndarray):
        x = self.box(k, z)
        s = self.mlp(x).squeeze(-1)
        # positivity for μ
        return jax.nn.softplus(s) ** 2

In [5]:

def sample_interior(key, n):
    k = random.uniform(key, (n,), minval=par.k_min, maxval=par.k_max)
    z = random.uniform(key, (n,), minval=-par.z_max, maxval=par.z_max)
    return k, z

def sample_k_edges(key, n_each):
    # k = k_min or k_max; z uniform
    z = random.uniform(key, (2*n_each,), minval=-par.z_max, maxval=par.z_max)
    k_edge = jnp.concatenate([jnp.full((n_each,), par.k_min),
                              jnp.full((n_each,), par.k_max)])
    return k_edge, z

def sample_z_edges(key, n_each):
    # z = ± z_max; k uniform
    k = random.uniform(key, (2*n_each,), minval=par.k_min, maxval=par.k_max)
    z_edge = jnp.concatenate([jnp.full((n_each,), -par.z_max),
                              jnp.full((n_each,),  par.z_max)])
    return k, z_edge

# Vectorized nets
def V_eval(Vnet, k, z):
    return vmap(Vnet)(k, z)

def mu_eval(Mnet, k, z):
    return vmap(Mnet)(k, z)

# Gradients/Hessians in physical coordinates (chain handled by inside map)
def V_grads(Vnet, k, z):
    def g(kz):
        return Vnet(kz[0], kz[1])
    X = jnp.stack([k, z], axis=1)
    dV = vmap(jacrev(g))(X)            # (n,2)
    H  = vmap(hessian(g))(X)           # (n,2,2)
    Vv = vmap(g)(X)                    # (n,)
    return Vv, dV[:,0], dV[:,1], H[:,1,1]

def mu_grads(Mnet, k, z):
    def m(kz):
        return Mnet(kz[0], kz[1])
    X = jnp.stack([k, z], axis=1)
    dmu = vmap(jacrev(m))(X)           # (n,2)
    H   = vmap(hessian(m))(X)          # (n,2,2)
    muv = vmap(m)(X)
    return muv, dmu[:,0], dmu[:,1], H[:,1,1]  # μ, μ_k, μ_z, μ_zz


In [6]:

def theta_of(V_k):
    # θ(V_k) = θ_+ + (θ_- - θ_+) σ(-β(V_k-1))
    return par.theta_plus + (par.theta_minus - par.theta_plus) * jax.nn.sigmoid(-par.beta_theta * (V_k - 1.0))

def inv_policy(V_k, k):
    th = theta_of(V_k)
    return (k / th) * (V_k - 1.0)      # i^*

def prod_y(k, z):
    return jnp.exp(par.x_log + z) * (k ** par.alpha)

def hjb_residual(Vnet, Mnet, P, k, z):
    V, V_k, V_z, V_zz = V_grads(Vnet, k, z)
    i_star = inv_policy(V_k, k)
    theta  = theta_of(V_k)

    term_profit = P * prod_y(k, z)
    term_k      = - par.delta * k * V_k
    term_cost   = (k / (2.0 * theta)) * (V_k - 1.0)**2
    term_zdiff  = 0.5 * (par.sigma_z**2) * V_zz
    term_zdrift = - par.kappa_z * z * V_z

    H = term_profit + term_k + term_cost + term_zdiff + term_zdrift
    return par.r * V - H  # residual; target 0

def fp_residual(Vnet, Mnet, k, z):
    # Build probability current J = (J_k, J_z) and take divergence
    # J_k = (i* - δk) μ
    # J_z = (-κ z) μ - (σ^2/2) ∂_z μ
    V, V_k, V_z, V_zz = V_grads(Vnet, k, z)
    mu, mu_k, mu_z, mu_zz = mu_grads(Mnet, k, z)
    i_star = inv_policy(V_k, k)

    def J_single(kz):
        k_, z_ = kz[0], kz[1]
        # recompute scalars at (k_,z_)
        V_, V_k_, _, _ = V_grads(Vnet, jnp.array([k_]), jnp.array([z_]))
        mu_, _, mu_z_, _ = mu_grads(Mnet, jnp.array([k_]), jnp.array([z_]))
        V_k_ = V_k_.squeeze()
        mu_  = mu_.squeeze()
        mu_z_= mu_z_.squeeze()
        i_   = inv_policy(V_k_, k_)
        Jk = (i_ - par.delta * k_) * mu_
        Jz = (-par.kappa_z * z_) * mu_ - 0.5 * (par.sigma_z**2) * mu_z_
        return jnp.stack([Jk, Jz])

    X = jnp.stack([k, z], axis=1)
    # Jacobian of J wrt (k,z): (n,2,2)
    dJ = vmap(jacrev(J_single))(X)
    divJ = dJ[:,0,0] + dJ[:,1,1]
    # Stationary FP: divJ = 0  (since ∂_t μ = -divJ)
    return divJ

def boundary_currents(Vnet, Mnet, k_edges, z_for_k, k_for_z, z_edges):
    # j_k = (i* - δk) μ on k-edges
    V, V_k, _, _ = V_grads(Vnet, k_edges, z_for_k)
    mu, _, _, _ = mu_grads(Mnet, k_edges, z_for_k)
    i_star = inv_policy(V_k, k_edges)
    j_k = (i_star - par.delta * k_edges) * mu

    # j_z = (-κ z) μ - (σ^2/2) ∂_z μ on z-edges
    mu2, _, mu_z, _ = mu_grads(Mnet, k_for_z, z_edges)
    j_z = (-par.kappa_z * z_edges) * mu2 - 0.5 * (par.sigma_z**2) * mu_z

    return j_k, j_z


In [7]:

def losses(Vnet, Mnet, P, key):
    # Interior samples
    key_int, key_bk, key_bz, key_mass = random.split(key, 4)
    k_i, z_i = sample_interior(key_int, par.n_colloc)

    # Residuals
    R_hjb = hjb_residual(Vnet, Mnet, P, k_i, z_i)
    R_fp  = fp_residual(Vnet, Mnet, k_i, z_i)

    L_hjb = jnp.mean(R_hjb**2)
    L_fp  = jnp.mean(R_fp**2)

    # Boundary penalties
    n_each = par.n_bndry // 4
    k_edge, z_for_k = sample_k_edges(key_bk, n_each)
    k_for_z, z_edge = sample_z_edges(key_bz, n_each)

    j_k, j_z = boundary_currents(Vnet, Mnet, k_edge, z_for_k, k_for_z, z_edge)
    L_bndry = jnp.mean(j_k**2) + jnp.mean(j_z**2)

    # Mass penalty
    k_m, z_m = sample_interior(key_mass, par.n_mass)
    mu_m = mu_eval(Mnet, k_m, z_m)
    mass_hat = area * jnp.mean(mu_m)
    L_mass = (mass_hat - 1.0)**2

    return L_hjb, L_fp, L_bndry, L_mass, mass_hat


In [8]:

def aggregate_Y(Mnet, key):
    k_a, z_a = sample_interior(key, par.n_agg)
    mu_a = mu_eval(Mnet, k_a, z_a)
    y_a  = prod_y(k_a, z_a)
    # self-normalized MC: Y = (∫ y μ)/(∫ μ)
    num = jnp.mean(y_a * mu_a)
    den = jnp.mean(mu_a) + 1e-12
    return (num / den)

def update_price(P, Y):
    logP = jnp.log(P + 1e-12)
    logP_target = -par.eta * jnp.log(Y + 1e-12)
    logP_new = (1.0 - par.lam_price) * logP + par.lam_price * logP_target
    return jnp.exp(logP_new), (logP_new - logP)


In [9]:

# Initialize nets
k1, k2 = random.split(key, 2)
Vnet = ValueNet(k1, width=64, depth=3)
Mnet = DensityNet(k2, width=64, depth=3)

# Optax optimizers
optV = optax.chain(optax.clip_by_global_norm(par.grad_clip), optax.adam(par.lr_V))
optM = optax.chain(optax.clip_by_global_norm(par.grad_clip), optax.adam(par.lr_M))
optV_state = optV.init(eqx.filter(Vnet, eqx.is_array))
optM_state = optM.init(eqx.filter(Mnet, eqx.is_array))

P = jnp.array(1.0)  # initial price


In [10]:

@eqx.filter_jit
def step_V(Vnet, Mnet, P, opt_state, key):
    L_hjb, L_fp, L_bndry, L_mass, mass_hat = losses(Vnet, Mnet, P, key)
    def loss_only_V(Vnet2):
        R = hjb_residual(Vnet2, Mnet, P, *sample_interior(key, par.n_colloc))
        return jnp.mean(R**2)
    grads = eqx.filter_grad(loss_only_V)(Vnet)
    updates, new_state = optV.update(eqx.filter(grads, eqx.is_array), opt_state, eqx.filter(Vnet, eqx.is_array))
    Vnet = eqx.apply_updates(Vnet, updates)
    return Vnet, new_state, dict(L_HJB=L_hjb, L_FP=L_fp, L_B=L_bndry, L_M=L_mass, mass=mass_hat)

@eqx.filter_jit
def step_M(Vnet, Mnet, P, opt_state, key):
    def loss_M(Mnet2):
        L_hjb, L_fp, L_bndry, L_mass, mass_hat = losses(Vnet, Mnet2, P, key)
        return L_fp + par.w_bndry * L_bndry + par.w_mass * L_mass, (L_hjb, L_fp, L_bndry, L_mass, mass_hat)
    (L_total, aux), grads = eqx.filter_value_and_grad(loss_M, has_aux=True)(Mnet)
    updates, new_state = optM.update(eqx.filter(grads, eqx.is_array), opt_state, eqx.filter(Mnet, eqx.is_array))
    Mnet = eqx.apply_updates(Mnet, updates)
    L_hjb, L_fp, L_bndry, L_mass, mass_hat = aux
    return Mnet, new_state, dict(L_HJB=L_hjb, L_FP=L_fp, L_B=L_bndry, L_M=L_mass, mass=mass_hat, L_total=L_total)

In [None]:

print("Training...")
key_loop = key
prev_logP = jnp.log(P + 1e-12)

for outer in range(par.outer_steps):
    # value steps
    for t in range(par.steps_V):
        key_loop, sub = random.split(key_loop)
        Vnet, optV_state, diagV = step_V(Vnet, Mnet, P, optV_state, sub)

    # density steps
    for t in range(par.steps_M):
        key_loop, sub = random.split(key_loop)
        Mnet, optM_state, diagM = step_M(Vnet, Mnet, P, optM_state, sub)

    # aggregate and update price
    key_loop, sub = random.split(key_loop)
    Y = aggregate_Y(Mnet, sub)
    P_new, dlogP = update_price(P, Y)

    # Diagnostics
    msg = (f"[outer {outer+1}] "
           f"L_HJB={float(diagM['L_HJB']):.3e} "
           f"L_FP={float(diagM['L_FP']):.3e} "
           f"L_B={float(diagM['L_B']):.3e} "
           f"L_M={float(diagM['L_M']):.3e} "
           f"mass={float(diagM['mass']):.4f} "
           f"Y={float(Y):.4f} P={float(P):.4f} ΔlogP={float(dlogP):+.2e}")
    print(msg)

    # early stop
    if jnp.abs(dlogP) < par.tol_logP:
        P = P_new
        print("Early stop: price stabilized.")
        break

    P, prev_logP = P_new, prev_logP + dlogP

print("Done.")


Training...
[outer 1] L_HJB=1.001e-01 L_FP=2.921e-03 L_B=1.030e-02 L_M=1.600e-03 mass=0.9600 Y=0.4580 P=1.0000 ΔlogP=+5.47e-01
[outer 2] L_HJB=1.504e-02 L_FP=2.582e-03 L_B=6.998e-03 L_M=3.954e-03 mass=1.0629 Y=0.1314 P=1.7275 ΔlogP=+1.04e+00
[outer 3] L_HJB=1.031e+02 L_FP=2.357e-03 L_B=7.993e-03 L_M=2.328e-02 mass=0.8474 Y=0.0706 P=4.8779 ΔlogP=+7.47e-01
[outer 4] L_HJB=1.072e+03 L_FP=1.020e-01 L_B=3.096e-02 L_M=6.290e-03 mass=1.0793 Y=0.0597 P=10.2925 ΔlogP=+3.40e-01
[outer 5] L_HJB=2.753e+03 L_FP=1.639e-01 L_B=9.068e-03 L_M=4.306e-01 mass=0.3438 Y=0.0530 P=14.4676 ΔlogP=+1.86e-01


In [None]:

# Verify the simple conjugacy when θ is constant:
Vk, i, th, k_sym = sp.symbols('V_k i theta k', positive=True)
H = (Vk - 1)*i - (th/2) * i**2 / k_sym
dHdi = sp.diff(H, i)
i_star_sym = sp.solve(sp.Eq(dHdi, 0), i)[0]
H_star = sp.simplify(H.subs(i, i_star_sym))
sp.simplify(i_star_sym), sp.simplify(H_star)


In [None]:

# Optional: quick slice plot of μ(k, z=0) after training
ks = jnp.linspace(par.k_min, par.k_max, 200)
zs = jnp.zeros_like(ks)
mu_slice = mu_eval(Mnet, ks, zs)

plt.figure()
plt.plot(ks, mu_slice)
plt.xlabel("k (capital)")
plt.ylabel("μ(k, z=0)")
plt.title("Density slice after training")
plt.show()


In [None]:

# Final diagnostics
k_a, z_a = sample_interior(random.PRNGKey(123), par.n_agg)
mu_a = mu_eval(Mnet, k_a, z_a)
mass_hat = area * jnp.mean(mu_a)
Y = aggregate_Y(Mnet, random.PRNGKey(456))
print(f"Final mass ≈ {float(mass_hat):.6f},  Final Y ≈ {float(Y):.6f},  Final P ≈ {float(P):.6f}")
