
# Minimal Stationary HJB–FP × PINN Solver (JAX + Equinox)

**Scope:** Two PINNs on a rectangular box for a stationary HJB–Fokker–Planck fixed point with partial irreversibility (smooth buy/sell wedge), plus a damped price-clearing loop.  
**Dependencies:** `jax`, `equinox`, `optax`, `sympy`, `matplotlib`.

**Model:**
- States: capital \(k\in[k_{\min},k_{\max}]\), productivity shock \(z\in[-z_{\max}, z_{\max}]\) with OU dynamics \(\mathrm{d}z = -\kappa_z z\,\mathrm{d}t + \sigma_z \mathrm{d}B_t\).
- Output: \(y(k,z;x)=\exp(x+z)k^\alpha\). Aggregate \(Y=\iint y\,\mu\,\mathrm{d}k\,\mathrm{d}z\). Inverse demand \(P=Y^{-\eta}\).
- Capital law: \(\dot k = i - \delta k\).
- Irreversibility cost (smoothed): instantaneous Hamiltonian term \(\sup_i \{ (V_k-1)i - \frac{\theta(i)}{2}\frac{i^2}{k} \}\) with
  a differentiable switch \(\theta(V_k) \in \{\theta_+,\theta_-\}\) around \(V_k=1\).
- HJB residual:
\[
\mathcal{R}_{\rm HJB} = rV - \Big[ P e^{x+z}k^\alpha - \delta k\,V_k + \frac{k}{2\,\theta(V_k)}(V_k-1)^2 + \frac{1}{2}\sigma_z^2 V_{zz} - \kappa_z z\,V_z \Big].
\]
- Stationary FP residual (no diffusion in \(k\)):
\[
\mathcal{R}_{\rm FP} = -\partial_k\big((i^\*-\delta k)\mu\big) - \partial_z\big((-\,\kappa_z z)\mu\big) + \tfrac12\sigma_z^2\,\partial_{zz}\mu.
\]
- Control and fluxes:
\[
i^\*=\frac{k}{\theta(V_k)}(V_k-1),\quad
j_k=(i^\*-\delta k)\mu,\quad
j_z=(-\kappa_z z)\mu - \tfrac12\sigma_z^2\,\partial_z \mu.
\]

**Losses:**
- Interior: MSE of \(\mathcal{R}_{\rm HJB}\) and \(\mathcal{R}_{\rm FP}\).
- Boundary: MSE of \(j_k\) on \(k\)-edges and \(j_z\) on \(z\)-edges.
- Mass: \( (\widehat{\int\!\!\int\mu}-1)^2\) by Monte Carlo.
- Tiny L2 on parameters.

**Training loop:** Alternating training for \(V\) and \(\mu\) inside an outer price iteration: \(\log P \leftarrow (1-\lambda)\log P + \lambda \log(Y^{-\eta})\).

> This is a small baseline. It is intentionally minimal and uses random collocation with autodiff (DGM/PINN-style). Scale `n_*` and `steps_*` to tighten residuals.


In [17]:

# (Optional) If running in a fresh environment, uncomment to install.
# !pip install -q "jax[cpu]" equinox optax sympy matplotlib


In [18]:

import math
from dataclasses import dataclass

import jax
import jax.numpy as jnp
from jax import grad, vmap, jit
import optax
import equinox as eqx
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple

# Enable float64 for second derivatives stability
from jax import config as jax_config
jax_config.update("jax_enable_x64", True)


In [19]:

@dataclass
class Config:
    # Economics
    r: float = 0.05
    delta: float = 0.08
    alpha: float = 0.35
    eta: float = 1.5
    xbar: float = 0.0
    kappa_z: float = 0.5
    sigma_z: float = 0.5
    theta_plus: float = 0.5   # buy cost (lower)
    theta_minus: float = 1.5  # sell cost (higher)
    beta_switch: float = 20.0 # sharpness around Vk=1

    # Box domain
    k_min: float = 0.2
    k_max: float = 5.0
    z_max: float = 3.0

    # Training sizes
    n_colloc: int = 2048
    n_bndry: int  = 1024
    n_mass: int   = 4096
    n_agg: int    = 4096

    # Optimisation
    steps_V: int = 600
    steps_M: int = 600
    outer_loops: int = 5
    lr_V: float = 2e-3
    lr_M: float = 2e-3
    l2: float = 1e-6
    w_bndry: float = 5.0
    w_mass: float = 10.0

    # Price iteration
    P0: float = 1.0
    lambda_price: float = 0.7

cfg = Config()
key = jax.random.PRNGKey(0)


In [20]:

