
# Continuous-Time HJB–FP Solver for a Zhang (2005)-Style Heterogeneous-Firm Asset-Pricing Model

This notebook sets up a **continuous-time** analogue of the classical **Zhang (2005, *Journal of Finance*)** investment-based asset-pricing model with **asymmetric (partial) irreversibility** in capital adjustment costs and an **inverse industry demand curve**. We solve the firm problem via a **Hamilton–Jacobi–Bellman (HJB)** equation and the cross-sectional dynamics via the **Fokker–Planck (FP)** equation. The numerical scheme follows the **Achdou–Han–Lasry–Lions–Moll (2022, *REStud*)** finite-difference blueprint.

**Key features**  
- Firm state: capital `k`, idiosyncratic productivity `z`; aggregate state `x` affects demand and the stochastic discount factor (SDF).  
- Revenue: `exp(x + z) * P(Y) * k**alpha`, with inverse demand `P(Y) = Y**(-eta)` and `Y` the aggregate output.  
- Adjustment costs: asymmetric quadratic `theta_plus` for `i >= 0`, `theta_minus` for `i < 0`.  
- Control: investment flow `i`. Capital law: `dk = (i - delta * k) dt`.  
- SDF in continuous time: risk-free rate `r(x)` and market price of risk `lambda(x)` adjust the **risk-neutral** drift of `x` in the HJB.  
- Discretization: FD with **upwinding** for hyperbolic terms, implicit/Howard improvement for the HJB; FP solved as the stationary distribution of the controlled flow.  
- Equilibrium: given `x`, solve for price `p=log P` consistent with `P(Y)` where `Y` aggregates firm output under the stationary `mu(k,z)` implied by the HJB policy.  

> This scaffold is meant to be **readable and hackable**. It uses **JAX** + **Equinox** for speed/JIT and to keep code composable.


In [1]:

# --- Imports & Config
import math, functools
import jax
import jax.numpy as jnp
from jax import jit, vmap, lax
import equinox as eqx

# If you want sparse, you can switch to jax.experimental.sparse later.
# Here we keep things dense/small to remain portable.

# Plotting optional (commented): 
# import matplotlib.pyplot as plt

# Typing helpers
Array = jnp.ndarray



## 1) Symbolic algebra (SymPy): First-order condition for investment and the reduced-form HJB

We derive the optimal investment rule under asymmetric quadratic adjustment costs. For a given marginal value of capital `V_k`, the control subproblem is quadratic, yielding a closed-form optimal `i*` and a closed-form **maximized** contribution to the HJB.


In [2]:

import sympy as sp

# Symbols
i, k, theta_p, theta_m, Vk = sp.symbols('i k theta_p theta_m Vk', positive=True)
one = sp.Integer(1)

# Quadratic cost: c(i,k) = theta/2 * (i^2)/k, with theta = theta_p if i>=0, theta_m if i<0
# Control objective part:  - i - 0.5 * theta * i^2 / k + Vk * i
theta = sp.symbols('theta', positive=True)
obj = - i - sp.Rational(1,2) * theta * i**2 / k + Vk * i

# FOC for unconstrained quadratic (given a fixed theta): d/di obj = 0
i_star = sp.simplify(sp.solve(sp.diff(obj, i), i)[0])
obj_star = sp.simplify(obj.subs(i, i_star))

i_star, obj_star


(k*(Vk - 1)/theta, k*(Vk - 1)**2/(2*theta))


**Result:** For a given `theta`, the quadratic subproblem yields
\[
i^* = \frac{k}{\theta}\,(V_k - 1),\qquad \max_i\{-i - \tfrac{\theta}{2}\tfrac{i^2}{k} + V_k i\} = \tfrac{k}{2\theta}\,(V_k - 1)^2.
\]

With **partial irreversibility**, use `theta = theta_plus` when `i^* >= 0` (i.e. when `V_k \ge 1`) and `theta = theta_minus` when `i^* < 0` (i.e. when `V_k < 1`). Numerically we use a **smoothed** switch for differentiability.



## 2) Parameters & model containers (Equinox)


In [None]:

@eqx.filter_jit
def soft_indicator(x, kappa=50.0):
    # Smooth step ~ 1[x>=0]; larger kappa -> sharper
    return 1.0 / (1.0 + jnp.exp(-kappa * x))

