In [None]:
import math
import random
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm


# ============================================================
# 0. Matplotlib / Font Settings (Type-3 font prevention)
# ============================================================
plt.rcParams.update({
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
    "font.family": "serif",
    "font.serif": ["Liberation Serif", "FreeSerif", "serif"],
    "font.size": 10,
    "axes.labelsize": 10,
    "legend.fontsize": 9,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "mathtext.fontset": "stix",
})


# ============================================================
# 1. Unified Configuration
# ============================================================
CFG: Dict = {
    # reproducibility / device / dtype
    "seed": 42,
    "device": "cuda:2" if torch.cuda.is_available() else "cpu",
    "dtype": torch.float32,

    # dynamics / cost
    "a0": -0.5,
    "b0": 1.0,
    "b1": 0.8,
    "sigma": 0.2,
    "R": 1.0,
    "S0": -10.0,

    # time grid
    "T": 2.0,
    "N_steps": 20,
    "delta": 1.0,

    # initial condition
    "x0": 0.0,

    # Algo 1 (Warmup / LSTM-DPO)
    "batch_size": 256,
    "warmup_iters": 10000,
    "lr": 1e-4,

    # Algo 2 (MV-FABSDE)
    "fbsde_iters": 10000,
    "fbsde_lr": 5e-4,
    "fbsde_decay_step": 10000,

    # Stage2 projection (P-PGDPO projection)
    "N_mc_stage2": 8,
    "N_branch": 10,
    "N_compare": 51,

    # PPO
    "ppo_hidden": 64,
    "ppo_optimizer_lr": 5e-4,
    "ppo_iters": 1000,
    "ppo_epochs": 10,
    "ppo_gamma_rl": 1.0,
    "ppo_gae_lambda": 0.95,
    "ppo_clip_eps": 0.2,
    "ppo_minibatch_size": 512,
    "ppo_entropy_coef": 0.01,
    "ppo_value_coef": 0.5,
    "ppo_max_grad_norm": 0.5,

    "ppo_x_norm": 2.0,
    "ppo_v_norm": 5.0,
    "ppo_reward_scale": 0.1,
}

DEVICE = torch.device(CFG["device"])
torch.set_default_dtype(CFG["dtype"])


# ============================================================
# 2. Utilities
# ============================================================
def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def dt(cfg: Dict) -> float:
    return cfg["T"] / cfg["N_steps"]


def delay_steps(cfg: Dict) -> int:
    d = int(round(cfg["delta"] / dt(cfg)))
    return max(1, d)


def make_time_norm(n: int, cfg: Dict, batch: int, device: torch.device) -> torch.Tensor:
    t = n * dt(cfg)
    return torch.full((batch,), t / cfg["T"], device=device)


def drift_x(x: torch.Tensor, v: torch.Tensor, v_delayed: torch.Tensor, cfg: Dict) -> torch.Tensor:
    return cfg["a0"] * x + cfg["b0"] * v + cfg["b1"] * v_delayed


def get_v_delayed(v_hist: torch.Tensor, n: int, D: int, device: torch.device) -> torch.Tensor:
    if n - D >= 0:
        return v_hist[:, n - D]
    return torch.zeros(v_hist.size(0), device=device, dtype=v_hist.dtype)