# Normalise inputs to [-1,1]^2 for network stability
def norm_inputs(k, z, cfg: Config):
    k_n = 2.0*(k - cfg.k_min)/(cfg.k_max - cfg.k_min) - 1.0
    z_n = z/cfg.z_max
    return jnp.stack([k_n, z_n], axis=-1)

def unnorm_k(kn, cfg: Config):
    return ( (kn + 1.0)*0.5 )*(cfg.k_max - cfg.k_min) + cfg.k_min

# Sampling
def sample_interior(key, n, cfg: Config):
    k = jax.random.uniform(key, (n,), minval=cfg.k_min, maxval=cfg.k_max)
    z = jax.random.uniform(jax.random.split(key, 2)[1], (n,), minval=-cfg.z_max, maxval=cfg.z_max)
    return k, z

def sample_edge_k(key, n, k_edge, cfg: Config):
    z = jax.random.uniform(key, (n,), minval=-cfg.z_max, maxval=cfg.z_max)
    k = jnp.full((n,), k_edge)
    return k, z

def sample_edge_z(key, n, z_edge, cfg: Config):
    k = jax.random.uniform(key, (n,), minval=cfg.k_min, maxval=cfg.k_max)
    z = jnp.full((n,), z_edge)
    return k, z

def area(cfg: Config):
    return (cfg.k_max - cfg.k_min) * (2.0*cfg.z_max)


In [21]:

# Two small MLPs
def make_mlp(key, width=64, depth=3):
    # Equinox MLP signature: MLP(in_size, out_size, width_size, depth, *, activation, key)
    return eqx.nn.MLP(in_size=2, out_size=1, width_size=width, depth=depth, activation=jax.nn.tanh, key=key)

key, k1, k2 = jax.random.split(key, 3)
V_net = make_mlp(k1, width=64, depth=3)   # value
S_net = make_mlp(k2, width=64, depth=3)   # density pre-activation; mu = softplus(S)^2

def mu_from_S(S):
    return jax.nn.softplus(S)**2 + 1e-12  # positivity and avoid exact zero


In [22]:

# Smooth switch for theta(V_k): small theta when Vk>1 (buy), large theta when Vk<1 (sell)
def theta_of_Vk(Vk, cfg: Config):
    # sigmoid(s) ~ 0 when Vk >> 1; ~1 when Vk << 1
    s = -cfg.beta_switch*(Vk - 1.0)
    sig = jax.nn.sigmoid(s)
    return cfg.theta_plus + (cfg.theta_minus - cfg.theta_plus)*sig

def i_star(k, Vk, cfg: Config):
    theta = theta_of_Vk(Vk, cfg)
    return (k/theta) * (Vk - 1.0)

def price_from_Y(Y, cfg: Config):
    return Y ** (-cfg.eta)


In [23]:

# Autodiff helpers
def V_and_derivs(params_V, k, z, cfg: Config):
    x = norm_inputs(k, z, cfg)
    V = eqx.filter_eval_shape(lambda m, x: m(x).squeeze(), params_V, x)  # shape hint
    V = V_net(x).squeeze()
    # First derivatives
    dV_dk = grad(lambda kk: V_net(norm_inputs(kk, z, cfg)).squeeze())(k)
    dV_dz = grad(lambda zz: V_net(norm_inputs(k, zz, cfg)).squeeze())(z)
    # Second derivative w.r.t z
    d2V_dzz = grad(lambda zz: grad(lambda t: V_net(norm_inputs(k, t, cfg)).squeeze())(zz))(z)
    return V, dV_dk, dV_dz, d2V_dzz

def S_mu_and_derivs(params_S, k, z, cfg: Config):
    x = norm_inputs(k, z, cfg)
    S = S_net(x).squeeze()
    mu = mu_from_S(S)
    # z-derivatives for diffusion flux and FP term
    dmu_dz = grad(lambda zz: mu_from_S(S_net(norm_inputs(k, zz, cfg)).squeeze()))(z)
    d2mu_dzz = grad(lambda zz: grad(lambda t: mu_from_S(S_net(norm_inputs(k, t, cfg)).squeeze()))(zz))(z)
    return S, mu, dmu_dz, d2mu_dzz

def hjb_residual(k, z, P, cfg: Config):
    V, Vk, Vz, Vzz = V_and_derivs(V_net, k, z, cfg)
    flow = P * jnp.exp(cfg.xbar + z) * (k**cfg.alpha)
    term_inv = (k/(2.0*theta_of_Vk(Vk, cfg))) * (Vk - 1.0)**2
    rhs = flow - cfg.delta * k * Vk + 0.5 * (cfg.sigma_z**2) * Vzz - cfg.kappa_z * z * Vz + term_inv
    return cfg.r*V - rhs