class Params(eqx.Module):
    # Preferences / pricing
    r0: float = 0.02      # baseline risk-free rate
    r1: float = 0.0       # slope wrt x (optional)
    lam0: float = 0.5     # market price of risk intercept for x shock
    lam1: float = -0.2    # slope wrt x (countercyclical risk price)

    # Technology / demand
    alpha: float = 0.7
    eta: float = 0.5      # inverse demand elasticity: P(Y) = Y^{-eta}
    delta: float = 0.08
    f: float = 0.0        # fixed operating cost (can be > 0)

    # Adjustment cost asymmetry (theta_minus > theta_plus)
    theta_plus: float  = 1.0
    theta_minus: float = 3.0

    # Aggregate shock x: OU under P; risk-neutral drift uses lam(x)
    kappa_x: float = 0.2
    xbar: float = 0.0
    sig_x: float = 0.1

    # Idiosyncratic shock z: OU
    kappa_z: float = 1.0
    sig_z: float = 0.2
    rho_xz: float = 0.0   # corr(x,z) if needed

    # Grids
    k_min: float = 1e-3
    k_max: float = 10.0
    n_k: int   = 200

    z_min: float = -3.0
    z_max: float =  3.0
    n_z: int   = 121

    # Numerics
    max_iter_hjb: int = 5000
    tol_hjb: float = 1e-6
    damp_hjb: float = 0.5

    max_iter_fp: int = 20000
    tol_fp: float = 1e-10
    dt_fp: float = 1e-3  # pseudo-time step for FP power iteration

    # Price iteration
    max_iter_price: int = 200
    tol_price: float = 1e-6
    damp_price: float = 0.7

params = Params()



## 3) Grids and primitives


In [4]:

def make_grids(p: Params):
    k = jnp.linspace(p.k_min, p.k_max, p.n_k)
    z = jnp.linspace(p.z_min, p.z_max, p.n_z)
    dk = k[1] - k[0]
    dz = z[1] - z[0]
    return k, z, dk, dz

@jit
def output_per_firm(k, z, x, p: Params):
    # physical output (before price) y = exp(x+z) * k**alpha
    return jnp.exp(x + z) * (k ** p.alpha)

@jit
def inverse_demand(Y, p: Params):
    # P(Y) = Y^{-eta}; returns price level (not log)
    return Y ** (-p.eta)

@jit
def price_from_mu(mu, k, z, x, p: Params):
    # Aggregate output Y(mu,x) then P(Y)
    # mu shape: (n_k, n_z), grids 1D
    y = output_per_firm(k[:, None], z[None, :], x, p)
    Y = jnp.sum(mu * y)
    P = inverse_demand(Y + 1e-12, p)
    return jnp.log(P), P, Y



## 4) HJB operator and policy
We solve the stationary HJB at a **fixed aggregate state** `x` and **log-price** `p_log` (consistent with inverse demand). Risk enters via the continuous-time SDF through `r(x)` and the **risk-neutral** drift of `x`. For the **snapshot** problem at fixed `x`, the `x`-derivative terms vanish; for fully time-dependent problems, include them.


In [5]:

@jit
def r_of_x(x, p: Params):
    return p.r0 + p.r1 * (x - p.xbar)

@jit
def lambda_of_x(x, p: Params):
    return p.lam0 + p.lam1 * (x - p.xbar)

@jit
def revenue(k, z, x, p_log, p: Params):
    # instantaneous revenue net of operating fixed cost f (before investment/adj.cost)
    return jnp.exp(x + z + p_log) * (k ** p.alpha) - p.f

@jit
def theta_effective(Vk, p: Params, kappa=50.0):
    # Smooth switch between theta_plus (i>=0) when Vk>=1 and theta_minus (i<1)
    w = soft_indicator(Vk - 1.0, kappa)
    return w * p.theta_plus + (1.0 - w) * p.theta_minus

@jit
def i_star_from_Vk(Vk, k, p: Params, kappa=50.0):
    th = theta_effective(Vk, p, kappa)
    return (k / th) * (Vk - 1.0)

