# Multicountry Model — Probab_01 (Table 1)

This notebook wires a multicountry (vector) continuous-time model to the BSDE solver.
It is a template to reproduce Table 1 in `Probab_01.pdf`. Provide the calibration
from the paper to match the reported results.


In [None]:
# Visual style: accessible color cycle, high-DPI, subtle grids
try:
    from bsde_dsgE.utils.nb_style import apply_notebook_style
    apply_notebook_style()
except Exception as e:
    print('Style setup skipped:', e)


## Math → Code map

Equations (19)–(22) from `Tex/Model.tex` define the macro BSDE: (19) analytic symmetric-state `q = (a·ψ+1)/(ρ·ψ+1)`, (20) BSDE driver `h`, (21) η drift/vol, and (22) ζ drift/vol. The code mapping lives in `bsde_dsgE/models/probab01_equations.py::compute_dynamics`, and the symmetric analytic check is available via `q_symmetric_analytic(a, ψ, ρ)`.

For forward path visuals (mean±2SE, rolling correlations) we simulate a simple vector SDE separately from the BSDE training (see `bsde_dsgE/models/multicountry.py`). The deep‑solver/driver uses the paper’s `h`, drifts, and diffusions; symmetric‑state `q` and `σ_q` values come from the paper’s Table 1 (parsed into JSON/TeX helpers). Terminal conditions for the deep solver follow the paper’s scheme (no explicit terminal is required for the backward‑Euler regression variant).

### Equations (19)–(22) linkage

The deep-solver variant follows the Try.md equations and `Tex/Model.tex` labels: (19) symmetric-state q, (20) BSDE driver h, (21) η drift/vol, (22) ζ drift/vol. See `bsde_dsgE/models/probab01_equations.py` for the code mapping and `scripts/compare_table1_solver.py --from-tex` to compare against the symmetric-state table parsed directly from LaTeX.

In [None]:
import os, math
import jax, jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from bsde_dsgE.models.multicountry import multicountry_probab01
from bsde_dsgE.core import load_solver_nd
from bsde_dsgE.utils.calibration import load_probab01_calibration
from bsde_dsgE.metrics.table1 import summary_stats, compare_to_targets

FAST = bool(os.environ.get('NOTEBOOK_FAST', ''))
key = jax.random.PRNGKey(0)
print({'FAST': FAST, 'jax': jax.__version__, 'device': jax.default_backend()})


In [None]:
# Parameters — load from calibration file if available, else defaults
calib_path = Path('data/probab01_table1.json')
if calib_path.exists():
    calib = load_probab01_calibration(calib_path)
    dim, rho, gamma, kappa, theta, sigma = (calib.dim, calib.rho, calib.gamma, calib.kappa, calib.theta, calib.sigma)
else:
    dim = 2  # countries
    rho = 0.05
    gamma = 8.0
    kappa = 0.2
    theta = 1.0
    sigma = 0.3
t0, t1 = 0.0, 1.0
dt = 0.05 if FAST else 0.02
batch = 16 if FAST else 64
depth, width = (2, 32) if FAST else (4, 128)

prob = multicountry_probab01(dim=dim, rho=rho, gamma=gamma, kappa=kappa, theta=theta, sigma=sigma, t0=t0, t1=t1)
solver = load_solver_nd(prob, dim=dim, dt=dt, depth=depth, width=width)
x0 = jnp.ones((batch, dim)) * theta
loss = solver(x0, key)
float(loss)


## Diffusion heatmap

Visualise the diffusion used: either diagonal `σ` or full `Σ` if provided in calibration.

In [None]:
# Build diffusion matrix for display
import numpy as np
if isinstance(sigma, (list, tuple)):
    diag = np.array(sigma, dtype=float)
    Sigma_disp = np.diag(diag)
elif isinstance(sigma, (float, int)):
    Sigma_disp = np.diag(np.ones(dim) * float(sigma))
else:
    Sigma_disp = np.diag(np.ones(dim))