def fp_residual(k, z, P, cfg: Config):
    _, mu, dmu_dz, d2mu_dzz = S_mu_and_derivs(S_net, k, z, cfg)
    V, Vk, _, _ = V_and_derivs(V_net, k, z, cfg)
    drift_k = i_star(k, Vk, cfg) - cfg.delta * k
    drift_z = -cfg.kappa_z * z
    # Partial derivatives of drift*mu terms
    def gk_mu(kk):
        Vloc, Vkloc, _, _ = V_and_derivs(V_net, kk, z, cfg)
        return (i_star(kk, Vkloc, cfg) - cfg.delta*kk) * mu_from_S(S_net(norm_inputs(kk, z, cfg)).squeeze())
    d_gkmu_dk = grad(gk_mu)(k)
    def gz_mu(zz):
        return (-cfg.kappa_z * zz) * mu_from_S(S_net(norm_inputs(k, zz, cfg)).squeeze())
    d_gzmu_dz = grad(gz_mu)(z)
    return - d_gkmu_dk - d_gzmu_dz + 0.5*(cfg.sigma_z**2) * d2mu_dzz

# Boundary fluxes
def flux_k(k_edge, z, cfg: Config):
    _, mu, _, _ = S_mu_and_derivs(S_net, k_edge, z, cfg)
    _, Vk, _, _ = V_and_derivs(V_net, k_edge, z, cfg)
    return (i_star(k_edge, Vk, cfg) - cfg.delta*k_edge) * mu

def flux_z(k, z_edge, cfg: Config):
    _, mu, dmu_dz, _ = S_mu_and_derivs(S_net, k, z_edge, cfg)
    return (-cfg.kappa_z * z_edge) * mu - 0.5*(cfg.sigma_z**2) * dmu_dz


In [24]:

def hjb_loss(batch_k, batch_z, P, cfg: Config):
    res = vmap(lambda kk, zz: hjb_residual(kk, zz, P, cfg))(batch_k, batch_z)
    return jnp.mean(res**2)

def fp_loss(batch_k, batch_z, P, cfg: Config):
    res = vmap(lambda kk, zz: fp_residual(kk, zz, P, cfg))(batch_k, batch_z)
    return jnp.mean(res**2)

def boundary_loss(kmin_samples, kmax_samples, zmin_samples, zmax_samples, cfg: Config):
    jk_min = vmap(lambda zz: flux_k(cfg.k_min, zz, cfg))(kmin_samples[1])
    jk_max = vmap(lambda zz: flux_k(cfg.k_max, zz, cfg))(kmax_samples[1])
    jz_min = vmap(lambda kk: flux_z(kk, -cfg.z_max, cfg))(zmin_samples[0])
    jz_max = vmap(lambda kk: flux_z(kk,  cfg.z_max, cfg))(zmax_samples[0])
    return jnp.mean(jk_min**2) + jnp.mean(jk_max**2) + jnp.mean(jz_min**2) + jnp.mean(jz_max**2)

def mass_estimate(key, cfg: Config):
    k, z = sample_interior(key, cfg.n_mass, cfg)
    mu_vals = vmap(lambda kk, zz: mu_from_S(S_net(norm_inputs(kk, zz, cfg)).squeeze()))(k, z)
    return area(cfg) * jnp.mean(mu_vals)

def aggregate_Y(key, P, cfg: Config):
    k, z = sample_interior(key, cfg.n_agg, cfg)
    mu_vals = vmap(lambda kk, zz: mu_from_S(S_net(norm_inputs(kk, zz, cfg)).squeeze()))(k, z)
    y = jnp.exp(cfg.xbar + z) * (k**cfg.alpha)
    return area(cfg) * jnp.mean(y * mu_vals)


In [25]:

@eqx.filter_value_and_grad
def loss_V(P, key, cfg: Config):
    k, z = sample_interior(key, cfg.n_colloc, cfg)
    l = hjb_loss(k, z, P, cfg)
    # L2 regularization
    l2 = sum([jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(eqx.filter(V_net, eqx.is_inexact_array))])
    return l + cfg.l2 * l2