@jit
def hjb_rhs_maximized(Vk, k, z, x, p_log, p: Params):
    # Plug back the maximized control contribution: 0.5 * k / theta * (Vk - 1)^2
    th = theta_effective(Vk, p)
    return revenue(k, z, x, p_log, p) - p.delta * k * Vk + 0.5 * k * (Vk - 1.0)**2 / th

# Finite differences (central for z, upwind for k drift handled in FP; here only gradients for Vk)
def diff_k(V, dk):
    # central differences in k for interior; forward/backward at boundaries
    Vk = jnp.zeros_like(V)
    Vk = Vk.at[1:-1,:].set( (V[2:,:] - V[:-2,:]) / (2.0*dk) )
    Vk = Vk.at[0,:].set( (V[1,:] - V[0,:]) / dk )
    Vk = Vk.at[-1,:].set( (V[-1,:] - V[-2,:]) / dk )
    return Vk

def diff_zz(V, dz):
    Vzz = jnp.zeros_like(V)
    Vzz = Vzz.at[:,1:-1].set( (V[:,2:] - 2.0*V[:,1:-1] + V[:,:-2]) / (dz**2) )
    Vzz = Vzz.at[:,0].set( (V[:,1] - V[:,0]) / (dz**2) )
    Vzz = Vzz.at[:,-1].set( (V[:,-2] - V[:,-1]) / (dz**2) )
    return Vzz

@jit
def hjb_operator(V, k, z, x, p_log, p: Params):
    dk = k[1]-k[0]; dz = z[1]-z[0]
    Vk = diff_k(V, dk)
    Vzz = diff_zz(V, dz)
    # Drift in z (OU): dz = -kappa_z * z dt + sig_z dB
    mu_z = -p.kappa_z * z[None, :]
    # Stationary snapshot in x: omit x-diffusion terms here (can add Vx terms if doing time-dependent)
    # HJB stationary: r(x) * V = max[...] + 0.5 * sig_z^2 * V_zz + mu_z * V_z (but we eliminated V_z via central diff only if needed)
    # We avoid V_z term for robustness in stationary (policy doesn't depend on V_z here because control is i)
    rhs = hjb_rhs_maximized(Vk, k[:,None], z[None,:], x, p_log, p) + 0.5*(p.sig_z**2)*Vzz
    return rhs



### HJB fixed point

We target the stationary fixed point `V = rhs / r(x)` with damping and optional Howard improvement via policy iteration. For a robust starter, we apply a simple **Picard** iteration with damping.


In [6]:

def solve_hjb(k, z, x, p_log, p: Params):
    V = jnp.zeros((k.size, z.size))
    rx = r_of_x(x, p)
    def body_fun(itV):
        it, V = itV
        rhs = hjb_operator(V, k, z, x, p_log, p)
        V_new = (1.0 - p.damp_hjb)*V + p.damp_hjb*(rhs / rx)
        diff = jnp.max(jnp.abs(V_new - V))
        return (it+1, V_new), diff
    def cond_fun(loop_carry):
        (it, V), diff = loop_carry
        return jnp.logical_and(it < p.max_iter_hjb, diff > p.tol_hjb)
    (it, V), diff = lax.while_loop(cond_fun, 
                                   lambda lc: (body_fun(lc[0]), body_fun(lc[0])[1]), 
                                   ((0, V), jnp.inf))
    return V



## 5) Policy function and Fokker–Planck stationary distribution

Given the HJB solution, compute `i*(k,z)` and the **drift in `k`**: `g_k = i - delta*k`. Then solve the stationary FP:
\[
0 = -\partial_k\big(g_k\,\mu\big) - \partial_z\big(\mu_z\,\mu\big) + \tfrac{1}{2}\partial_{zz}\big(\sigma_z^2\,\mu\big).
\]
We implement a **power iteration** on an explicit Euler discretization that preserves positivity and mass (small `dt_fp`).


In [7]:

def compute_policy(V, k, z, p: Params):
    Vk = diff_k(V, k[1]-k[0])
    I = i_star_from_Vk(Vk, k[:,None], p)
    gk = I - p.delta * k[:,None]
    return I, gk

