# Train MacroFinanceNet — Demo (FAST)

This demo runs a small training loop (few epochs, small batch) to show the loss decreasing and prints symmetric-state q.

In [None]:
import os, time, json
import numpy as np
import jax, jax.numpy as jnp
import equinox as eqx
import optax
import matplotlib.pyplot as plt
from bsde_dsgE.models.macro_solver import Config as NetCfg, MacroFinanceNet, evaluate_symmetric
from bsde_dsgE.models.probab01_equations import Config as EqCfg, compute_dynamics, q_symmetric_analytic
FAST = True
J=5; steps = 8 if FAST else 64; paths = 512 if FAST else 4096; dt = 0.001
net_cfg = NetCfg(J=J); eq_cfg = EqCfg(J=J)
key = jax.random.PRNGKey(0)
model = MacroFinanceNet(net_cfg, key)
opt = optax.adam(1e-4)
opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))
def batch_loss(model, key):
    B = paths
    key_eta, key_dir, key_dW = jax.random.split(key, 3)
    eta = jax.random.uniform(key_eta, (B, eq_cfg.N_ETA), minval=0.2, maxval=0.8)
    raw = jax.random.uniform(key_dir, (B, eq_cfg.J)); z_full = raw / jnp.sum(raw, axis=1, keepdims=True)
    z = z_full[:, : eq_cfg.N_ZETA]; Om0 = jnp.hstack([eta, z])
    dWs = jax.random.normal(key_dW, (steps, B, eq_cfg.J)) * jnp.sqrt(dt)
    def scan_fn(carry, dW_i):
        Om = carry
        q, s, r = model(Om)
        drift, vol, h, Z = compute_dynamics(eq_cfg, Om, q, s, r)
        Om1 = Om + jnp.einsum('bij,bi->bj', vol, dW_i) + drift * dt
        q1 = q - h * dt + jnp.einsum('bij,bi->bj', Z, dW_i)
        qh, _, _ = model(Om1)
        l = jnp.mean(jnp.sum((qh - q1)**2, axis=1))
        return Om1, l
    _, losses = jax.lax.scan(scan_fn, Om0, dWs)
    return jnp.mean(losses)
@eqx.filter_jit
def train_step(model, opt_state, key):
    loss, grads = eqx.filter_value_and_grad(batch_loss)(model, key)
    updates, opt_state = opt.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss
loss_hist = []
print({'analytic_q': q_symmetric_analytic(a=0.1, psi=5.0, rho=0.03)})
for ep in range(1, (200 if FAST else 1000)+1):
    key, k = jax.random.split(key)
    model, opt_state, loss = train_step(model, opt_state, k)
    loss_hist.append(float(loss))
    if ep % (50 if FAST else 100) == 0: print({'ep': ep, 'loss': float(loss)})
etas = (0.3,0.4,0.5,0.6,0.7)
q, s, r = evaluate_symmetric(net_cfg, model, etas)
print('q@symmetric (demo):', np.array(q))
# Plot loss history
fig, ax = plt.subplots(figsize=(5,3))
ax.plot(range(1, len(loss_hist)+1), loss_hist)
ax.set_xlabel('epoch'); ax.set_ylabel('loss'); ax.set_title('Training loss (demo)')
fig.tight_layout(); fig