# ============================================================
# 3. Models
# ============================================================
class LSTMPolicy(nn.Module):
    def __init__(self, input_size: int = 2, hidden_size: int = 64):
        super().__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTMCell(input_size, hidden_size)
        self.head = nn.Linear(hidden_size, 1)
        with torch.no_grad():
            self.head.bias.fill_(0.5)

    def init_hidden(self, batch_size: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        h0 = torch.zeros(batch_size, self.hidden_size, device=device)
        c0 = torch.zeros(batch_size, self.hidden_size, device=device)
        return h0, c0

    def forward_step(
        self, t_norm: torch.Tensor, x_t: torch.Tensor, h: torch.Tensor, c: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        inp = torch.stack([t_norm, x_t], dim=-1)
        h_next, c_next = self.lstm(inp, (h, c))
        v_raw = self.head(h_next).squeeze(-1)
        v = torch.nn.functional.softplus(v_raw)
        return v, h_next, c_next


class FBSDENet(nn.Module):
    def __init__(self, input_size: int = 2, hidden_size: int = 64):
        super().__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTMCell(input_size, hidden_size)
        self.head_Y = nn.Linear(hidden_size, 1)
        self.head_Z = nn.Linear(hidden_size, 1)
        self.head_EY = nn.Linear(hidden_size, 1)
        with torch.no_grad():
            self.head_Y.bias.fill_(-5.0)

    def init_hidden(self, batch_size: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        h0 = torch.zeros(batch_size, self.hidden_size, device=device)
        c0 = torch.zeros(batch_size, self.hidden_size, device=device)
        return h0, c0

    def forward_step(
        self, t_norm: torch.Tensor, x_t: torch.Tensor, h: torch.Tensor, c: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        inp = torch.stack([t_norm, x_t], dim=-1)
        h_next, c_next = self.lstm(inp, (h, c))
        y = self.head_Y(h_next).squeeze(-1)
        z = self.head_Z(h_next).squeeze(-1)
        ey = self.head_EY(h_next).squeeze(-1)
        return y, z, ey, h_next, c_next


class WindowPPOAgent(nn.Module):
    def __init__(self, input_dim: int = 3, hidden_dim: int = 64, window_size: int = 10):
        super().__init__()
        self.window_size = window_size
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.actor_mu = nn.Linear(hidden_dim, 1)
        self.actor_log_std = nn.Parameter(torch.zeros(1))
        self.critic = nn.Linear(hidden_dim, 1)

        nn.init.constant_(self.actor_mu.bias, 1.5)
        for name, param in self.named_parameters():
            if "weight" in name:
                nn.init.orthogonal_(param, gain=1.0)

    def forward(self, x_window: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        out, _ = self.lstm(x_window)
        last_hidden = out[:, -1, :]
        mu = self.actor_mu(last_hidden)
        val = self.critic(last_hidden)
        std = torch.exp(self.actor_log_std).expand_as(mu)
        return mu, std, val

    def get_action(self, x_window: torch.Tensor, deterministic: bool = False):
        mu, std, val = self.forward(x_window)
        if deterministic:
            action_raw = mu
            action = torch.nn.functional.softplus(action_raw) + 1e-6
            return action, val

        dist = torch.distributions.Normal(mu, std)
        action_raw = dist.sample()
        action = torch.nn.functional.softplus(action_raw) + 1e-6
        log_prob = dist.log_prob(action_raw)
        return action, log_prob, val, action_raw


# ============================================================
# 4. Algorithm 1: Warmup (LSTM-DPO)
# ============================================================
def simulate_batch_episode_pg(policy: LSTMPolicy, cfg: Dict, detach_policy: bool = False) -> torch.Tensor:
    device = DEVICE
    N, D = cfg["N_steps"], delay_steps(cfg)
    batch = cfg["batch_size"]

    x = torch.full((batch,), cfg["x0"], device=device)
    v_hist = torch.zeros(batch, N + 1, device=device)

    h, c = policy.init_hidden(batch, device)
    cost = torch.zeros(batch, device=device)

    for n in range(N):
        t_norm = make_time_norm(n, cfg, batch, device)
        if detach_policy:
            with torch.no_grad():
                v, h, c = policy.forward_step(t_norm, x, h, c)
        else:
            v, h, c = policy.forward_step(t_norm, x, h, c)

        v_hist[:, n] = v
        v_del = get_v_delayed(v_hist, n, D, device)

        dB = math.sqrt(dt(cfg)) * torch.randn(batch, device=device)
        x = x + drift_x(x, v, v_del, cfg) * dt(cfg) + cfg["sigma"] * dB
        cost = cost + cfg["R"] * v**2 * dt(cfg)

    cost = cost + cfg["S0"] * x
    return cost.mean()


def warmup_train(policy: LSTMPolicy, cfg: Dict) -> LSTMPolicy:
    policy.to(DEVICE)
    policy.train()
    opt = optim.Adam(policy.parameters(), lr=cfg["lr"])

    for it in range(cfg["warmup_iters"]):
        opt.zero_grad(set_to_none=True)
        J = simulate_batch_episode_pg(policy, cfg, detach_policy=False)
        J.backward()
        opt.step()
        if (it + 1) % 200 == 0:
            print(f"[Warmup] iter={it+1}, J={J.item():.4f}")

    return policy


def rollout_policy_pg_deterministic(policy: LSTMPolicy, cfg: Dict) -> Dict[str, List[torch.Tensor]]:
    device = DEVICE
    T, N, D = cfg["T"], cfg["N_steps"], delay_steps(cfg)
    policy.eval()

    x = torch.full((1,), cfg["x0"], device=device)
    v_hist = torch.zeros(1, N + 1, device=device)
    h, c = policy.init_hidden(1, device)

    xs, vs, hs, cs, v_hists = [], [], [], [], []
    with torch.no_grad():
        for n in range(N):
            xs.append(x.clone())
            hs.append(h.clone())
            cs.append(c.clone())
            v_hists.append(v_hist.clone())

            t_norm = torch.full((1,), (n * dt(cfg)) / T, device=device)
            v, h, c = policy.forward_step(t_norm, x, h, c)
            vs.append(v.clone())

            v_hist[:, n] = v
            v_del = get_v_delayed(v_hist, n, D, device)
            x = x + drift_x(x, v, v_del, cfg) * dt(cfg)

        xs.append(x.clone())

    return {"xs": xs, "vs": vs, "hs": hs, "cs": cs, "v_hists": v_hists}


# ============================================================
# 5. Algorithm 2: MV-FABSDE
# ============================================================
def train_fbsde(net: FBSDENet, cfg: Dict) -> FBSDENet:
    net.to(DEVICE)
    net.train()

    opt = optim.Adam(net.parameters(), lr=cfg["fbsde_lr"])
    scheduler = optim.lr_scheduler.StepLR(opt, step_size=cfg["fbsde_decay_step"], gamma=0.1)
    mse = nn.MSELoss()

    N, D = cfg["N_steps"], delay_steps(cfg)
    batch = cfg["batch_size"]
    T = cfg["T"]

    for it in range(cfg["fbsde_iters"]):
        dB = torch.randn(batch, N, device=DEVICE) * math.sqrt(dt(cfg))

        x = torch.full((batch,), cfg["x0"], device=DEVICE)
        v_hist = torch.zeros(batch, N + 1, device=DEVICE)
        h, c = net.init_hidden(batch, DEVICE)

        Y_pred = [None] * (N + 1)
        Z_pred = [None] * (N + 1)
        EY_pred = [None] * (N + 1)
        tildeY = [None] * (N + 1)

        t0 = torch.zeros(batch, device=DEVICE)
        y, z, ey, h, c = net.forward_step(t0, x, h, c)
        Y_pred[0], Z_pred[0], EY_pred[0] = y, z, ey

        for i in range(N):
            y_i, z_i, ey_i = Y_pred[i], Z_pred[i], EY_pred[i]
            ey_used = ey_i if i <= N - D else torch.zeros_like(ey_i)

            v_val = - (cfg["b0"] * y_i + cfg["b1"] * ey_used) / (2.0 * cfg["R"])
            v = torch.relu(v_val)
            v_hist[:, i] = v

            v_del = get_v_delayed(v_hist, i, D, DEVICE)
            x = x + drift_x(x, v, v_del, cfg) * dt(cfg) + cfg["sigma"] * dB[:, i]

            tildeY[i + 1] = y_i - cfg["a0"] * y_i * dt(cfg) + z_i * dB[:, i]

            t_next = torch.full((batch,), ((i + 1) * dt(cfg)) / T, device=DEVICE)
            y, z, ey, h, c = net.forward_step(t_next, x, h, c)
            Y_pred[i + 1], Z_pred[i + 1], EY_pred[i + 1] = y, z, ey

        L1 = 0.0
        for i in range(N):
            L1 = L1 + mse(Y_pred[i + 1], tildeY[i + 1])
        L1 = L1 + mse(Y_pred[N], torch.full((batch,), cfg["S0"], device=DEVICE))

        L2 = 0.0
        for i in range(0, N - D + 1):
            L2 = L2 + mse(EY_pred[i], tildeY[i + D])

        L1 = L1 / max(1, N)
        L2 = L2 / max(1, N - D + 1)
        loss = L1 + L2

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        scheduler.step()

        if (it + 1) % 2000 == 0:
            print(f"[Alg2 MV-FABSDE] iter={it+1}, L1={L1.item():.6f}, L2={L2.item():.6f}, total={loss.item():.6f}")

    return net


def rollout_fbsde_deterministic(net: FBSDENet, cfg: Dict) -> List[torch.Tensor]:
    net.eval()

    N, D = cfg["N_steps"], delay_steps(cfg)
    T = cfg["T"]

    x = torch.full((1,), cfg["x0"], device=DEVICE)
    v_hist = torch.zeros(1, N + 1, device=DEVICE)
    h, c = net.init_hidden(1, DEVICE)

    vs = []
    with torch.no_grad():
        for i in range(N):
            t_norm = torch.full((1,), (i * dt(cfg)) / T, device=DEVICE)
            y, z, ey, h, c = net.forward_step(t_norm, x, h, c)
            ey_used = ey if i <= N - D else torch.zeros_like(ey)
            v_val = - (cfg["b0"] * y + cfg["b1"] * ey_used) / (2.0 * cfg["R"])
            v = torch.relu(v_val)

            vs.append(v.clone())
            v_hist[:, i] = v
            v_del = get_v_delayed(v_hist, i, D, DEVICE)
            x = x + drift_x(x, v, v_del, cfg) * dt(cfg)

    return vs


# ============================================================
# 6. Algorithm 3: PPO
# ============================================================
def train_ppo_window(cfg: Dict) -> WindowPPOAgent:
    N = cfg["N_steps"]
    D = delay_steps(cfg)
    window = D + 2

    x_norm = cfg["ppo_x_norm"]
    v_norm = cfg["ppo_v_norm"]
    reward_scale = cfg["ppo_reward_scale"]

    agent = WindowPPOAgent(input_dim=3, hidden_dim=cfg["ppo_hidden"], window_size=window).to(DEVICE)

    optimizer = optim.Adam(agent.parameters(), lr=cfg["ppo_optimizer_lr"])

    ppo_iters = cfg["ppo_iters"]
    ppo_epochs = cfg["ppo_epochs"]
    gamma_rl = cfg["ppo_gamma_rl"]
    lam_gae = cfg["ppo_gae_lambda"]
    clip_eps = cfg["ppo_clip_eps"]
    minibatch_size = cfg["ppo_minibatch_size"]

    for it in range(ppo_iters):
        batch_windows = []
        batch_raw_acts = []
        batch_log_probs = []
        batch_vals = []
        batch_rews = []

        B = cfg["batch_size"]
        x = torch.full((B, 1), cfg["x0"], device=DEVICE)
        v_hist = torch.zeros(B, N + 1, device=DEVICE)

        w = torch.zeros(B, window, 3, device=DEVICE)
        w[:, :, 0] = 0.0
        w[:, :, 1] = cfg["x0"] / x_norm
        w[:, :, 2] = 0.0

        episode_rew_sum = torch.zeros(B, device=DEVICE)

        for n in range(N):
            with torch.no_grad():
                v, log_prob, val, raw_act = agent.get_action(w, deterministic=False)

            batch_windows.append(w.clone())
            batch_raw_acts.append(raw_act)
            batch_log_probs.append(log_prob)
            batch_vals.append(val)

            v_flat = v.squeeze(-1)
            v_hist[:, n] = v_flat
            v_del = get_v_delayed(v_hist, n, D, DEVICE)

            dW = math.sqrt(dt(cfg)) * torch.randn(B, device=DEVICE)
            x_next = x.squeeze(-1) + drift_x(x.squeeze(-1), v_flat, v_del, cfg) * dt(cfg) + cfg["sigma"] * dW
            x = x_next.unsqueeze(-1)

            step_reward = -cfg["R"] * (v_flat**2) * dt(cfg)
            if n == N - 1:
                step_reward = step_reward + (-cfg["S0"] * x_next)
            batch_rews.append(step_reward * reward_scale)
            episode_rew_sum += step_reward

            if n < N - 1:
                next_t_norm = ((n + 1) * dt(cfg)) / cfg["T"]
                new_in = torch.cat(
                    [
                        torch.full((B, 1, 1), next_t_norm, device=DEVICE),
                        (x / x_norm).unsqueeze(1),
                        (v / v_norm).unsqueeze(1),
                    ],
                    dim=2,
                )
                w = torch.cat([w[:, 1:, :], new_in], dim=1)

        returns = []
        gae = torch.zeros(B, device=DEVICE)
        for t in reversed(range(N)):
            curr_val = batch_vals[t].squeeze(-1)
            if t == N - 1:
                next_val = torch.zeros_like(curr_val)
                nonterm = 0.0
            else:
                next_val = batch_vals[t + 1].squeeze(-1)
                nonterm = 1.0

            delta = batch_rews[t] + gamma_rl * next_val * nonterm - curr_val
            gae = delta + gamma_rl * lam_gae * nonterm * gae
            returns.insert(0, gae + curr_val)

        b_windows = torch.stack(batch_windows).view(-1, window, 3)
        b_raw_acts = torch.stack(batch_raw_acts).view(-1, 1)
        b_log_probs = torch.stack(batch_log_probs).view(-1, 1)
        b_rets = torch.stack(returns).view(-1, 1)
        b_vals = torch.stack(batch_vals).view(-1, 1)

        adv = b_rets - b_vals
        adv = (adv - adv.mean()) / (adv.std() + 1e-8)

        dataset = b_windows.size(0)
        perm = torch.randperm(dataset, device=DEVICE)

        for _ in range(ppo_epochs):
            for start in range(0, dataset, minibatch_size):
                idx = perm[start: start + minibatch_size]

                mu, std, val = agent.forward(b_windows[idx])
                dist = torch.distributions.Normal(mu, std)
                new_lp = dist.log_prob(b_raw_acts[idx])
                entropy = dist.entropy().mean()

                ratio = torch.exp(new_lp - b_log_probs[idx])
                surr1 = ratio * adv[idx]
                surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv[idx]

                loss = (
                    -torch.min(surr1, surr2).mean()
                    + cfg["ppo_value_coef"] * ((val - b_rets[idx]) ** 2).mean()
                    - cfg["ppo_entropy_coef"] * entropy
                )

                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(agent.parameters(), cfg["ppo_max_grad_norm"])
                optimizer.step()

        if (it + 1) % 100 == 0:
            print(f"[PPO] iter={it+1}, AvgReward={episode_rew_sum.mean().item():.4f}")

    return agent


def rollout_ppo_deterministic(agent: WindowPPOAgent, cfg: Dict) -> np.ndarray:
    agent.eval()

    N, D = cfg["N_steps"], delay_steps(cfg)
    window = D + 2
    x_norm, v_norm = cfg["ppo_x_norm"], cfg["ppo_v_norm"]

    x = torch.full((1, 1), cfg["x0"], device=DEVICE)
    v_hist = torch.zeros(1, N + 1, device=DEVICE)

    w = torch.zeros(1, window, 3, device=DEVICE)
    w[:, :, 0] = 0.0
    w[:, :, 1] = cfg["x0"] / x_norm
    w[:, :, 2] = 0.0

    vs = []
    with torch.no_grad():
        for n in range(N):
            v, _ = agent.get_action(w, deterministic=True)
            v_val = v.item()
            vs.append(v_val)

            v_hist[:, n] = v_val
            v_del = get_v_delayed(v_hist, n, D, DEVICE)

            x = x + drift_x(x, v, v_del, cfg) * dt(cfg)

            if n < N - 1:
                next_t_norm = ((n + 1) * dt(cfg)) / cfg["T"]
                new_in = torch.cat(
                    [
                        torch.full((1, 1, 1), next_t_norm, device=DEVICE),
                        (x / x_norm).unsqueeze(1),
                        (v / v_norm).unsqueeze(1),
                    ],
                    dim=2,
                )
                w = torch.cat([w[:, 1:, :], new_in], dim=1)

    return np.array(vs)


# ============================================================
# 7. Benchmark
# ============================================================
def analytic_solution(cfg: Dict) -> Tuple[np.ndarray, np.ndarray]:
    T, N = cfg["T"], cfg["N_steps"]
    a0, b0, b1 = cfg["a0"], cfg["b0"], cfg["b1"]
    R, S0 = cfg["R"], cfg["S0"]
    D = delay_steps(cfg)

    t_grid = np.linspace(0, T, N + 1)
    y = S0 * np.exp(a0 * (T - t_grid))

    u = np.zeros_like(y)
    for i in range(N + 1):
        term1 = b0 * y[i]
        term2 = b1 * y[i + D] if (i + D) <= N else 0.0
        u[i] = max(0.0, - (term1 + term2) / (2.0 * R))

    return t_grid, u


def mae_rmse(y_hat: np.ndarray, y_ref: np.ndarray) -> Tuple[float, float]:
    y_hat = np.asarray(y_hat).reshape(-1)
    y_ref = np.asarray(y_ref).reshape(-1)
    err = y_hat - y_ref
    mae = float(np.mean(np.abs(err)))
    rmse = float(np.sqrt(np.mean(err**2)))
    return mae, rmse


# ============================================================
# 8. P-PGDPO Projection
# ============================================================
def estimate_costate_mc(
    policy: LSTMPolicy,
    cfg: Dict,
    n0: int,
    x0_val: torch.Tensor,
    h0: torch.Tensor,
    c0: torch.Tensor,
    v_hist0: torch.Tensor,
    M: int,
) -> torch.Tensor:
    policy.eval()

    N, D = cfg["N_steps"], delay_steps(cfg)
    T = cfg["T"]

    lambdas = []
    for _ in range(M):
        x_init = x0_val.clone().detach().requires_grad_(True)
        x = x_init

        h = h0.clone().detach()
        c = c0.clone().detach()
        v_hist = v_hist0.clone().detach()

        cost = torch.zeros(1, device=DEVICE)
        for n in range(n0, N):
            t_norm = torch.full((1,), (n * dt(cfg)) / T, device=DEVICE)
            v, h, c = policy.forward_step(t_norm, x, h, c)

            v_hist[:, n] = v
            v_del = get_v_delayed(v_hist, n, D, DEVICE)

            dB = math.sqrt(dt(cfg)) * torch.randn(1, device=DEVICE)
            x = x + drift_x(x, v, v_del, cfg) * dt(cfg) + cfg["sigma"] * dB
            cost = cost + cfg["R"] * v**2 * dt(cfg)

        cost = cost + cfg["S0"] * x
        (grad_x,) = torch.autograd.grad(cost, x_init, retain_graph=False, create_graph=False)
        lambdas.append(grad_x.detach())

    return torch.mean(torch.stack(lambdas, dim=0), dim=0)


def project_p_pgdpo(
    policy_pg: LSTMPolicy,
    cfg: Dict,
    t_grid: np.ndarray,
    rollout_pg: Dict[str, List[torch.Tensor]],
) -> Tuple[np.ndarray, np.ndarray]:
    N = cfg["N_steps"]
    D = delay_steps(cfg)

    idx_list = sorted(
        set(
            int(round(i * (N - 1) / (cfg["N_compare"] - 1)))
            for i in range(cfg["N_compare"])
        )
    )

    xs = rollout_pg["xs"]
    hs = rollout_pg["hs"]
    cs = rollout_pg["cs"]
    v_hists = rollout_pg["v_hists"]

    v_pmp_list = []
    t_cmp = []

    for n in idx_list:
        lam_t = estimate_costate_mc(
            policy_pg, cfg, n, xs[n], hs[n], cs[n], v_hists[n], M=cfg["N_mc_stage2"]
        ).item()

        if n + D >= N:
            lam_t_delta = cfg["S0"] if (n + D) == N else 0.0
        else:
            future = []
            for _ in range(cfg["N_branch"]):
                x_curr = xs[n].clone()
                h_curr = hs[n].clone()
                c_curr = cs[n].clone()
                v_hist_curr = v_hists[n].clone()

                for k in range(n, n + D):
                    t_k_norm = torch.full((1,), (k * dt(cfg)) / cfg["T"], device=DEVICE)
                    with torch.no_grad():
                        v, h_curr, c_curr = policy_pg.forward_step(t_k_norm, x_curr, h_curr, c_curr)

                    v_hist_curr[:, k] = v
                    v_del = get_v_delayed(v_hist_curr, k, D, DEVICE)
                    dB = math.sqrt(dt(cfg)) * torch.randn(1, device=DEVICE)
                    x_curr = x_curr + drift_x(x_curr, v, v_del, cfg) * dt(cfg) + cfg["sigma"] * dB

                lam_future = estimate_costate_mc(
                    policy_pg, cfg, n + D, x_curr, h_curr, c_curr, v_hist_curr, M=cfg["N_mc_stage2"]
                ).item()
                future.append(lam_future)

            lam_t_delta = float(np.mean(future))

        val = - (cfg["b0"] * lam_t + cfg["b1"] * lam_t_delta) / (2.0 * cfg["R"])
        v_pmp_list.append(max(0.0, val))
        t_cmp.append(t_grid[n])

    return np.array(t_cmp), np.array(v_pmp_list)


# ============================================================
# 9. Plotting
# ============================================================
def plot_controls(
    t_grid: np.ndarray,
    gt: np.ndarray,
    c_pg: np.ndarray,
    c_fbsde: np.ndarray,
    c_ppo: np.ndarray,
    c_proj: np.ndarray,
    ppo: bool = True,
    fbsde: bool = True,
) -> None:
    plt.figure(figsize=(7, 5))
    plt.plot(t_grid, gt, "k-", lw=2, label="Benchmark")

    plt.plot(t_grid, c_pg, "b:", label="LSTM-DPO")
    if fbsde:
        plt.plot(t_grid, c_fbsde, "g-.", label="DEEP ABSDE")
    if ppo:
        plt.plot(t_grid, c_ppo, "m-", label="PPO")
    plt.plot(t_grid, c_proj, "r--", label="PGDPO")
    plt.xlabel("Time")
    plt.ylabel("Advertising Expenditure")
    plt.legend()
    plt.grid(alpha=0.1)
    if ppo:
        plt.savefig("Benchmark2_control_ppo.pdf", bbox_inches="tight")
    else:
        plt.savefig("Benchmark2_control.pdf", bbox_inches="tight")
    plt.tight_layout()
    plt.show()


# ============================================================
# 10. Metrics + Plot driver
# ============================================================
def compute_metrics_and_plot(
    policy_pg: LSTMPolicy,
    fbsde_net: Optional[FBSDENet],
    policy_ppo: WindowPPOAgent,
    cfg: Dict,
) -> None:
    policy_pg.eval()
    if fbsde_net is not None:
        fbsde_net.eval()
    policy_ppo.eval()

    # --- benchmark ---
    t_grid_full, v_star_full = analytic_solution(cfg)
    t_plot = t_grid_full[:-1]
    v_star = v_star_full[:-1]


    # --- Algo 1 rollout (deterministic) ---
    rollout_pg = rollout_policy_pg_deterministic(policy_pg, cfg)
    v_pg = np.array([v.item() for v in rollout_pg["vs"]])  # (N,)

    # --- Algo 2 rollout (optional) ---
    # (kept as-is; metrics printing optional if you still want)
    v_fbsde_list = rollout_fbsde_deterministic(fbsde_net, cfg)
    v_fbsde = np.array([v.item() for v in v_fbsde_list])  # (N,)

    # --- Algo 3 PPO rollout (deterministic) ---
    v_ppo = rollout_ppo_deterministic(policy_ppo, cfg)  # (N,)

    # --- Projection (P-PGDPO) ---
    t_cmp, v_pmp = project_p_pgdpo(policy_pg, cfg, t_grid_full, rollout_pg)

    v_star_cmp = np.interp(t_cmp, t_grid_full, v_star_full)
    v_pg_cmp = np.interp(t_cmp, t_grid_full[:-1], v_pg)
    v_ppo_cmp = np.interp(t_cmp, t_grid_full[:-1], v_ppo)
    v_fbsde_cmp = np.interp(t_cmp, t_grid_full[:-1], v_fbsde)

    mae_pg, rmse_pg = mae_rmse(v_pg_cmp, v_star_cmp)
    mae_ppo, rmse_ppo = mae_rmse(v_ppo_cmp, v_star_cmp)
    mae_pmp, rmse_pmp = mae_rmse(v_pmp, v_star_cmp)
    mae_f, rmse_f = mae_rmse(v_fbsde_cmp, v_star_cmp)

    print(f"RMSE & MAE (Algo 1 PG)      : {rmse_pg:.6f}, {mae_pg:.6f}")
    print(f"RMSE & MAE (Algo 2 FBSDE)   : {rmse_f:.6f}, {mae_f:.6f}")
    print(f"RMSE & MAE (Algo 3 PPO)     : {rmse_ppo:.6f}, {mae_ppo:.6f}")
    print(f"RMSE & MAE (P-PGDPO)        : {rmse_pmp:.6f}, {mae_pmp:.6f}")

    # --- plot using your function ---
    v_pmp_on_N = np.interp(t_plot, t_cmp, v_pmp)

    plot_controls(t_plot, v_star, v_pg, v_fbsde, v_ppo, v_pmp_on_N, ppo=True, fbsde=False)
    plot_controls(t_plot, v_star, v_pg, v_fbsde, v_ppo, v_pmp_on_N, ppo=False, fbsde=True)



# ============================================================
# 10. Main
# ============================================================
def main() -> None:
    set_seed(CFG["seed"])

    policy_pg = LSTMPolicy(hidden_size=CFG["ppo_hidden"])
    policy_pg = warmup_train(policy_pg, CFG)

    fbsde_net = FBSDENet(hidden_size=CFG["ppo_hidden"])
    fbsde_net = train_fbsde(fbsde_net, CFG)

    policy_ppo = train_ppo_window(CFG)

    compute_metrics_and_plot(policy_pg, fbsde_net, policy_ppo, CFG)



if __name__ == "__main__":
    main()
