In [19]:
# Two Trees (log-utility) via Deep-FBSDE, with Heun–Stratonovich BSDE loss
# Math anchors: Eq. (21)–(23), (48), (50) in Two Trees; BSDE↔PDE (Pardoux–Peng).
# Citations in accompanying notes.

import jax, jax.numpy as jnp
import equinox as eqx, optax

jax.config.update("jax_enable_x64", True)

# -------------------- Model parameters and share SDE --------------------
class Params(eqx.Module):
    delta: float
    mu1: float; mu2: float
    sig1: float; sig2: float
    rho: float

def eta2(p):  # Var of σ1 dZ1 − σ2 dZ2 with corr ρ
    return p.sig1**2 + p.sig2**2 - 2.0*p.rho*p.sig1*p.sig2

def b_ito(s, p):
    # Itô drift from Eq. (48)
    return s*(1.0 - s)*(p.mu1 - p.mu2 - s*p.sig1**2 + (1.0 - s)*p.sig2**2
                        + 2.0*(s - 0.5)*p.rho*p.sig1*p.sig2)

def sig_s(s, p):                  # scalar diffusion magnitude
    return s*(1.0 - s)*jnp.sqrt(eta2(p))

def b_strat(s, p):               # Itô → Stratonovich: b_S = b_I - 0.5 * Σ G ∂G
    # Here G(s) = s(1-s)[σ1, -σ2] with corr matrix [[1,ρ],[ρ,1]] ⇒ correction = 0.5*η²*s(1-s)(1-2s)
    e2 = eta2(p)
    return b_ito(s, p) - 0.5 * e2 * s*(1.0 - 1.0*s)*(1.0 - 2.0*s)


def heun_strat_step(s, dt, dW, p):
    s = jnp.clip(s, 1e-8, 1.0 - 1e-8)
    bS = b_strat(s, p);  sig = sig_s(s, p)
    s_pred = jnp.clip(s + bS*dt + sig*dW, 1e-8, 1.0 - 1e-8)
    return jnp.clip(s + 0.5*(bS + b_strat(s_pred, p))*dt
                      + 0.5*(sig + sig_s(s_pred, p))*dW, 1e-8, 1.0 - 1.0*1e-8)

# -------------------- Single head with hard boundary conditions --------------------
class Head(eqx.Module):
    mlp: eqx.nn.MLP
    def __init__(self, key, width=32, depth=2):
        self.mlp = eqx.nn.MLP(1, 1, width, depth, key=key, activation=jax.nn.tanh)
    def __call__(self, s, delta):
        # Enforces Y(0)=0 and Y(1)=1/delta
        h = self.mlp(s[..., None]).squeeze(-1)
        return (s + s*(1.0 - s)*h) / delta

# -------------------- Path simulation (antithetic) --------------------
def sim_paths(key, p, n_paths=4096, n_steps=64, T=1.0):
    dt = T / n_steps
    k = jax.random.split(key, 1)[0]
    dW = jax.random.normal(k, (n_paths, n_steps)) * jnp.sqrt(dt)
    dW = jnp.concatenate([dW, -dW], axis=0)  # antithetic → shape (2*n_paths, n_steps)
    s0 = jnp.full((dW.shape[0],), 0.5)
    def step(carry, dw):
        new_s = heun_strat_step(carry, dt, dw, p)
        return new_s, new_s  # emit state to collect full path
    last_s, s_hist = jax.lax.scan(step, s0, dW.T)   # s_hist: [steps, paths]
    S = s_hist.T                                    # [paths, steps]
    return S, dW, dt

# Helpers that robustly handle 0D/1D/2D inputs by vectorizing along axis 0 when present