@eqx.filter_value_and_grad
def loss_M(P, keys, cfg: Config):
    k, z = sample_interior(keys[0], cfg.n_colloc, cfg)
    l_fp = fp_loss(k, z, P, cfg)
    # boundary samples
    kz1 = sample_edge_k(keys[1], cfg.n_bndry, cfg.k_min, cfg)
    kz2 = sample_edge_k(keys[2], cfg.n_bndry, cfg.k_max, cfg)
    kz3 = sample_edge_z(keys[3], cfg.n_bndry, -cfg.z_max, cfg)
    kz4 = sample_edge_z(keys[4], cfg.n_bndry,  cfg.z_max, cfg)
    l_b = boundary_loss(kz1, kz2, kz3, kz4, cfg)
    # mass penalty
    m = mass_estimate(keys[5], cfg)
    l_mass = (m - 1.0)**2
    # L2
    l2 = sum([jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(eqx.filter(S_net, eqx.is_inexact_array))])
    return l_fp + cfg.w_bndry*l_b + cfg.w_mass*l_mass + cfg.l2*l2

opt_V = optax.adam(cfg.lr_V)
opt_M = optax.adam(cfg.lr_M)
opt_state_V = opt_V.init(eqx.filter(V_net, eqx.is_inexact_array))
opt_state_M = opt_M.init(eqx.filter(S_net, eqx.is_inexact_array))

def train_outer(P, key, cfg: Config):
    global V_net, S_net, opt_state_V, opt_state_M
    # --- Train V ---
    for t in range(cfg.steps_V):
        key, sub = jax.random.split(key)
        (lV), grads = loss_V(P, sub, cfg)
        updates, opt_state_V = opt_V.update(grads, opt_state_V, params=eqx.filter(V_net, eqx.is_inexact_array))
        V_net = eqx.apply_updates(V_net, updates)
        if (t+1) % max(1, cfg.steps_V//3) == 0:
            print(f"  V step {t+1}/{cfg.steps_V}: L_HJB={float(lV):.3e}")
    # --- Train M ---
    for t in range(cfg.steps_M):
        key, *subs = jax.random.split(key, 7)
        (lM), grads = loss_M(P, subs, cfg)
        updates, opt_state_M = opt_M.update(grads, opt_state_M, params=eqx.filter(S_net, eqx.is_inexact_array))
        S_net = eqx.apply_updates(S_net, updates)
        if (t+1) % max(1, cfg.steps_M//3) == 0:
            print(f"  M step {t+1}/{cfg.steps_M}: L_FP+pen={float(lM):.3e}")
    # --- Diagnostics & price update ---
    key, km, ka = jax.random.split(key, 3)
    mhat = float(mass_estimate(km, cfg))
    Yhat = float(aggregate_Y(ka, P, cfg))
    P_new = math.exp((1-cfg.lambda_price)*math.log(P) + cfg.lambda_price*math.log(max(1e-12, Yhat**(-cfg.eta))))
    return P_new, mhat, Yhat, key


In [None]:

# Run a short outer loop as a smoke test (increase steps_* and n_* for accuracy)
P = cfg.P0
key = jax.random.PRNGKey(42)
for it in range(cfg.outer_loops):
    print(f"\n=== Outer iteration {it+1}/{cfg.outer_loops}, current P={P:.4f} ===")
    P, mhat, Yhat, key = train_outer(P, key, cfg)
    print(f"Mass ~ {mhat:.4f}, Y ~ {Yhat:.4f}, updated P={P:.4f}")



=== Outer iteration 1/5, current P=1.0000 ===
  V step 200/600: L_HJB=7.284e+01
  V step 400/600: L_HJB=7.733e+01
  V step 600/600: L_HJB=6.781e+01


In [1]:

# Quick visual: mu(k, z=0) profile after training (coarse grid)
import numpy as onp
kgrid = onp.linspace(cfg.k_min, cfg.k_max, 200)
z0 = 0.0
mu_line = []
for kk in kgrid:
    mu_line.append(mu_from_S(S_net(norm_inputs(jnp.array(kk), jnp.array(z0), cfg)).squeeze()))
mu_line = jnp.array(mu_line)
plt.figure(figsize=(5,3))
plt.plot(kgrid, onp.asarray(mu_line))
plt.title("Stationary density slice: $\\mu(k, z=0)$")
plt.xlabel("k"); plt.ylabel("mu")
plt.tight_layout()
plt.show()


NameError: name 'cfg' is not defined

In [None]:

# Verify i* and conjugate term with SymPy for constant theta
import sympy as sp
i, Vk, k, theta = sp.symbols('i Vk k theta', positive=True, real=True)
expr = (Vk - 1)*i - (theta/2)*i**2/k
istar = sp.simplify(sp.diff(expr, i))
istar_sol = sp.solve(sp.Eq(istar, 0), i)[0]
sup_val = sp.simplify(expr.subs(i, istar_sol))
print("i* =", istar_sol)
print("Sup value =", sup_val)
