In [None]:
import math
import time
import copy
from typing import Dict, Tuple, List
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import matplotlib.pyplot as plt


# ============================================================
# 0) Matplotlib / Font Settings
# ============================================================
plt.rcParams.update({
    "pdf.fonttype": 42,
    "font.family": "serif",
    "font.serif": ["Liberation Serif", "FreeSerif", "serif"],
    "font.size": 10,
    "axes.labelsize": 10,
    "legend.fontsize": 9,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "mathtext.fontset": "stix",
})


# ============================================================
# 1) Unified Configuration
# ============================================================
CFG: Dict = {
    # reproducibility / device / dtype
    "seed": 42,
    "device": "cuda:2" if torch.cuda.is_available() else "cpu",
    "dtype": torch.float32,

    # simulation grid
    "T": 2,
    "N": 20,

    # delay
    "tau": 1.0,

    # epidemiology params
    "Lambda": 2.0,
    "beta": 0.2,
    "alpha": 0.1,
    "gamma": 0.2,
    "mu": 0.1,
    "p": 0.6,

    # cost weights & bounds
    "k1": 5.0,
    "k2": 5.0,
    "k3": 50.0,
    "k4": 50.0,
    "u_min": 0.0,
    "u_max": 1.0,

    # initial state
    "S0": 100.0,
    "I0": 20.0,
    "C0": 10.0,
    "R0": 20.0,

    # noise intensity (diag)
    "eta": [0.1, 0.1, 0.1, 0.1],

    # --- FBSM benchmark ---
    "FBSM_TOL": 1e-7,
    "FBSM_LR": 0.05,
    "FBSM_MAX_ITER": 10000,

    # --- PG-DPO training ---
    "PG_iter": 5000,
    "pg_lr": 1e-4,
    "pg_clip_grad": 1.0,
    "pg_hidden": 128,
    "pg_sched_step": 2000000,
    "pg_sched_gamma": 0.2,
    "pg_batch_size": 256,

    # --- Stepwise MPC ---
    "mpc_num_mc": 128,
    "mpc_seed_base": 42,

    # --- FBSDE  ---
    "fbsde_seed": 42,
    "fbsde_batch": 256,
    "fbsde_epochs": 30000,
    "fbsde_lr": 1e-4,
    "fbsde_step": 100000,
    "fbsde_hidden": 96,

    # FBSDE loss weights
    "w_L1": 1.0,
    "w_terminal": 1.0,
    "w_L2": 0.4,

    # curriculum & EMA target net
    "L2_warmup_epochs": 800,
    "L2_ramp_epochs": 1200,
    "ema_decay": 0.995,

    # --- PPO ---
    "ppo_hidden": 64,
    "ppo_batch": 256,        # num_envs per update
    "ppo_episodes": 512000,     # # of total episodes
    "ppo_epochs": 10,         # PPO epochs per update
    "ppo_clip_eps": 0.2,
    "ppo_lr": 5e-5,
    "ppo_gamma_rl": 1.0,
    "ppo_gae_lambda": 0.95,
    "ppo_entropy_coef": 0.01,
    "ppo_max_grad_norm": 0.5,
    "ppo_minibatch": 1024,
}


# ============================================================
# 2) Global device / dtype / derived params
# ============================================================
DEVICE = torch.device(CFG["device"])
torch.set_default_dtype(CFG["dtype"])