def _flatten_vmap_reshape(fun, X):
    """Apply fun elementwise preserving (B,T) layout; supports 0D/1D/2D.
    fun: maps scalar→scalar. X: () | (B,) | (B,T). Returns same shape as X.
    """
    ndim = jnp.ndim(X)
    if ndim == 0:
        return fun(X)
    if ndim == 1:
        return jax.vmap(fun)(X)
    if ndim == 2:
        B, T = X.shape
        Y = jax.vmap(fun)(X.reshape(B*T,))
        return Y.reshape(B, T)
    raise ValueError(f"_flatten_vmap_reshape: unsupported ndim {ndim} for shape {jnp.shape(X)}")

def f_apply(model, p, x):
    f = lambda z: model(z, p.delta)
    return _flatten_vmap_reshape(f, x)

def fprime_batch(model, p, X):
    g = jax.grad(lambda z: model(z, p.delta))
    return _flatten_vmap_reshape(g, X)

# -------------------- Stratonovich trapezoid BSDE loss --------------------
@eqx.filter_value_and_grad
def bsde_loss(model, key, p, n_paths=4096, n_steps=64, T=1.0):
    S, dW, dt = sim_paths(key, p, n_paths, n_steps, T)
    # All below should be (paths, steps)
    Y    = f_apply(model, p, S)
    fp   = fprime_batch(model, p, S)
    Z    = fp * sig_s(S, p)

    # shape-safe "next in time"
    S_n  = jnp.concatenate([S[..., 1:], S[..., -1:]], axis=-1)
    Y_n  = jnp.concatenate([Y[..., 1:], Y[..., -1:]], axis=-1)
    fp_n = fprime_batch(model, p, S_n)
    Z_n  = fp_n * sig_s(S_n, p)

    # Stratonovich trapezoid one-step: average drift/diffusion; stop-grad on next step
    drift = 0.5*((p.delta*(Y + jax.lax.stop_gradient(Y_n))) - (S + jax.lax.stop_gradient(S_n))) * dt
    diff  = 0.5*(Z + jax.lax.stop_gradient(Z_n)) * dW
    resid = jax.lax.stop_gradient(Y_n) - (Y + drift + diff)

    # Control variate on ΔW (OLS slope; stop-grad to avoid chasing noise)
    beta  = (jnp.mean(resid*dW) / (jnp.mean(dW**2) + 1e-16))
    resid = resid - jax.lax.stop_gradient(beta)*dW
    return jnp.mean(resid**2)

def make_opt(): return optax.chain(optax.clip(1.0), optax.adam(3e-3))

@eqx.filter_jit
def train_step(model, opt, opt_state, key, p, n_paths, n_steps):
    val, grads = bsde_loss(model, key, p, n_paths, n_steps)
    updates, opt_state = opt.update(grads, opt_state, eqx.filter(model, eqx.is_inexact_array))
    model = eqx.apply_updates(model, updates)
    return model, opt_state, val

# -------------------- Independent acceptance checks --------------------

def generator_residual(model, p, grid):
    f   = lambda s: model(s, p.delta)
    fp  = jax.vmap(jax.grad(f))(grid)
    fpp = jax.vmap(lambda s: jax.grad(lambda u: jax.grad(f)(u))(s))(grid)
    b   = jax.vmap(lambda s: b_ito(s, p))(grid)
    v   = jax.vmap(lambda s: sig_s(s, p)**2)(grid)
    return -grid + p.delta*f_apply(model, p, grid) - b*fp - 0.5*v*fpp

def fd_bvp_solution(p, N=400):
    s = jnp.linspace(0.0, 1.0, N+1).at[0].set(1e-12).at[-1].set(1.0-1e-12)
    h = s[1]-s[0]
    b = b_ito(s, p); v = sig_s(s, p)**2
    a_lo = -0.5*v[1:-1]/h**2 + b[1:-1]/(2*h)
    a_di =  p.delta*jnp.ones_like(s[1:-1]) + v[1:-1]/h**2
    a_up = -0.5*v[1:-1]/h**2 - b[1:-1]/(2*h)
    rhs  = s[1:-1]
    rhs  = rhs.at[-1].add(-a_up[-1]*(1.0/p.delta))  # Y(1)=1/δ
    from jax.lax.linalg import tridiagonal_solve
    rhs2 = rhs[:, None]  # make rhs rank-2 for current JAX API
    f_inner = tridiagonal_solve(a_lo, a_di, a_up, rhs2)[:, 0]
    f = jnp.concatenate([jnp.array([0.0]), f_inner, jnp.array([1.0/p.delta])])
    return s, f