figH, axH = plt.subplots(figsize=(3.5,3))
im = axH.imshow(Sigma_disp, cmap='viridis', aspect='equal')
cb = figH.colorbar(im, ax=axH); cb.set_label('σ (diag)')
axH.set_title('Diffusion (diag σ) display')
axH.set_xlabel('dim'); axH.set_ylabel('dim')
figH.tight_layout(); figH


## Sample paths and diagnostics

Below we simulate a few forward steps using the `BSDEProblem.step` function
to visualise the state evolution. This is for intuition only; the solver
itself learns the backward components (Y, Z).

In [None]:
# Simulate a few forward steps to visualise state paths
N = int((t1 - t0)/dt)
x = x0[:4]  # 4 sample paths
ts = [t0]
xs = [np.array(x)]
for i in range(N):
    t = t0 + i*dt
    dW = jax.random.normal(jax.random.fold_in(key, i), x.shape) * math.sqrt(dt)
    x = prob.step(x, t, dt, dW)
    ts.append(t+dt); xs.append(np.array(x))
xs = np.stack(xs, axis=0)  # (N+1, paths, dim)
fig, ax = plt.subplots(1, dim, figsize=(4*dim, 3), sharex=True, sharey=True)
ax = np.atleast_1d(ax)
for d in range(dim):
    for p in range(xs.shape[1]):
        ax[d].plot(ts, xs[:, p, d], alpha=0.7)
    ax[d].set_title(f'Country {d+1}')
    ax[d].set_xlabel('t'); ax[d].set_ylabel('state')
fig.suptitle('State sample paths')
fig.tight_layout()
fig


## Impulse responses (deterministic)

Set noise to zero to see mean reversion dynamics after a small one-time shock at t=0.

In [None]:
shock = jnp.zeros((4, dim)); shock = shock.at[:,0].set(0.2)  # shock first country
x_irf = x0[:4] + shock
ts_irf = [t0]; xs_irf = [np.array(x_irf)]
for i in range(N):
    t = t0 + i*dt
    dW0 = jnp.zeros_like(x_irf)
    x_irf = prob.step(x_irf, t, dt, dW0)
    ts_irf.append(t+dt); xs_irf.append(np.array(x_irf))
xs_irf = np.stack(xs_irf, axis=0)
figI, axI = plt.subplots(1, dim, figsize=(4*dim, 3), sharex=True, sharey=True)
axI = np.atleast_1d(axI)
for d in range(dim):
    for p in range(xs_irf.shape[1]):
        axI[d].plot(ts_irf, xs_irf[:, p, d], alpha=0.8)
    axI[d].axhline(float(theta), color='k', lw=1, ls='--', alpha=0.5)
    axI[d].set_title(f'IRF: Country {d+1}')
figI.suptitle('Impulse responses (noise off)'); figI.tight_layout(); figI


## Animation (optional)

In [None]:
# Matplotlib animation of evolving states (display as HTML in VS Code/Jupyter)
from matplotlib import animation

fig2, ax2 = plt.subplots(1, dim, figsize=(4*dim, 3), sharex=True, sharey=True)
ax2 = np.atleast_1d(ax2)
lines = [[ax2[d].plot([], [], lw=2)[0] for _ in range(xs.shape[1])] for d in range(dim)]
for d in range(dim):
    ax2[d].set_xlim(ts[0], ts[-1])
    ax2[d].set_ylim(xs.min()*1.1, xs.max()*1.1)
    ax2[d].set_title(f'Country {d+1}')

def init():
    for d in range(dim):
        for p in range(xs.shape[1]):
            lines[d][p].set_data([], [])
    return sum(lines, [])

def animate(i):
    for d in range(dim):
        for p in range(xs.shape[1]):
            lines[d][p].set_data(ts[:i+1], xs[:i+1, p, d])
    return sum(lines, [])

ani = animation.FuncAnimation(fig2, animate, init_func=init, frames=N+1, interval=100, blit=True)
ani.to_jshtml()


## Table 1 reproduction hooks

Replace the placeholders below with moment computations and assert-allclose
against the Table 1 values once provided.

In [None]:
# Compute summary statistics and optionally compare to targets
stats = summary_stats(xs)
print({k: (v.tolist() if hasattr(v, 'tolist') else v) for k, v in stats.items() if k in ('mean','std')})