def set_seed(seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def compute_derived_cfg(cfg: Dict) -> None:
    cfg["dt"] = cfg["T"] / cfg["N"]
    cfg["K_delay"] = int(cfg["tau"] / cfg["dt"])  # keep original rule (floor)
    cfg["N_pop"] = cfg["S0"] + cfg["I0"] + cfg["C0"] + cfg["R0"]


compute_derived_cfg(CFG)
print(f"Using device: {DEVICE}")


# ============================================================
# 3) Shared physics helpers (delay / dynamics / cost)
# ============================================================
def get_delayed_val_np(arr: np.ndarray, t_idx: int, k_delay: int, init_val: float) -> float:
    j = t_idx - k_delay
    return init_val if j < 0 else arr[j]


def to_numpy_noise(noise_tensor, N: int) -> np.ndarray:
    if torch.is_tensor(noise_tensor):
        noise_np = noise_tensor.detach().cpu().numpy()
    else:
        noise_np = np.array(noise_tensor)
    if noise_np.ndim == 3 and noise_np.shape[1] == 1:
        noise_np = noise_np.squeeze(1)
    if noise_np.shape[0] != N:
        raise ValueError(f"Noise length {noise_np.shape[0]} does not match N={N}.")
    return noise_np


def running_cost_torch(cfg: Dict, r: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
    I = r[:, 1]
    C = r[:, 2]
    return (cfg["k1"] * I + cfg["k2"] * C
            + 0.5 * cfg["k3"] * u[:, 0] ** 2
            + 0.5 * cfg["k4"] * u[:, 1] ** 2)


def step_state_stoch(
    cfg: Dict,
    r: torch.Tensor,
    r_hist_stack: torch.Tensor,
    u: torch.Tensor,
    w: torch.Tensor,
    eta_vec: torch.Tensor
) -> torch.Tensor:
    S, I, C, R = r.unbind(-1)

    hist_len = r_hist_stack.size(0)
    idx = max(0, hist_len - 1 - cfg["K_delay"])
    I_tau = r_hist_stack[idx, :, 1]
    C_tau = r_hist_stack[idx, :, 2]

    u1t, u2t = u[:, 0], u[:, 1]
    inc = cfg["beta"] * S * (I_tau + C_tau) * (1 - u1t)

    dS = cfg["Lambda"] - inc - cfg["alpha"] * S
    dI = inc - (cfg["alpha"] + cfg["gamma"] + u2t) * I
    dC = (cfg["p"] * cfg["gamma"]) * I - (cfg["alpha"] + cfg["mu"]) * C
    dR = ((1 - cfg["p"]) * cfg["gamma"] + u2t) * I - cfg["alpha"] * R

    drift = torch.stack([dS, dI, dC, dR], dim=-1)
    diffusion = eta_vec[None, :] * r

    r_next = r + cfg["dt"] * drift + math.sqrt(cfg["dt"]) * diffusion * w
    return torch.clamp(r_next, min=0.0, max=cfg["N_pop"] * 1.5)


def cumulative_objective_np(cfg: Dict, I_traj, C_traj, u1_traj, u2_traj) -> np.ndarray:
    stage = (cfg["k1"] * I_traj[:-1] + cfg["k2"] * C_traj[:-1]
             + 0.5 * cfg["k3"] * (u1_traj ** 2) + 0.5 * cfg["k4"] * (u2_traj ** 2))
    return (stage * cfg["dt"]).cumsum()


# ============================================================
# 4) Plotting / Metrics
# ============================================================
def print_metrics(gt: np.ndarray, preds: Dict[str, np.ndarray], name: str) -> None:
    print(f"\n[{name}]")
    print(f"{'Algorithm':<15} | {'RMSE':<12} | {'MAE':<12}")
    print("-" * 45)
    for alg, pr in preds.items():
        m = min(len(gt), len(pr))
        rmse = float(np.sqrt(np.mean((pr[:m] - gt[:m]) ** 2)))
        mae = float(np.mean(np.abs(pr[:m] - gt[:m])))
        print(f"{alg:<15} | {rmse:<12.6f} | {mae:<12.6f}")


def plot_series(
    t: np.ndarray,
    gt: np.ndarray,
    series: Dict[str, Tuple[np.ndarray, str]],
    title: str,
    xlabel: str,
    ylabel: str,
    filename_pdf: str,
    ylim: Tuple[float, float] = None,
    ppo: bool = True,
    absde: bool = True,
) -> None:
    plt.figure(figsize=(7, 5))
    plt.plot(t, gt, "k-", lw=2, label="Benchmark")

    for name, (y, style) in series.items():
        name_upper = name.upper()
        if ("PPO" in name_upper) and (not ppo):
            continue
        if ("ABSDE" in name_upper) and (not absde):
            continue
        plt.plot(t, y, style, label=name)

    #plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    if ylim is not None:
        plt.ylim(ylim)
    plt.grid(alpha=0.1)
    plt.legend()
    plt.tight_layout()

    stem = filename_pdf[:-4] if filename_pdf.lower().endswith(".pdf") else filename_pdf
    tag = ""
    tag += "_ppo" if ppo else "_noPPO"
    tag += "_absde" if absde else "_noABSDE"
    out_pdf = f"{stem}{tag}.pdf"

    #plt.savefig(out_pdf, bbox_inches="tight")
    plt.show()


# ============================================================
# 5) Benchmark: FBSM
# ============================================================
def stochastic_milstein_forward_np(cfg: Dict, u1: np.ndarray, u2: np.ndarray, noise_W: np.ndarray):
    N = cfg["N"]
    dt = cfg["dt"]
    K = cfg["K_delay"]
    eta = cfg["eta"]

    S = np.zeros(N + 1); I = np.zeros(N + 1); C = np.zeros(N + 1); R = np.zeros(N + 1)
    S[0], I[0], C[0], R[0] = cfg["S0"], cfg["I0"], cfg["C0"], cfg["R0"]
    sqrt_dt = np.sqrt(dt)

    for k in range(N):
        I_tau = get_delayed_val_np(I, k, K, cfg["I0"])
        C_tau = get_delayed_val_np(C, k, K, cfg["C0"])
        xi = noise_W[k]  # (4,)

        incidence = cfg["beta"] * S[k] * (I_tau + C_tau) * (1 - u1[k])

        drift_S = cfg["Lambda"] - incidence - cfg["alpha"] * S[k]
        drift_I = incidence - (cfg["alpha"] + cfg["gamma"] + u2[k]) * I[k]
        drift_C = (cfg["p"] * cfg["gamma"]) * I[k] - (cfg["alpha"] + cfg["mu"]) * C[k]
        drift_R = ((1 - cfg["p"]) * cfg["gamma"] + u2[k]) * I[k] - cfg["alpha"] * R[k]

        diff_S = eta[0] * S[k]
        corr_S = 0.5 * (eta[0] ** 2) * S[k] * (xi[0] ** 2 - 1) * dt
        S[k + 1] = S[k] + drift_S * dt + diff_S * xi[0] * sqrt_dt + corr_S

        diff_I = eta[1] * I[k]
        corr_I = 0.5 * (eta[1] ** 2) * I[k] * (xi[1] ** 2 - 1) * dt
        I[k + 1] = I[k] + drift_I * dt + diff_I * xi[1] * sqrt_dt + corr_I

        diff_C = eta[2] * C[k]
        corr_C = 0.5 * (eta[2] ** 2) * C[k] * (xi[2] ** 2 - 1) * dt
        C[k + 1] = C[k] + drift_C * dt + diff_C * xi[2] * sqrt_dt + corr_C

        diff_R = eta[3] * R[k]
        corr_R = 0.5 * (eta[3] ** 2) * R[k] * (xi[3] ** 2 - 1) * dt
        R[k + 1] = R[k] + drift_R * dt + diff_R * xi[3] * sqrt_dt + corr_R

        S[k + 1] = max(0, S[k + 1]); I[k + 1] = max(0, I[k + 1])
        C[k + 1] = max(0, C[k + 1]); R[k + 1] = max(0, R[k + 1])

    return S, I, C, R


def solve_adjoint_deterministic_np(cfg: Dict, S, I, C, R, u1, u2):
    N = cfg["N"]
    dt = cfg["dt"]
    K = cfg["K_delay"]

    c1 = np.zeros(N + 1); c2 = np.zeros(N + 1); c3 = np.zeros(N + 1); c4 = np.zeros(N + 1)
    c1[-1] = 0; c2[-1] = 0; c3[-1] = 0; c4[-1] = 0

    for k in range(N - 1, -1, -1):
        Sk, Ik = S[k], I[k]
        Itau = get_delayed_val_np(I, k, K, cfg["I0"])
        Ctau = get_delayed_val_np(C, k, K, cfg["C0"])

        dc1 = (c1[k + 1] - c2[k + 1]) * (cfg["beta"] * (Itau + Ctau) * (1 - u1[k])) + c1[k + 1] * cfg["alpha"]

        dc2 = (-cfg["k1"]
               + (c1[k + 1] - c2[k + 1]) * cfg["beta"] * Sk * (1 - u1[k])
               + c2[k + 1] * (cfg["alpha"] + cfg["gamma"] + u2[k])
               - c3[k + 1] * (cfg["p"] * cfg["gamma"])
               - c4[k + 1] * ((1 - cfg["p"]) * cfg["gamma"] + u2[k]))

        dc3 = (-cfg["k2"]
               + (c1[k + 1] - c2[k + 1]) * cfg["beta"] * Sk * (1 - u1[k])
               + c3[k + 1] * (cfg["alpha"] + cfg["mu"]))

        dc4 = c4[k + 1] * cfg["alpha"]

        c1[k] = c1[k + 1] - dc1 * dt
        c2[k] = c2[k + 1] - dc2 * dt
        c3[k] = c3[k + 1] - dc3 * dt
        c4[k] = c4[k + 1] - dc4 * dt

    return c1, c2, c3, c4


def solve_fbsm_benchmark(cfg: Dict, noise_input, verbose: bool = True):
    noise_W = to_numpy_noise(noise_input, cfg["N"])
    N = cfg["N"]

    u1 = np.full(N + 1, 0.5)
    u2 = np.full(N + 1, 0.5)

    if verbose:
        print(f"[FBSM] Start. Noise shape={noise_W.shape}")

    for it in range(cfg["FBSM_MAX_ITER"]):
        S, I, C, R = stochastic_milstein_forward_np(cfg, u1, u2, noise_W)
        c1, c2, c3, c4 = solve_adjoint_deterministic_np(cfg, S, I, C, R, u1, u2)

        old_u1 = u1.copy()
        old_u2 = u2.copy()

        for k in range(N + 1):
            Itau = get_delayed_val_np(I, k, cfg["K_delay"], cfg["I0"])
            Ctau = get_delayed_val_np(C, k, cfg["K_delay"], cfg["C0"])

            switching_fn1 = (c2[k] - c1[k]) * S[k] * (Itau + Ctau)
            u1_star = np.clip(switching_fn1 / cfg["k3"], cfg["u_min"], cfg["u_max"])

            switching_fn2 = (c2[k] - c4[k]) * I[k]
            u2_star = np.clip(switching_fn2 / cfg["k4"], cfg["u_min"], cfg["u_max"])

            u1[k] = (1 - cfg["FBSM_LR"]) * old_u1[k] + cfg["FBSM_LR"] * u1_star
            u2[k] = (1 - cfg["FBSM_LR"]) * old_u2[k] + cfg["FBSM_LR"] * u2_star

        diff = np.mean(np.abs(u1 - old_u1)) + np.mean(np.abs(u2 - old_u2))

        if diff < cfg["FBSM_TOL"]:
            if verbose:
                print(f"[FBSM] Converged at iter={it}, diff={diff:.6e}")
            break

        if verbose and it % 100 == 0:
            print(f"[FBSM] Iter={it}, diff={diff:.6e}")

    return u1, u2, S, I, C, R


# ============================================================
# 6) PG-DPO policy (LSTM)
# ============================================================
class LSTMPolicy(nn.Module):
    def __init__(self, hidden: int = 64):
        super().__init__()
        self.hidden = hidden
        self.lstm = nn.LSTMCell(4, hidden)
        self.head = nn.Sequential(
            nn.Linear(hidden, 128),
            nn.SiLU(),
            nn.Linear(128, 2),
            nn.Sigmoid()
        )
        nn.init.xavier_uniform_(self.head[2].weight, gain=0.01)
        nn.init.constant_(self.head[2].bias, -0.5)

    def init_hidden(self, B: int, device: torch.device):
        return (torch.zeros(B, self.hidden, device=device),
                torch.zeros(B, self.hidden, device=device))

    def forward(self, r_t: torch.Tensor, hc, N_pop: float):
        r_norm = r_t / N_pop
        h, c = self.lstm(r_norm, hc)
        u = self.head(h)
        return u, (h, c)


def train_pg_dpo(cfg: Dict, policy: nn.Module) -> List[float]:
    print("[PG-DPO] Training...")
    eta_vec = torch.tensor(cfg["eta"], device=DEVICE, dtype=cfg["dtype"])

    pg_bs = int(cfg.get("pg_batch_size", 1))

    opt = optim.Adam(policy.parameters(), lr=cfg["pg_lr"])
    sched = lr_scheduler.StepLR(opt, step_size=cfg["pg_sched_step"], gamma=cfg["pg_sched_gamma"])

    loss_hist: List[float] = []
    for it in range(cfg["PG_iter"]):
        r0 = torch.tensor(
            [[cfg["S0"], cfg["I0"], cfg["C0"], cfg["R0"]]],
            device=DEVICE, dtype=cfg["dtype"]
        ).repeat(pg_bs, 1)

        hc = policy.init_hidden(pg_bs, DEVICE)

        r_curr = r0
        r_hist = [r0]
        loss = 0.0

        for _ in range(cfg["N"]):
            u, hc = policy(r_curr, hc, cfg["N_pop"])

            w = torch.randn(pg_bs, 4, device=DEVICE, dtype=cfg["dtype"])

            rc = running_cost_torch(cfg, r_curr, u)
            if torch.is_tensor(rc) and rc.ndim > 0:
                rc = rc.mean()
            loss = loss + rc * cfg["dt"]

            r_hist_stack = torch.stack(r_hist)  # (t, B, 4)
            r_next = step_state_stoch(cfg, r_curr, r_hist_stack, u, w, eta_vec)

            r_curr = r_next
            r_hist.append(r_next)

        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(policy.parameters(), cfg["pg_clip_grad"])
        opt.step()
        sched.step()

        loss_hist.append(float(loss.item()))
        if it % 500 == 0:
            print(f"[PG-DPO] Iter {it}, Loss {loss.item():.6f}")

    return loss_hist


# ============================================================
# 7) Rollouts: (A) Pure policy, (B) Stepwise MPC
# ============================================================
@torch.no_grad()
def rollout_policy_on_noise(cfg: Dict, policy: nn.Module, noise_data: torch.Tensor):
    eta_vec = torch.tensor(cfg["eta"], device=DEVICE, dtype=cfg["dtype"])
    r_curr = torch.tensor([[cfg["S0"], cfg["I0"], cfg["C0"], cfg["R0"]]], device=DEVICE, dtype=cfg["dtype"])
    hc = policy.init_hidden(1, DEVICE)

    real_history = [r_curr]
    I_traj = [cfg["I0"]]
    C_traj = [cfg["C0"]]
    u1_traj, u2_traj = [], []

    for t in range(cfg["N"]):
        u, hc = policy(r_curr, hc, cfg["N_pop"])
        w_t = noise_data[t]
        r_hist_stack = torch.stack(real_history)
        r_next = step_state_stoch(cfg, r_curr, r_hist_stack, u, w_t, eta_vec)

        r_curr = r_next
        real_history.append(r_curr)

        I_traj.append(float(r_curr[0, 1].item()))
        C_traj.append(float(r_curr[0, 2].item()))
        u1_traj.append(float(u[0, 0].item()))
        u2_traj.append(float(u[0, 1].item()))

    return np.array(I_traj), np.array(C_traj), np.array(u1_traj), np.array(u2_traj)


def run_stepwise_mpc_simulation(cfg: Dict, policy: nn.Module, noise_data: torch.Tensor, num_mc: int, seed_base: int):
    eta_vec = torch.tensor(cfg["eta"], device=DEVICE, dtype=cfg["dtype"])
    r_curr = torch.tensor([[cfg["S0"], cfg["I0"], cfg["C0"], cfg["R0"]]], device=DEVICE, dtype=cfg["dtype"])
    hc_real = policy.init_hidden(1, DEVICE)
    real_history = [r_curr]

    I_traj = [cfg["I0"]]
    C_traj = [cfg["C0"]]
    u1_traj, u2_traj = [], []

    print(f"[MPC] Running Stepwise MPC (MC={num_mc})...")
    for t in range(cfg["N"]):
        r_now_grad = r_curr.detach().clone().requires_grad_(True)

        gen = torch.Generator(device=DEVICE)
        gen.manual_seed(seed_base + t)

        r_sim = r_now_grad.repeat(num_mc, 1)
        hc_sim = (hc_real[0].repeat(num_mc, 1), hc_real[1].repeat(num_mc, 1))

        future_cost = 0.0
        sim_future_states: List[torch.Tensor] = []

        for k in range(t, cfg["N"]):
            u_sim, hc_sim = policy(r_sim, hc_sim, cfg["N_pop"])
            future_cost = future_cost + running_cost_torch(cfg, r_sim, u_sim).mean() * cfg["dt"]

            idx_tau = k - cfg["K_delay"]
            if idx_tau < t:
                safe_idx = max(0, idx_tau)
                r_delayed = real_history[safe_idx].detach().repeat(num_mc, 1)
            elif idx_tau == t:
                r_delayed = r_now_grad.repeat(num_mc, 1)
            else:
                future_idx = idx_tau - (t + 1)
                r_delayed = sim_future_states[future_idx]

            S_s, I_s, C_s, R_s = r_sim.unbind(-1)
            I_tau, C_tau = r_delayed[:, 1], r_delayed[:, 2]
            u1_s, u2_s = u_sim[:, 0], u_sim[:, 1]

            inc = cfg["beta"] * S_s * (I_tau + C_tau) * (1 - u1_s)

            dS = cfg["Lambda"] - inc - cfg["alpha"] * S_s
            dI = inc - (cfg["alpha"] + cfg["gamma"] + u2_s) * I_s
            dC = (cfg["p"] * cfg["gamma"]) * I_s - (cfg["alpha"] + cfg["mu"]) * C_s
            dR = ((1 - cfg["p"]) * cfg["gamma"] + u2_s) * I_s - cfg["alpha"] * R_s

            drift = torch.stack([dS, dI, dC, dR], dim=-1)

            xi = torch.randn(num_mc, 4, device=DEVICE, generator=gen, dtype=cfg["dtype"])
            diffusion = (eta_vec[None, :] * r_sim) * math.sqrt(cfg["dt"]) * xi

            r_next_sim = torch.clamp(r_sim + cfg["dt"] * drift + diffusion, 0.0, cfg["N_pop"] * 1.5)
            sim_future_states.append(r_next_sim)
            r_sim = r_next_sim

        grads = torch.autograd.grad(future_cost, r_now_grad, create_graph=False)[0]
        lamS, lamI, lamC, lamR = grads[0]

        idx_tau_curr = max(0, t - cfg["K_delay"])
        r_tau_curr = real_history[idx_tau_curr]
        I_tau_curr = r_tau_curr[0, 1]
        C_tau_curr = r_tau_curr[0, 2]

        S_curr_val = r_curr[0, 0]
        I_curr_val = r_curr[0, 1]

        val1 = (lamI - lamS) * cfg["beta"] * S_curr_val * (I_tau_curr + C_tau_curr) / cfg["k3"]
        val2 = (lamI - lamR) * I_curr_val / cfg["k4"]

        u1_star = torch.clamp(val1, cfg["u_min"], cfg["u_max"])
        u2_star = torch.clamp(val2, cfg["u_min"], cfg["u_max"])
        u_star = torch.stack([u1_star, u2_star], dim=0).unsqueeze(0)

        w_t = noise_data[t]
        r_hist_stack = torch.stack(real_history)

        with torch.no_grad():
            _, hc_real = policy(r_curr, hc_real, cfg["N_pop"])

        r_next_real = step_state_stoch(cfg, r_curr, r_hist_stack, u_star, w_t, eta_vec)

        r_curr = r_next_real
        real_history.append(r_curr)

        I_traj.append(float(r_curr[0, 1].item()))
        C_traj.append(float(r_curr[0, 2].item()))
        u1_traj.append(float(u1_star.item()))
        u2_traj.append(float(u2_star.item()))

        if t % 10 == 0:
            print(f"[MPC] Step {t}/{cfg['N']} done.")

    return np.array(I_traj), np.array(C_traj), np.array(u1_traj), np.array(u2_traj)


#============================================================
# 8) Deep ABSDE (FBSDE-Net)
#   - Net outputs (Y, Z, G)
#       G(t_i) ≈ E_t[ ∂_{x_delay} H(t_i + tau) ]  (here delay is (I_tau, C_tau) -> dim=2)
#   - L1: BSDE consistency using tildeY
#   - L2: shift-matching   G_i ≈ (∂_{x_delay}H)_{i+D}
#============================================================
STATE_DIM = 4
BROWN_DIM = 4

class DeepABSDE_Net(nn.Module):
    def __init__(self, hidden: int):
        super().__init__()
        self.hidden = hidden
        self.lstm = nn.LSTMCell(input_size=STATE_DIM + 1, hidden_size=hidden)

        self.head_Y = nn.Sequential(
            nn.Linear(hidden, hidden), nn.Tanh(), nn.Linear(hidden, STATE_DIM)
        )
        self.head_Z = nn.Sequential(
            nn.Linear(hidden, hidden), nn.Tanh(), nn.Linear(hidden, STATE_DIM * BROWN_DIM)
        )
        self.head_G = nn.Sequential(
            nn.Linear(hidden, hidden), nn.Tanh(), nn.Linear(hidden, 2)
        )

    def init_hidden(self, batch: int, device: torch.device):
        return (
            torch.zeros(batch, self.hidden, device=device),
            torch.zeros(batch, self.hidden, device=device),
        )

    def step(self, x, t, hc):
        inp = torch.cat([x, t], dim=1)
        h, c = self.lstm(inp, hc)

        y = self.head_Y(h)
        z = self.head_Z(h).view(-1, STATE_DIM, BROWN_DIM)
        g = self.head_G(h)

        return y, z, g, (h, c)

def build_fbsde_helpers(cfg: Dict):
    D = int(cfg["K_delay"])
    dL_dx = torch.tensor(
        [0.0, cfg["k1"], cfg["k2"], 0.0],
        device=DEVICE, dtype=cfg["dtype"]
    ).view(1, 4)
    eta_vec = torch.tensor(
        cfg["eta"], device=DEVICE, dtype=cfg["dtype"]
    ).view(1, 4)
    return D, dL_dx, eta_vec


def fbsde_drift(cfg: Dict, x, I_delay, C_delay, u1, u2):
    S, I, C, R = x[:, 0:1], x[:, 1:2], x[:, 2:3], x[:, 3:4]
    Qd = I_delay + C_delay

    bS = cfg["Lambda"] - cfg["beta"] * S * Qd * (1.0 - u1) - cfg["alpha"] * S
    bI = cfg["beta"] * S * Qd * (1.0 - u1) - (cfg["alpha"] + cfg["gamma"] + u2) * I
    bC = (cfg["p"] * cfg["gamma"]) * I - (cfg["alpha"] + cfg["mu"]) * C  # u2 removed
    bR = ((1.0 - cfg["p"]) * cfg["gamma"] + u2) * I - cfg["alpha"] * R

    return torch.cat([bS, bI, bC, bR], dim=1)


def fbsde_diffusion(eta_vec, x):
    return eta_vec * x


def fbsde_controls(cfg: Dict, x, I_delay, C_delay, y):
    S, I = x[:, 0:1], x[:, 1:2]
    Qd = I_delay + C_delay
    yS, yI, yR = y[:, 0:1], y[:, 1:2], y[:, 3:4]

    u1_raw = (cfg["beta"] * S * Qd * (yI - yS)) / (cfg["k3"] + 1e-12)
    u2_raw = (I * (yI - yR)) / (cfg["k4"] + 1e-12)

    return (
        torch.clamp(u1_raw, cfg["u_min"], cfg["u_max"]),
        torch.clamp(u2_raw, cfg["u_min"], cfg["u_max"]),
    )


def dH_d_delay_target(cfg: Dict, X, U, Y):
    S = X[:, 0:1]
    u1 = U[:, 0:1]
    yS = Y[:, 0:1]
    yI = Y[:, 1:2]

    g = cfg["beta"] * S * (1.0 - u1) * (yI - yS)
    return torch.cat([g, g], dim=1)  # (B,2)


def fbsde_driver(cfg: Dict, dL_dx, eta_vec, x, I_delay, C_delay, u1, u2, y, z, g):
    yS, yI, yC, yR = y[:, 0:1], y[:, 1:2], y[:, 2:3], y[:, 3:4]
    Qd = I_delay + C_delay

    dbS_dS = -cfg["beta"] * Qd * (1.0 - u1) - cfg["alpha"]
    dbI_dS =  cfg["beta"] * Qd * (1.0 - u1)
    At_y_S = yS * dbS_dS + yI * dbI_dS

    dbI_dI = -(cfg["alpha"] + cfg["gamma"] + u2)
    dbC_dI = (cfg["p"] * cfg["gamma"])
    dbR_dI = ((1.0 - cfg["p"]) * cfg["gamma"] + u2)
    At_y_I = yI * dbI_dI + yC * dbC_dI + yR * dbR_dI

    At_y_C = yC * (-(cfg["alpha"] + cfg["mu"]))
    At_y_R = yR * (-cfg["alpha"])

    At_y = torch.cat([At_y_S, At_y_I, At_y_C, At_y_R], dim=1)

    sigma_term = torch.stack([
        eta_vec[0, 0] * z[:, 0, 0],
        eta_vec[0, 1] * z[:, 1, 1],
        eta_vec[0, 2] * z[:, 2, 2],
        eta_vec[0, 3] * z[:, 3, 3],
    ], dim=1)

    gI, gC = g[:, 0:1], g[:, 1:2]
    ant = torch.cat([torch.zeros_like(gI), gI, gC, torch.zeros_like(gI)], dim=1)

    return dL_dx + At_y + sigma_term + ant


# ------------------------------------------------------------
# Training step
# ------------------------------------------------------------
def fbsde_train_step(cfg: Dict, net, opt, scheduler, epoch: int):
    net.train()
    opt.zero_grad()

    D, dL_dx, eta_vec = build_fbsde_helpers(cfg)
    M, N, dt = cfg["fbsde_batch"], cfg["N"], cfg["dt"]

    dW = torch.randn(M, N, 4, device=DEVICE, dtype=cfg["dtype"]) * math.sqrt(dt)

    x0 = torch.tensor(
        [cfg["S0"], cfg["I0"], cfg["C0"], cfg["R0"]],
        device=DEVICE, dtype=cfg["dtype"]
    ).view(1, 4).repeat(M, 1)

    hc = net.init_hidden(M, DEVICE)
    hist0 = x0

    x_list = [x0]
    y_list, z_list, g_list, u_list = [], [], [], []

    for i in range(N):
        t_i = torch.full((M, 1), i * dt, device=DEVICE, dtype=cfg["dtype"])
        x_i = x_list[i]

        y_i, z_i, g_i, hc = net.step(x_i, t_i, hc)
        if i > N - D:
            g_i = torch.zeros_like(g_i)

        x_delay = x_list[i - D] if i - D >= 0 else hist0
        I_d, C_d = x_delay[:, 1:2], x_delay[:, 2:3]

        u1_i, u2_i = fbsde_controls(cfg, x_i, I_d, C_d, y_i)
        b_i = fbsde_drift(cfg, x_i, I_d, C_d, u1_i, u2_i)
        sig_i = fbsde_diffusion(eta_vec, x_i)

        x_next = torch.clamp(x_i + b_i * dt + sig_i * dW[:, i, :], min=0.0)
        x_list.append(x_next)

        y_list.append(y_i)
        z_list.append(z_i)
        g_list.append(g_i)
        u_list.append(torch.cat([u1_i, u2_i], dim=1))

    t_N = torch.full((M, 1), N * dt, device=DEVICE, dtype=cfg["dtype"])
    y_N, _, _, _ = net.step(x_list[N], t_N, hc)
    y_list.append(y_N)

    X = torch.stack(x_list, dim=1)
    Y = torch.stack(y_list, dim=1)
    Z = torch.stack(z_list, dim=1)
    G = torch.stack(g_list, dim=1)
    U = torch.stack(u_list, dim=1)

    f_list = []
    for i in range(N):
        x_delay = X[:, i - D, :] if i - D >= 0 else hist0
        f_i = fbsde_driver(
            cfg, dL_dx, eta_vec,
            X[:, i, :],
            x_delay[:, 1:2], x_delay[:, 2:3],
            U[:, i, 0:1], U[:, i, 1:2],
            Y[:, i, :], Z[:, i, :, :],
            G[:, i, :]
        )
        f_list.append(f_i)
    F = torch.stack(f_list, dim=1)

    ZdW = torch.einsum("bnij,bnj->bni", Z, dW)
    tildeY = torch.zeros_like(Y)
    tildeY[:, 0, :] = Y[:, 0, :]
    tildeY[:, 1:, :] = Y[:, :-1, :] - F * dt + ZdW

    L1 = torch.mean((Y[:, 1:, :] - tildeY[:, 1:, :]) ** 2)
    Lterm = torch.mean(Y[:, -1, :] ** 2)

    g_targets = []
    for k in range(N):
        gk = dH_d_delay_target(
            cfg,
            X[:, k, :].detach(),
            U[:, k, :].detach(),
            Y[:, k, :].detach(),
        )
        g_targets.append(gk)

    g_targets.append(
        torch.zeros(M, 2, device=DEVICE, dtype=cfg["dtype"])
    )
    G_targ = torch.stack(g_targets, dim=1)

    if D < N:
        pred = G[:, 0:(N - D + 1), :]
        targ = G_targ[:, D:(N + 1), :].detach()
        L2 = torch.mean((pred - targ) ** 2)
    else:
        L2 = torch.tensor(0.0, device=DEVICE, dtype=cfg["dtype"])

    wL2 = cfg["w_L2"]
    loss = cfg["w_L1"] * L1 + cfg["w_terminal"] * Lterm + wL2 * L2

    loss.backward()
    nn.utils.clip_grad_norm_(net.parameters(), 1.0)
    opt.step()
    if scheduler is not None:
        scheduler.step()

    return float(loss), float(L1), float(Lterm), float(L2), float(wL2)


# ------------------------------------------------------------
# Rollout
# ------------------------------------------------------------
@torch.no_grad()
def rollout_fbsde_on_noise(cfg: Dict, net, noise_data):
    net.eval()
    D, _, eta_vec = build_fbsde_helpers(cfg)

    r_curr = torch.tensor(
        [[cfg["S0"], cfg["I0"], cfg["C0"], cfg["R0"]]],
        device=DEVICE, dtype=cfg["dtype"]
    )
    hc = net.init_hidden(1, DEVICE)
    hist0 = r_curr.clone()
    real_history = [r_curr]

    I_traj = [cfg["I0"]]
    C_traj = [cfg["C0"]]
    u1_traj, u2_traj = [], []

    for t in range(cfg["N"]):
        t_tensor = torch.full((1, 1), t * cfg["dt"], device=DEVICE, dtype=cfg["dtype"])
        y, _, _, hc = net.step(r_curr, t_tensor, hc)

        idx = t - D
        r_delay = real_history[idx] if idx >= 0 else hist0
        I_d, C_d = r_delay[:, 1:2], r_delay[:, 2:3]

        u1, u2 = fbsde_controls(cfg, r_curr, I_d, C_d, y)

        w = noise_data[t]
        b = fbsde_drift(cfg, r_curr, I_d, C_d, u1, u2)
        sig = fbsde_diffusion(eta_vec, r_curr)

        r_curr = torch.clamp(
            r_curr + b * cfg["dt"] + sig * math.sqrt(cfg["dt"]) * w,
            min=0.0
        )
        real_history.append(r_curr)

        I_traj.append(float(r_curr[0, 1]))
        C_traj.append(float(r_curr[0, 2]))
        u1_traj.append(float(u1))
        u2_traj.append(float(u2))

    return np.array(I_traj), np.array(C_traj), np.array(u1_traj), np.array(u2_traj)


# ------------------------------------------------------------
# Entry for pipeline
# ------------------------------------------------------------
def train_fbsde(cfg: Dict):
    print("\n[ABSDE] Training Started...")
    torch.manual_seed(cfg["fbsde_seed"])

    net = DeepABSDE_Net(cfg["fbsde_hidden"]).to(DEVICE)
    opt = optim.Adam(net.parameters(), lr=cfg["fbsde_lr"])
    scheduler = optim.lr_scheduler.StepLR(opt, step_size=cfg["fbsde_step"], gamma=0.5)

    for ep in range(1, cfg["fbsde_epochs"] + 1):
        loss, L1, LT, L2, wL2 = fbsde_train_step(cfg, net, opt, scheduler, ep)
        if ep % 3000 == 0:
            print(f"[ABSDE] Ep {ep}, Loss={loss:.4e}, L1={L1:.3e}, LT={LT:.3e}, L2={L2:.3e}")

    return net

@torch.no_grad()
def rollout_fbsde_on_noise(cfg: Dict, net, noise_data):
    net.eval()
    D, _, eta_vec = build_fbsde_helpers(cfg)

    r_curr = torch.tensor([[cfg["S0"], cfg["I0"], cfg["C0"], cfg["R0"]]], device=DEVICE, dtype=cfg["dtype"])
    hc = net.init_hidden(1, DEVICE)
    hist0 = r_curr.clone()
    real_history = [r_curr]

    I_traj = [cfg["I0"]]; C_traj = [cfg["C0"]]
    u1_traj, u2_traj = [], []

    for t in range(cfg["N"]):
        t_tensor = torch.full((1, 1), float(t) * cfg["dt"], device=DEVICE, dtype=cfg["dtype"])
        y_curr, _, _, hc = net.step(r_curr, t_tensor, hc)

        idx_delay = t - D
        r_delay = real_history[idx_delay] if idx_delay >= 0 else hist0
        I_d, C_d = r_delay[:, 1:2], r_delay[:, 2:3]

        u1, u2 = fbsde_controls(cfg, r_curr, I_d, C_d, y_curr)

        w_t = noise_data[t]  # (1,4)
        b = fbsde_drift(cfg, r_curr, I_d, C_d, u1, u2)
        sig = fbsde_diffusion(eta_vec, r_curr)

        r_next = torch.clamp(r_curr + b * cfg["dt"] + sig * w_t, min=0.0)
        real_history.append(r_next)
        r_curr = r_next

        I_traj.append(float(r_curr[0, 1].item()))
        C_traj.append(float(r_curr[0, 2].item()))
        u1_traj.append(float(u1.item()))
        u2_traj.append(float(u2.item()))

    return np.array(I_traj), np.array(C_traj), np.array(u1_traj), np.array(u2_traj)



# ============================================================
# 9) PPO
# ============================================================
class WindowHCVPPO(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.actor_mu = nn.Linear(hidden_dim, 2)
        self.actor_log_std = nn.Parameter(torch.zeros(2) - 0.5)
        self.critic = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        out, _ = self.lstm(x)
        h = out[:, -1]
        mu = self.actor_mu(h)
        std = torch.exp(self.actor_log_std).expand_as(mu)
        v = self.critic(h)
        return mu, std, v

    def act(self, x, deterministic: bool = False):
        mu, std, v = self.forward(x)
        if deterministic:
            u = torch.sigmoid(mu)
            return u, v
        dist = torch.distributions.Normal(mu, std)
        raw = dist.sample()
        u = torch.sigmoid(raw)
        logp = dist.log_prob(raw).sum(dim=1, keepdim=True)
        return u, logp, v, raw


def ppo_train(cfg):
    print("[PPO] Starting Training...")

    dt_local = cfg["dt"]
    delay_steps = int(round(cfg["tau"] / dt_local))
    window = delay_steps + 2
    input_dim = 7

    total_episodes = int(cfg["ppo_episodes"])
    # number of parallel episodes per rollout/update
    num_envs = int(cfg["ppo_batch"])

    agent = WindowHCVPPO(input_dim, cfg["ppo_hidden"]).to(DEVICE)
    with torch.no_grad():
        agent.actor_mu.bias.data[:] = 1.0
    opt = optim.Adam(agent.parameters(), lr=cfg["ppo_lr"])

    episodes_done = 0
    update_idx = 0

    while episodes_done < total_episodes:
        update_idx += 1

        batch_envs = min(num_envs, total_episodes - episodes_done)
        episodes_done += batch_envs

        S = torch.full((batch_envs, 1), cfg["S0"], device=DEVICE, dtype=cfg["dtype"])
        I = torch.full((batch_envs, 1), cfg["I0"], device=DEVICE, dtype=cfg["dtype"])
        C = torch.full((batch_envs, 1), cfg["C0"], device=DEVICE, dtype=cfg["dtype"])
        R = torch.full((batch_envs, 1), cfg["R0"], device=DEVICE, dtype=cfg["dtype"])

        hist_I = [I.clone() for _ in range(delay_steps + 1)]
        hist_C = [C.clone() for _ in range(delay_steps + 1)]
        win = torch.zeros(batch_envs, window, input_dim, device=DEVICE, dtype=cfg["dtype"])

        buf_W, buf_raw, buf_lp, buf_v, buf_r = [], [], [], [], []
        ep_reward = 0.0

        agent.eval()
        with torch.no_grad():
            for n in range(cfg["N"]):
                u, logp, v, raw = agent.act(win, deterministic=False)
                u1, u2 = u[:, 0:1], u[:, 1:2]

                I_tau = hist_I[-(delay_steps + 1)]
                C_tau = hist_C[-(delay_steps + 1)]

                dW = torch.randn(batch_envs, 4, device=DEVICE, dtype=cfg["dtype"]) * np.sqrt(dt_local)

                S = S + (cfg["Lambda"] - cfg["beta"] * S * (I_tau + C_tau) * (1 - u1) - cfg["alpha"] * S) * dt_local \
                      + cfg["eta"][0] * S * dW[:, 0:1]

                I = I + (cfg["beta"] * S * (I_tau + C_tau) * (1 - u1) - (cfg["alpha"] + cfg["gamma"] + u2) * I) * dt_local \
                      + cfg["eta"][1] * I * dW[:, 1:2]

                C = C + ((cfg["p"] * cfg["gamma"]) * I - (cfg["alpha"] + cfg["mu"]) * C) * dt_local \
                      + cfg["eta"][2] * C * dW[:, 2:3]

                R = R + (((1 - cfg["p"]) * cfg["gamma"] + u2) * I - cfg["alpha"] * R) * dt_local \
                      + cfg["eta"][3] * R * dW[:, 3:4]

                hist_I.append(I); hist_C.append(C)

                cost = (cfg["k1"] * I + cfg["k2"] * C + 0.5 * cfg["k3"] * u1 ** 2 + 0.5 * cfg["k4"] * u2 ** 2).squeeze()
                reward = -cost * dt_local

                buf_W.append(win.clone())
                buf_raw.append(raw.clone())
                buf_lp.append(logp.clone())
                buf_v.append(v.clone())
                buf_r.append(reward.clone())

                ep_reward += reward.mean().item()

                if n < cfg["N"] - 1:
                    new = torch.cat([
                        torch.full((batch_envs, 1, 1), (n + 1) * dt_local, device=DEVICE, dtype=cfg["dtype"]),
                        S.unsqueeze(1), I.unsqueeze(1), C.unsqueeze(1), R.unsqueeze(1),
                        I_tau.unsqueeze(1), C_tau.unsqueeze(1)
                    ], dim=2)
                    win = torch.cat([win[:, 1:], new], dim=1)

        # -----------------------------
        # GAE (terminal bootstrap = 0)
        # -----------------------------
        returns = []
        gae = torch.zeros(batch_envs, device=DEVICE, dtype=cfg["dtype"])

        for t_ in reversed(range(cfg["N"])):
            v_t = buf_v[t_].squeeze(-1)
            v_next = buf_v[t_ + 1].squeeze(-1) if t_ < cfg["N"] - 1 else torch.zeros_like(v_t)
            delta = buf_r[t_] + cfg["ppo_gamma_rl"] * v_next - v_t
            gae = delta + cfg["ppo_gamma_rl"] * cfg["ppo_gae_lambda"] * gae
            returns.insert(0, (gae + v_t).unsqueeze(-1))

        # -----------------------------
        # Flatten buffers
        # -----------------------------
        W = torch.stack(buf_W).reshape(-1, window, input_dim).detach()
        RAW = torch.stack(buf_raw).reshape(-1, 2).detach()
        OLD_LP = torch.stack(buf_lp).reshape(-1, 1).detach()
        RET = torch.stack(returns).reshape(-1, 1).detach()
        OLD_V = torch.stack(buf_v).reshape(-1, 1).detach()

        ADV = (RET - OLD_V)
        ADV = (ADV - ADV.mean()) / (ADV.std() + 1e-8)

        # -----------------------------
        # PPO Update
        # -----------------------------
        agent.train()
        data_size = W.shape[0]
        mb = min(int(cfg["ppo_minibatch"]), data_size)

        for _ in range(int(cfg["ppo_epochs"])):
            perm = torch.randperm(data_size, device=DEVICE)
            for start in range(0, data_size, mb):
                idx = perm[start:start + mb]

                mu, std, v = agent.forward(W[idx])
                dist = torch.distributions.Normal(mu, std)

                newlp = dist.log_prob(RAW[idx]).sum(dim=1, keepdim=True)
                ratio = torch.exp(newlp - OLD_LP[idx])

                surr1 = ratio * ADV[idx]
                surr2 = torch.clamp(ratio, 1 - cfg["ppo_clip_eps"], 1 + cfg["ppo_clip_eps"]) * ADV[idx]

                loss = (-torch.min(surr1, surr2).mean()
                        + 0.5 * (v - RET[idx]).pow(2).mean()
                        - cfg["ppo_entropy_coef"] * dist.entropy().mean())

                opt.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), cfg["ppo_max_grad_norm"])
                opt.step()

        # -----------------------------
        # Logging
        # -----------------------------
        if update_idx % 200 == 0:
            print(f"[PPO] update {update_idx} | episodes_done={episodes_done}/{total_episodes} | rollout_mean_reward={ep_reward:.4f}")

    return agent


@torch.no_grad()
def rollout_ppo_on_noise(cfg: Dict, agent, noise_data):
    agent.eval()
    dt_local = cfg["dt"]
    delay_steps = int(round(cfg["tau"] / dt_local))
    window = delay_steps + 2
    input_dim = 7

    eta_vec = torch.tensor(cfg["eta"], device=DEVICE, dtype=cfg["dtype"])
    r_curr = torch.tensor([[cfg["S0"], cfg["I0"], cfg["C0"], cfg["R0"]]], device=DEVICE, dtype=cfg["dtype"])
    real_history = [r_curr]
    win = torch.zeros(1, window, input_dim, device=DEVICE, dtype=cfg["dtype"])

    I_traj = [cfg["I0"]]; C_traj = [cfg["C0"]]
    u1_traj, u2_traj = [], []

    for t in range(cfg["N"]):
        u, _ = agent.act(win, deterministic=True)
        u1_t, u2_t = float(u[0, 0].item()), float(u[0, 1].item())

        w_t = noise_data[t]
        r_hist_stack = torch.stack(real_history)
        r_next = step_state_stoch(cfg, r_curr, r_hist_stack, u, w_t, eta_vec)
        r_curr = r_next
        real_history.append(r_curr)

        I_traj.append(float(r_curr[0, 1].item()))
        C_traj.append(float(r_curr[0, 2].item()))
        u1_traj.append(u1_t); u2_traj.append(u2_t)

        idx_tau = max(0, t - cfg["K_delay"])
        r_tau = real_history[idx_tau]
        I_tau = float(r_tau[0, 1].item())
        C_tau = float(r_tau[0, 2].item())

        if t < cfg["N"] - 1:
            new = torch.tensor([[(t + 1) * dt_local,
                                 float(r_curr[0, 0]), float(r_curr[0, 1]),
                                 float(r_curr[0, 2]), float(r_curr[0, 3]),
                                 I_tau, C_tau]], device=DEVICE, dtype=cfg["dtype"]).unsqueeze(1)
            win = torch.cat([win[:, 1:], new], dim=1)

    return np.array(I_traj), np.array(C_traj), np.array(u1_traj), np.array(u2_traj)


# ============================================================
# 10) Rollout benchmark fixed controls on same noise
# ============================================================
@torch.no_grad()
def rollout_fixed_controls(cfg: Dict, u1_arr, u2_arr, noise_data):
    eta_vec = torch.tensor(cfg["eta"], device=DEVICE, dtype=cfg["dtype"])
    r_curr = torch.tensor([[cfg["S0"], cfg["I0"], cfg["C0"], cfg["R0"]]], device=DEVICE, dtype=cfg["dtype"])
    real_history = [r_curr]
    I_traj = [cfg["I0"]]; C_traj = [cfg["C0"]]
    u1_traj, u2_traj = [], []

    for t in range(cfg["N"]):
        u_t = torch.tensor([[float(u1_arr[t]), float(u2_arr[t])]], device=DEVICE, dtype=cfg["dtype"])
        w_t = noise_data[t]
        r_hist_stack = torch.stack(real_history)
        r_next = step_state_stoch(cfg, r_curr, r_hist_stack, u_t, w_t, eta_vec)

        r_curr = r_next
        real_history.append(r_curr)

        I_traj.append(float(r_curr[0, 1].item()))
        C_traj.append(float(r_curr[0, 2].item()))
        u1_traj.append(float(u_t[0, 0].item()))
        u2_traj.append(float(u_t[0, 1].item()))

    return np.array(I_traj), np.array(C_traj), np.array(u1_traj), np.array(u2_traj)


# ============================================================
# 11) Evaluation
# ============================================================
def evaluate_models(cfg: Dict) -> Dict:
    set_seed(cfg["seed"])

    # Common noise
    common_noise = torch.randn(cfg["N"], 1, 4, device=DEVICE, dtype=cfg["dtype"])
    t_grid = np.linspace(0, cfg["T"], cfg["N"] + 1)
    t_u = t_grid[:-1]

    # 1) Benchmark (Adjoint FBSM)
    print("\n[Validation] 1. Benchmark (Adjoint)")
    u1_bench, u2_bench, _, _, _, _ = solve_fbsm_benchmark(cfg, common_noise, verbose=True)

    # 2) PG-DPO policy training
    print("\n[Validation] 2. LSTM-DPO Training")
    policy = LSTMPolicy(hidden=cfg["pg_hidden"]).to(DEVICE)
    train_pg_dpo(cfg, policy)

    # 3) P-PGDPO style MPC validation
    print("\n[Validation] 3. PGDPO (Stepwise MPC)")
    I_mpc, C_mpc, u1_mpc, u2_mpc = run_stepwise_mpc_simulation(
        cfg, policy, common_noise, num_mc=cfg["mpc_num_mc"], seed_base=cfg["mpc_seed_base"]
    )

    # 4) Pure network rollout
    print("\n[Validation] 4. LSTM-DPO (Pure rollout)")
    I_net, C_net, u1_net, u2_net = rollout_policy_on_noise(cfg, policy, common_noise)

    # 5) PPO
    print("\n[Validation] 5. PPO")
    ppo_agent = ppo_train(cfg)
    I_ppo, C_ppo, u1_ppo, u2_ppo = rollout_ppo_on_noise(cfg, ppo_agent, common_noise)

    # 6) FBSDE
    print("\n[Validation] 6. Deep ABSDE (FBSDE-Net)")
    fbsde_net = train_fbsde(cfg)
    I_fbsde, C_fbsde, u1_fbsde, u2_fbsde = rollout_fbsde_on_noise(cfg, fbsde_net, common_noise)

    # Re-rollout benchmark controls on same noise
    I_bench_path, C_bench_path, u1_bench_path, u2_bench_path = rollout_fixed_controls(cfg, u1_bench, u2_bench, common_noise)

    # Objectives (cumulative cost curves)
    J_bench_cum = cumulative_objective_np(cfg, I_bench_path, C_bench_path, u1_bench_path, u2_bench_path)
    J_mpc_cum   = cumulative_objective_np(cfg, I_mpc,       C_mpc,       u1_mpc,       u2_mpc)
    J_net_cum   = cumulative_objective_np(cfg, I_net,       C_net,       u1_net,       u2_net)
    J_ppo_cum   = cumulative_objective_np(cfg, I_ppo,       C_ppo,       u1_ppo,       u2_ppo)
    J_fbsde_cum = cumulative_objective_np(cfg, I_fbsde,     C_fbsde,     u1_fbsde,     u2_fbsde)

    # Controls are length N, benchmark arrays are length N+1; align to N
    gt_u1 = u1_bench[:cfg["N"]]
    gt_u2 = u2_bench[:cfg["N"]]

    return {
        "common_noise": common_noise,
        "t_grid": t_grid,
        "t_u": t_u,
        "gt_u1": gt_u1,
        "gt_u2": gt_u2,

        "bench": {"I": I_bench_path, "C": C_bench_path, "u1": u1_bench_path, "u2": u2_bench_path, "J": J_bench_cum},
        "net":   {"I": I_net,       "C": C_net,       "u1": u1_net,       "u2": u2_net,       "J": J_net_cum},
        "mpc":   {"I": I_mpc,       "C": C_mpc,       "u1": u1_mpc,       "u2": u2_mpc,       "J": J_mpc_cum},
        "ppo":   {"I": I_ppo,       "C": C_ppo,       "u1": u1_ppo,       "u2": u2_ppo,       "J": J_ppo_cum},
        "absde": {"I": I_fbsde,     "C": C_fbsde,     "u1": u1_fbsde,     "u2": u2_fbsde,     "J": J_fbsde_cum},
    }

def plot_results(cfg: Dict, res: Dict, plot_ppo: bool = True, plot_absde: bool = True) -> None:
    t_grid = res["t_grid"]
    t_u = res["t_u"]
    gt_u1 = res["gt_u1"]
    gt_u2 = res["gt_u2"]

    # u1
    plot_series(
        t_u, gt_u1,
        series={
            "LSTM-DPO": (res["net"]["u1"], "b:"),
            "DEEP ABSDE": (res["absde"]["u1"], "g-."),
            "PGDPO": (res["mpc"]["u1"], "r--"),
            "PPO": (res["ppo"]["u1"], "m-"),
        },
        title="Control u1 (Vaccination)",
        xlabel="Time",
        ylabel="Vaccination",
        filename_pdf="Optimal_Vaccination_u1.pdf",
        ylim=(-0.1, 1.1),
        ppo=plot_ppo,
        absde=plot_absde,
    )

    # u2
    plot_series(
        t_u, gt_u2,
        series={
            "LSTM-DPO": (res["net"]["u2"], "b:"),
            "DEEP ABSDE": (res["absde"]["u2"], "g-."),
            "PGDPO": (res["mpc"]["u2"], "r--"),
            "PPO": (res["ppo"]["u2"], "m-"),
        },
        title="Control u2 (Treatment)",
        xlabel="Time",
        ylabel="Treatment",
        filename_pdf="Optimal_Vaccination_u2.pdf",
        ylim=(-0.1, 1.1),
        ppo=plot_ppo,
        absde=plot_absde,
    )

    # cumulative cost (N points)
    plot_series(
        t_u, res["bench"]["J"],
        series={
            "LSTM-DPO": (res["net"]["J"], "b:"),
            "DEEP ABSDE": (res["absde"]["J"], "g-."),
            "PGDPO": (res["mpc"]["J"], "r--"),
            "PPO": (res["ppo"]["J"], "m-"),
        },
        title="Cumulative Cost J",
        xlabel="Time",
        ylabel="Cumulative Cost",
        filename_pdf="Optimal_Vaccination_J.pdf",
        ylim=None,
        ppo=plot_ppo,
        absde=plot_absde,
    )

    # states I
    plot_series(
        t_grid, res["bench"]["I"],
        series={
            "LSTM-DPO": (res["net"]["I"], "b:"),
            "DEEP ABSDE": (res["absde"]["I"], "g-."),
            "PGDPO": (res["mpc"]["I"], "r--"),
            "PPO": (res["ppo"]["I"], "m-"),
        },
        title="Infected I Trajectory",
        xlabel="Time",
        ylabel="Infected I",
        filename_pdf="Optimal_Vaccination_I.pdf",
        ylim=None,
        ppo=plot_ppo,
        absde=plot_absde,
    )

    # states C
    plot_series(
        t_grid, res["bench"]["C"],
        series={
            "LSTM-DPO": (res["net"]["C"], "b:"),
            "DEEP ABSDE": (res["absde"]["C"], "g-."),
            "PGDPO": (res["mpc"]["C"], "r--"),
            "PPO": (res["ppo"]["C"], "m-"),
        },
        title="Chronic C Trajectory",
        xlabel="Time",
        ylabel="Chronic C",
        filename_pdf="Optimal_Vaccination_C.pdf",
        ylim=None,
        ppo=plot_ppo,
        absde=plot_absde,
    )


# ============================================================
# 12) Entry point
# ============================================================
if __name__ == "__main__":
    results = evaluate_models(CFG)

    plot_results(CFG, results, plot_ppo=False, plot_absde=True)
    plot_results(CFG, results, plot_ppo=True, plot_absde=False)