def f_closed_form_sym(s, delta):
    # Eq. (22) holds in the symmetric case with δ=σ^2
    V = (1.0/(2.0*delta*s))*(1.0 + ((1.0 - s)/s)*jnp.log1p(-s) - (s/(1.0 - s))*jnp.log(s))
    return s*V

 

In [20]:
# -------------------- Tiny driver (example: symmetric special case) --------------------
p = Params(delta=0.04, mu1=0.02, mu2=0.02, sig1=0.2, sig2=0.2, rho=0.0)  # δ=σ^2 for Eq.(22)
model = Head(jax.random.PRNGKey(0))
opt   = make_opt(); opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))
key = jax.random.PRNGKey(1)
for _ in range(1500):
    key, sub = jax.random.split(key)
    model, opt_state, _ = train_step(model, opt, opt_state, sub, p, n_paths=4096, n_steps=64)

# Acceptance checks
grid = jnp.linspace(1e-6, 1.0-1e-6, 401)
A_res_inf = jnp.max(jnp.abs(generator_residual(model, p, grid)))
s_fd, f_fd = fd_bvp_solution(p)
f_net = f_apply(model, p, s_fd)
B_bvp_err = jnp.max(jnp.abs(f_net - f_fd))
f_cf  = f_closed_form_sym(s_fd, p.delta)
C_cf_err = jnp.max(jnp.abs(f_net - f_cf))
D_range_ok = (jnp.min(f_net) >= -1e-6) & (jnp.max(f_net) <= 1.0/p.delta + 1e-6)
sym_err = jnp.max(jnp.abs(f_apply(model, p, s_fd) + f_apply(model, p, 1.0 - s_fd) - 1.0/p.delta))
print({"A_res_inf": float(A_res_inf),
       "B_bvp_err": float(B_bvp_err),
       "C_cf_err": float(C_cf_err),
       "D_range_ok": bool(D_range_ok),
       "sym_err": float(sym_err)})

{'A_res_inf': 0.1954662771478498, 'B_bvp_err': 2.569036144584988, 'C_cf_err': 2.569056022316385, 'D_range_ok': True, 'sym_err': 1.1751915472856815}


In [17]:
# Quick shape diagnostics (small sizes)
pp = Params(delta=0.04, mu1=0.02, mu2=0.02, sig1=0.2, sig2=0.2, rho=0.0)
S, dW, dt = sim_paths(jax.random.PRNGKey(123), pp, n_paths=4, n_steps=5, T=1.0)
Y    = f_apply(model, pp, S)
fp   = fprime_batch(model, pp, S)
Z    = fp * sig_s(S, pp)
S_n  = jnp.concatenate([S[..., 1:], S[..., -1:]], axis=-1)
Y_n  = jnp.concatenate([Y[..., 1:], Y[..., -1:]], axis=-1)
fp_n = fprime_batch(model, pp, S_n)
Z_n  = fp_n * sig_s(S_n, pp)
print({
    'S': S.shape,
    'dW': dW.shape,
    'Y': Y.shape,
    'fp': fp.shape,
    'Z': Z.shape,
    'S_n': S_n.shape,
    'Y_n': Y_n.shape,
    'Z_n': Z_n.shape,
})

{'S': (8, 5), 'dW': (8, 5), 'Y': (8, 5), 'fp': (8, 5), 'Z': (8, 5), 'S_n': (8, 5), 'Y_n': (8, 5), 'Z_n': (8, 5)}