# If Table 1 targets are provided in calibration, compare and highlight diffs
if calib_path.exists() and calib.table1_targets and 'examples' in calib.table1_targets:
    targets = calib.table1_targets['examples']
    res = compare_to_targets(stats, targets)
    print('Comparison:', res)
    # TODO: extend to the actual Table 1 statistics once provided


In [None]:
# Optional strict check for CI: set STRICT_TABLE1=1 to enforce passing
import os
if os.environ.get('STRICT_TABLE1','') and calib_path.exists() and calib.table1_targets and 'examples' in calib.table1_targets:
    res = compare_to_targets(stats, calib.table1_targets['examples'])
    assert res.get('all_ok', False), f'Table 1 comparison failed: {res}'


## Paper Table 1 (symmetric states)

We load the transcribed values for Table 1 and display `q_i` and `sigma_{q,i,j}` under symmetric states (η, ζ).

In [None]:
import json
cal = json.loads(calib_path.read_text()) if calib_path.exists() else {}
paper = cal.get('table1_values', {})
sym = paper.get('symmetric_states', [])
# Optional country names for J (sigma_q heatmaps)
namesJ = cal.get('country_names_J') or cal.get('probab01_params', {}).get('country_names', [])
for st in sym:
    print('eta=', st['eta'], ', zeta=', st['zeta'])
    print('q:', st['q'])
    # show sigma heatmap (diverging cmap, symmetric about zero)
    mat = np.array(st['sigma_q'])
    mabs = float(np.max(np.abs(mat))) if mat.size else 1.0
    figS, axS = plt.subplots(figsize=(3.6,3.2))
    im = axS.imshow(mat, cmap='coolwarm', vmin=-mabs, vmax=mabs)
    axS.set_aspect('equal')
    cb = figS.colorbar(im, ax=axS); cb.set_label('σ_q')
    axS.set_title('σ_{q,i,j} (symmetric scale)')
    axS.set_xlabel('j'); axS.set_ylabel('i')
    axS.set_xticks(range(len(mat))); axS.set_yticks(range(len(mat)))
    labels = namesJ if isinstance(namesJ, list) and len(namesJ)==len(mat) else [f'C{j+1}' for j in range(len(mat))]
    axS.set_xticklabels(labels); axS.set_yticklabels(labels)
    figS.tight_layout()
    plt.show()


## Mean ± 2SE (by country)

Naive error bars for the per-dimension sample mean using 2×standard error (ignoring serial correlation). This is a quick diagnostic, not an inference statement.

In [None]:
import numpy as np
flat = xs.reshape(xs.shape[0]*xs.shape[1], xs.shape[2])
m = flat.mean(axis=0)
s = flat.std(axis=0, ddof=1)
se = s / np.sqrt(flat.shape[0])
# Optional names for dim from calibration
names_dim = cal.get('country_names_dim') or cal.get('country_names', [])
labels = names_dim if isinstance(names_dim, list) and len(names_dim)==dim else [f'C{j+1}' for j in range(dim)]
if dim > 6:
    figM, axM = plt.subplots(figsize=(2.2*max(6, dim/1.2), 0.35*dim + 2.5))
    ypos = np.arange(dim)
    bars = axM.barh(ypos, m, xerr=2*se, alpha=0.85, error_kw={'capsize':3})
    axM.set_yticks(ypos); axM.set_yticklabels(labels)
    axM.set_title('Mean ± 2SE by country')
    axM.set_xlabel('value'); axM.set_ylabel('country')
    xmax = float(np.max(m + 2*se)) if m.size else 1.0
    for i, b in enumerate(bars):
        x = b.get_width(); axM.annotate(f'{m[i]:.2f} ± {2*se[i]:.2f}',
            xy=(x, b.get_y() + b.get_height()/2), xytext=(4, 0), textcoords='offset points',
            ha='left', va='center', fontsize=9)
    axM.set_xlim(right=1.1*xmax)