@jit
def fp_step(mu, gk, z, p: Params, dk, dz, dt):
    # Upwind in k for drift term: -∂k(gk * mu)
    mu_kplus = jnp.vstack([mu[1:,:], mu[-1:,:]])
    mu_kminus = jnp.vstack([mu[0:1,:], mu[:-1,:]])
    gk_pos = jnp.maximum(gk, 0.0)
    gk_neg = jnp.minimum(gk, 0.0)
    flux_k = (gk_pos * (mu - mu_kminus) + gk_neg * (mu_kplus - mu)) / dk

    # z diffusion (central)
    mu_zz = jnp.zeros_like(mu)
    mu_zz = mu_zz.at[:,1:-1].set( (mu[:,2:] - 2*mu[:,1:-1] + mu[:,:-2]) / (dz**2) )
    mu_zz = mu_zz.at[:,0].set( (mu[:,1] - mu[:,0]) / (dz**2) )
    mu_zz = mu_zz.at[:,-1].set( (mu[:,-2] - mu[:,-1]) / (dz**2) )

    # z drift term -∂z(mu_z * mu), with mu_z = -kappa_z * z
    muz = -p.kappa_z * z[None, :]
    # Upwind in z
    mu_zplus = jnp.hstack([mu[:,1:], mu[:,-1:]])
    mu_zminus = jnp.hstack([mu[:,0:1], mu[:,:-1]])
    muz_pos = jnp.maximum(muz, 0.0)
    muz_neg = jnp.minimum(muz, 0.0)
    flux_z = (muz_pos * (mu - mu_zminus) + muz_neg * (mu_zplus - mu)) / dz

    rhs = - flux_k - flux_z + 0.5*(p.sig_z**2) * mu_zz
    mu_new = mu + dt * rhs
    mu_new = jnp.maximum(mu_new, 0.0)
    mu_new = mu_new / (jnp.sum(mu_new) + 1e-16)
    return mu_new

def solve_fp(gk, k, z, p: Params):
    dk = k[1] - k[0]; dz = z[1] - z[0]
    mu = jnp.ones((k.size, z.size)); mu = mu / jnp.sum(mu)
    def body_fun(it_mu):
        it, mu = it_mu
        mu_new = fp_step(mu, gk, z, p, dk, dz, p.dt_fp)
        diff = jnp.max(jnp.abs(mu_new - mu))
        return (it+1, mu_new), diff
    def cond_fun(state):
        (it, mu), diff = state
        return jnp.logical_and(it < p.max_iter_fp, diff > p.tol_fp)
    (it, mu), diff = lax.while_loop(cond_fun,
                                    lambda s: (body_fun(s[0]), body_fun(s[0])[1]),
                                    ((0, mu), jnp.inf))
    return mu



## 6) Price clearing loop at a fixed aggregate `x`

Iterate on the **log price** `p_log` until `p_log = log P(Y(mu))` under the stationary distribution implied by the HJB policy given `p_log`.


In [8]:

def equilibrium_at_x(x, p: Params):
    k, z, dk, dz = make_grids(p)
    # initialize price log as 0.0 (price level 1)
    p_log = 0.0
    for it in range(p.max_iter_price):
        V = solve_hjb(k, z, x, p_log, p)
        I, gk = compute_policy(V, k, z, p)
        mu = solve_fp(gk, k, z, p)
        p_log_new, P_level, Y = price_from_mu(mu, k, z, x, p)
        p_log = p.damp_price * p_log_new + (1.0 - p.damp_price) * p_log
        # Convergence check: re-evaluate Y under updated price next loop
    return p_log, V, I, mu



## 7) Minimal run (snapshot at the steady state `x = xbar`)

Uncomment to run on your machine with JAX properly installed. This will iterate to a snapshot equilibrium and return the log price, value function, investment policy, and stationary distribution.


In [9]:

# x0 = params.xbar
# p_log_star, V_star, I_star, mu_star = equilibrium_at_x(x0, params)
# print("log price at equilibrium:", p_log_star)
# # Optional: visualize slices or aggregates
# # import matplotlib.pyplot as plt
# # plt.plot(jnp.linspace(params.k_min, params.k_max, params.n_k), I_star[:, I_star.shape[1]//2])
# # plt.title("Investment policy at median z")
# # plt.show()