else:
    figM, axM = plt.subplots(figsize=(1.2*max(6, dim), 3.0))
    xpos = np.arange(dim)
    bars = axM.bar(xpos, m, yerr=2*se, alpha=0.85, error_kw={'capsize':3})
    axM.set_title('Mean ± 2SE by country')
    axM.set_xlabel('country'); axM.set_ylabel('value')
    axM.set_xticks(xpos); axM.set_xticklabels(labels)
    ymax = float(np.max(m + 2*se)) if m.size else 1.0
    for i, b in enumerate(bars):
        y = b.get_height(); axM.annotate(f'{m[i]:.2f} ± {2*se[i]:.2f}',
            xy=(b.get_x() + b.get_width()/2, y), xytext=(0, 4), textcoords='offset points',
            ha='center', va='bottom', fontsize=9)
    axM.set_ylim(top=1.1*ymax)
figM.tight_layout(); plt.show()


## Rolling Correlation Heatmap (animation)

Compute a rolling correlation of states across countries and animate it over time windows to visualise co-movement dynamics.

In [None]:
from matplotlib import animation as _anim
def _corr_from_block(block):
    B = block.reshape(-1, block.shape[-1])
    cov = np.cov(B, rowvar=False)
    sd = np.sqrt(np.clip(np.diag(cov), 1e-12, None))
    denom = np.outer(sd, sd)
    return np.where(denom>0, cov/denom, 0.0)
W = min(20, xs.shape[0]-1)
frames = min(50, xs.shape[0]-W) if FAST else xs.shape[0]-W
figC, axC = plt.subplots(figsize=(3.5,3.0))
img = axC.imshow(_corr_from_block(xs[:W]), vmin=-1, vmax=1, cmap='coolwarm')
axC.set_aspect('equal')
cb = figC.colorbar(img, ax=axC); cb.set_label('corr')
# correlation labels
D = _corr_from_block(xs[:W]).shape[0]
ticks = range(D)
names_dim = cal.get('country_names_dim') or cal.get('country_names', [])
labs = names_dim if isinstance(names_dim, list) and len(names_dim)==D else [f'C{j+1}' for j in ticks]
axC.set_xticks(ticks); axC.set_xticklabels(labs)
axC.set_yticks(ticks); axC.set_yticklabels(labs)
axC.set_title('Rolling corr (window={})'.format(W))
def _init():
    img.set_data(_corr_from_block(xs[:W])); return (img,)
def _animate(i):
    block = xs[i:i+W]
    img.set_data(_corr_from_block(block))
    return (img,)
aniC = _anim.FuncAnimation(figC, _animate, init_func=_init, frames=frames, interval=120, blit=True)
aniC.to_jshtml()


## Diagnostics via utils/figures
We use helper functions to simulate paths (optionally Sobol), compute mean ± 2SE, and rolling correlations.

In [None]:
from bsde_dsgE.utils.figures import simulate_paths, mean_and_2se, rolling_corr, impulse_response
steps = 20 if FAST else 100
sim = simulate_paths(prob, x0, steps=steps, dt=dt, key=key)
xs = sim.xs  # (T+1, P, dim)
m, e2 = mean_and_2se(xs)
print({'mean': m.tolist(), '2SE': e2.tolist()})
W = min(10, xs.shape[0]-1) if FAST else min(30, xs.shape[0]-1)
corrs = rolling_corr(xs, window=W)
corrs.shape


In [None]:
# Impulse response on state 0
mb, ms, irf = impulse_response(prob, x0, steps=steps, dt=dt, shock_dim=0, shock_size=0.1, key=key)
fig, ax = plt.subplots(figsize=(4, 2.5))
ax.plot(irf[:, 0], label='IRF dim0'); ax.legend(); fig.tight_layout(); plt.show()


## Analytic q (eq. 19) sanity check

For symmetric states (ζ_j = 1/J), eq. (19) implies\n\nq = (a·ψ + 1) / (ρ·ψ + 1).\n\nUse `bsde_dsgE.models.probab01_equations.q_symmetric_analytic(a, ψ, ρ)` to compute this reference value and compare against the displayed `table1_values.q` block. For a direct comparison against LaTeX-parsed values, see the CLI: `python scripts/compare_table1_solver.py --calib data/probab01_table1.json --from-tex`.