In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.linalg as linalg
import scipy.stats as st
from scipy.stats import ks_2samp

from matplotlib import rc
rc('text', usetex=True)
rc('font', family='serif')


np.set_printoptions(linewidth=180)

# Plot style (matched to weak/strong figures)
LABEL_FONTSIZE = 14
TICK_LABELSIZE = 14
TITLE_FONTSIZE = 14


def style_axes(ax, xlabel=None, ylabel=None, title=None):
    """Apply consistent font sizes to an axis."""
    if title is not None:
        ax.set_title(title, fontsize=TITLE_FONTSIZE)
    if xlabel is not None:
        ax.set_xlabel(xlabel, fontsize=LABEL_FONTSIZE)
    if ylabel is not None:
        ax.set_ylabel(ylabel, fontsize=LABEL_FONTSIZE)
    ax.tick_params(labelsize=TICK_LABELSIZE)
    return ax

# Scheme styles (consistent across all plots)
SCHEME_STYLES = {
    'nodal': {'label': 'Nodal (CTMC)', 'color': 'C0'},
    'bin':   {'label': 'Bin-integrated', 'color': 'C1'},
}
SCHEMES = list(SCHEME_STYLES.keys())


In [None]:
class OU1D:
    """
    1D OU process + Markov-chain approximation on [-L, L]
    with absorbing edges, *coupled* via shared Gaussian increments.

    SDE: dX = -X/tau dt + sig dW
    """

    def __init__(self, dt=0.05, T_max=5.0, tau=1.0, sig=0.05,
                 L=0.1, h=0.01, scheme='nodal', seed=5):
        self.dt    = dt
        self.T_max = T_max
        self.tau   = tau
        self.sig   = sig
        self.L     = L
        self.h     = h
        self.seed  = seed
        self.scheme = scheme

        if scheme == "nodal":
            self.x_states, Q = self._make_Q1d_nodal()
            self.P1d = linalg.expm(self.dt * Q)
        elif scheme == "bin":
            self.x_states, self.P1d = self._make_P1d_bin()
        else:
            raise ValueError("scheme must be 'nodal' or 'bin'")

        self.cdf1d  = np.cumsum(self.P1d, axis=1)
        self.mass1d = self.cdf1d[:, -1]

    def _make_Q1d_nodal(self):
        h   = self.h
        L   = self.L
        tau = self.tau
        sig = self.sig

        x = np.arange(-L, L + h, h)

        diag     = -sig**2 / h**2 + 0 * x
        diag_inf = 0.5 * sig**2 / h**2 + x / (tau * 2 * h)
        diag_sup = 0.5 * sig**2 / h**2 - x / (tau * 2 * h)

        Q = np.diag(diag)
        Q += np.diag(diag_sup[:-1],  1)
        Q += np.diag(diag_inf[1:],  -1)

        return x, Q

    def _make_P1d_bin(self):
        """Bin-integrated EM kernel on [-L, L] (missing mass = absorption).

        States are nodes x_i on the CTMC grid; bins are Voronoi cells around nodes
        with edges at midpoints, truncated to [-L, L].
        """
        h   = self.h
        L   = self.L
        dt  = self.dt
        tau = self.tau
        sig = self.sig

        # CTMC grid nodes
        x = np.arange(-L, L + h, h)
        N = x.size

        # Cell edges: [-L, midpoints..., L]
        edges = np.empty(N + 1)
        edges[0] = -L
        edges[-1] = L
        edges[1:-1] = 0.5 * (x[:-1] + x[1:])

        P = np.zeros((N, N))
        s = sig * np.sqrt(dt)

        for i, x0 in enumerate(x):
            m = x0 + dt * (-x0 / tau)
            z_edges = (edges - m) / s
            cdf_edges = st.norm.cdf(z_edges)
            P[i, :] = cdf_edges[1:] - cdf_edges[:-1]

        return x, P

    def simulate(self, N_rep=1000, x0=0.0):
        """
        Simulate N_rep coupled paths of:
          - the 1D SDE by Euler--Maruyama (killed when |X|>L)
          - the 1D Markov chain on the grid (killed by missing mass)

        Coupling uses a shared Z ~ N(0,1) per step and U=Phi(Z).
        """
        np.random.seed(self.seed)

        N_T = int(self.T_max / self.dt) + 1
        T = np.arange(N_T) * self.dt

        X_sde = np.zeros((N_rep, N_T))
        X_mc  = np.zeros((N_rep, N_T))

        alive_sde = np.ones(N_rep, dtype=bool)
        alive_mc  = np.ones(N_rep, dtype=bool)

        X_sde[:, 0] = x0
        i0 = np.argmin(np.abs(self.x_states - x0))
        ix = np.full(N_rep, i0, dtype=int)
        X_mc[:, 0] = self.x_states[i0]

        for k in range(1, N_T):
            Z = np.random.randn(N_rep)

            # SDE (Euler--Maruyama)
            X_prev = X_sde[:, k-1].copy()
            X_new  = X_prev + self.dt * (-X_prev / self.tau) + self.sig * np.sqrt(self.dt) * Z
            inside = np.abs(X_new) <= self.L
            alive_sde = alive_sde & inside
            X_new[~alive_sde] = np.nan
            X_sde[:, k] = X_new

            # MC (inverse-CDF sampling with absorption)
            U = st.norm.cdf(Z)
            for r in range(N_rep):
                if not alive_mc[r]:
                    X_mc[r, k] = np.nan
                    continue

                cdf = self.cdf1d[ix[r]]
                mass = cdf[-1]
                if U[r] > mass:
                    alive_mc[r] = False
                    X_mc[r, k] = np.nan
                    continue

                new_ix = np.searchsorted(cdf, U[r])
                ix[r] = new_ix
                X_mc[r, k] = self.x_states[new_ix]

        return T, X_sde, X_mc


def plot_coupled_paths_1d(T, X_sde, X_mc, x_states=None, title=None):
    """1D quick sanity plot: one SDE path vs one MC path (coupled)."""
    fig, ax = plt.subplots(1, 1, figsize=(4, 4))
    ax.plot(T, X_sde, '-o', ms=2, label='SDE (EM)')
    ax.plot(T, X_mc,  '-o', ms=2, label='MCA')

    if x_states is not None:
        for x in x_states:
            ax.axhline(x, alpha=0.10, lw=0.5, color='k')

    style_axes(ax, xlabel=r'$t$', ylabel=r'$x$', title=title)
    ax.legend()
    fig.tight_layout()
    return fig, ax


def plot_coupled_paths_2d(X_sde, X_mc, x_states, L, title_left=None, title_right=None):
    """2D quick sanity plot: SDE path and coupled SDE/MC paths."""
    fig, ax = plt.subplots(1, 1, figsize=(4, 4))

    ax.plot(X_sde[:, 0], X_sde[:, 1], '-o', ms=2, label='SDE (EM)')
    ax.plot(X_mc[:, 0],  X_mc[:, 1],  '-o', ms=2, label='MCA')
    ax.set_xlim(-L, L)
    ax.set_ylim(-L, L)
    ax.set_aspect('equal', 'box')

    for x in x_states:
        ax.axhline(x, alpha=0.10, lw=0.5, color='k')
        ax.axvline(x, alpha=0.10, lw=0.5, color='k')

    style_axes(ax, xlabel='x', ylabel='y', title=title_right)
    ax.legend()

    fig.tight_layout()
    return fig, ax


class OU2D:
    """
    2D OU process + Markov-chain approximation on a square [-L, L]^2
    with absorbing edges, *coupled* via shared Gaussian increments.

    SDE: dX = -X/tau dt + sig dW_x
         dY = -Y/tau dt + sig dW_y
    """

    def __init__(self, dt=0.05, T_max=5.0, tau=1.0, sig=0.05,
                 L=0.1, h=0.01, scheme='nodal', seed=5):
        self.dt    = dt
        self.T_max = T_max
        self.tau   = tau
        self.sig   = sig
        self.L     = L
        self.h     = h
        self.seed  = seed

        # 1D grid and generator with absorbing edges
        self.scheme = scheme
        if scheme == "nodal":
            # your old construction
            self.x_states, self.Q1d = self._make_Q1d_nodal()
            self.P1d = linalg.expm(self.dt * self.Q1d)
        elif scheme == "bin":
            # new: build P1d directly by Gaussian bin-integration
            self.x_states, self.P1d = self._make_P1d_bin()
        else:
            raise ValueError("scheme must be 'nodal' or 'bin'")
        
        # 1D transition kernel (continuous-time → dt step)
        # Rows: from-state, columns: to-state
        #self.P1d = linalg.expm(self.dt * self.Q1d)

        # CDFs per row for fast sampling, including absorption (mass < 1)
        self.cdf1d  = np.cumsum(self.P1d, axis=1)   # shape (N, N)
        self.mass1d = self.cdf1d[:, -1]             # total mass ≤ 1 per row

        # 2D grid (for plotting only)
        Xg, Yg = np.meshgrid(self.x_states, self.x_states, indexing="ij")
        self.Xg = Xg.ravel()
        self.Yg = Yg.ravel()

    def _make_Q1d_nodal(self):
        h   = self.h
        L   = self.L
        tau = self.tau
        sig = self.sig

        # === current scheme (collocation on nodes) ===
        x = np.arange(-L, L + h, h)

        diag     = -sig**2 / h**2 + 0 * x
        diag_inf = 0.5 * sig**2 / h**2 + x / (tau * 2 * h)
        diag_sup = 0.5 * sig**2 / h**2 - x / (tau * 2 * h)

        Q = np.diag(diag)
        Q += np.diag(diag_sup[:-1],  1)
        Q += np.diag(diag_inf[1:],  -1)

        return x, Q


    def _make_P1d_bin(self):
        """Bin-integrated EM kernel on [-L, L] (missing mass = absorption).

        States are nodes x_i on the CTMC grid; bins are Voronoi cells around nodes
        with edges at midpoints, truncated to [-L, L].
        """
        h   = self.h
        L   = self.L
        dt  = self.dt
        tau = self.tau
        sig = self.sig

        # CTMC grid nodes
        x = np.arange(-L, L + h, h)
        N = x.size

        # Cell edges: [-L, midpoints..., L]
        edges = np.empty(N + 1)
        edges[0] = -L
        edges[-1] = L
        edges[1:-1] = 0.5 * (x[:-1] + x[1:])

        P = np.zeros((N, N))
        s = sig * np.sqrt(dt)

        for i, x0 in enumerate(x):
            m = x0 + dt * (-x0 / tau)
            z_edges = (edges - m) / s
            cdf_edges = st.norm.cdf(z_edges)
            P[i, :] = cdf_edges[1:] - cdf_edges[:-1]

        return x, P

    def simulate(self, N_rep=1000, x0=0.0, y0=0.0):
        """
        Simulate N_rep paths of:
          - the 2D SDE by Euler–Maruyama
          - the 2D Markov chain on the grid

        Coupling:
          - Draw Z ~ N(0, I_2)
          - Use it in the SDE: dW = sqrt(dt)*Z
          - Map each Z component to U via Φ to drive the 1D MC:
                U_x = Φ(Z_x), U_y = Φ(Z_y)
          - 2D transition is product of the two 1D transitions.

        Absorbing edge:
          - SDE: once |X|>L or |Y|>L ⇒ NaN thereafter
          - MC: if U_x or U_y falls in the "absorbing tail"
                of the 1D CDF ⇒ NaN thereafter
        """
        np.random.seed(self.seed)

        N_T = int(self.T_max / self.dt) + 1

        # Containers: (rep, time, 2 coords)
        X_sde = np.zeros((N_rep, N_T, 2))
        X_mc  = np.zeros((N_rep, N_T, 2))

        # Alive flags for each process (separate)
        alive_sde = np.ones(N_rep, dtype=bool)
        alive_mc  = np.ones(N_rep, dtype=bool)

        # Initial condition at (x0, y0)
        X_sde[:, 0, 0] = x0
        X_sde[:, 0, 1] = y0

        # Markov initial indices: closest grid point to (x0, y0)
        i0 = np.argmin(np.abs(self.x_states - x0))
        j0 = np.argmin(np.abs(self.x_states - y0))

        ix = np.full(N_rep, i0, dtype=int)
        iy = np.full(N_rep, j0, dtype=int)

        X_mc[:, 0, 0] = self.x_states[i0]
        X_mc[:, 0, 1] = self.x_states[j0]

        t = 0.0
        for k in range(1, N_T):
            t += self.dt

            # ====== Shared Gaussian increments ======
            # Z ~ N(0, I_2)
            Z = np.random.randn(N_rep, 2)

            # ====== 2D SDE: Euler–Maruyama with absorbing box ======
            X_prev = X_sde[:, k-1, :].copy()
            dW     = np.sqrt(self.dt) * Z
            X_new  = X_prev + self.dt * (-X_prev / self.tau) + self.sig * dW

            inside = (np.abs(X_new[:, 0]) <= self.L) & (np.abs(X_new[:, 1]) <= self.L)
            alive_sde = alive_sde & inside
            X_new[~alive_sde] = np.nan
            X_sde[:, k, :] = X_new

            # ====== 2D Markov chain: product of 1D kernels, coupled via Z ======
            # U_x, U_y obtained from Z via Φ (component-wise)
            #U = 0.5 * (1.0 + np.erf(Z / np.sqrt(2.0)))  # shape (N_rep, 2), uniform (0,1)
            U = st.norm.cdf(Z)
            
            for r in range(N_rep):
                if not alive_mc[r]:
                    X_mc[r, k, :] = np.nan
                    continue

                u_x, u_y = U[r, 0], U[r, 1]

                # --- X dimension ---
                cdf_x  = self.cdf1d[ix[r]]
                mass_x = cdf_x[-1]  # ≤ 1
                if u_x > mass_x:
                    alive_mc[r] = False
                    X_mc[r, k, :] = np.nan
                    continue
                new_ix = np.searchsorted(cdf_x, u_x)

                # --- Y dimension ---
                cdf_y  = self.cdf1d[iy[r]]
                mass_y = cdf_y[-1]
                if u_y > mass_y:
                    alive_mc[r] = False
                    X_mc[r, k, :] = np.nan
                    continue
                new_iy = np.searchsorted(cdf_y, u_y)

                # Survived in both dims: update indices and position
                ix[r] = new_ix
                iy[r] = new_iy
                X_mc[r, k, 0] = self.x_states[new_ix]
                X_mc[r, k, 1] = self.x_states[new_iy]

        T = np.arange(N_T) * self.dt
        return T, X_sde, X_mc



def indicator_ball(x, center=(0.0, 0.0), radius=0.1):
    """
    x : array (N, 2)
    retourne un booléen (N,) pour 1_{|x-center| <= radius}
    """
    dx = x[:, 0] - center[0]
    dy = x[:, 1] - center[1]
    return dx*dx + dy*dy <= radius*radius


def smooth_indicator_ball(x, center=(0.0, 0.0), radius=0.1, eps=0.02):
    """
    Version "regularized indicator":
    phi(x) ~ 1_{|x-center| <= radius} adoucie sur une bande eps.
    On renvoie une valeur dans [0,1].
    """
    dx = x[:, 0] - center[0]
    dy = x[:, 1] - center[1]
    r2 = dx*dx + dy*dy
    # transition autour de radius^2 sur largeur eps
    z = (radius*radius - r2) / eps
    return 1.0 / (1.0 + np.exp(-z))

def prob_in_set_for_init(h,
                         x0, y0,
                         set_indicator,      # indicatrice (bool) OU version lisse [0,1]
                         N_rep=5000,
                         T_max=1.0,
                         dt=0.01,
                         tau=1.0,
                         sig=0.05,
                         L=0.3,
                         scheme='nodal',
                         seed=0,
                         smooth=False):
    """
    Retourne (P_EM, P_MC) pour un ensemble A (via set_indicator),
    une condition initiale (x0,y0) et un pas de grille h.

    - Si smooth=False : set_indicator doit renvoyer bool (vraie indicatrice).
    - Si smooth=True  : set_indicator renvoie des poids dans [0,1] (approx. indicatrice)
                        et on calcule E[phi(X_T)] pour EM et MC.
    """
    sim = OU2D(dt=dt, T_max=T_max, tau=tau, sig=sig, L=L, h=h, scheme=scheme, seed=seed)
    T, X_sde, X_mc = sim.simulate(N_rep=N_rep, x0=x0, y0=y0)

    final_sde = X_sde[:, -1, :]  # (N_rep, 2)
    final_mc  = X_mc[:, -1, :]

    # Absorption : NaN => en dehors de l'ensemble
    alive_sde = ~np.isnan(final_sde[:, 0])
    alive_mc  = ~np.isnan(final_mc[:, 0])

    if smooth:
        # version régulière : phi(x) dans [0,1]
        phi_sde = np.zeros(N_rep)
        phi_mc  = np.zeros(N_rep)

        phi_sde[alive_sde] = set_indicator(final_sde[alive_sde])
        phi_mc[alive_mc]   = set_indicator(final_mc[alive_mc])

        P_EM = phi_sde.mean()
        P_MC = phi_mc.mean()
    else:
        # véritable indicatrice : booléen
        in_sde = np.zeros(N_rep, dtype=bool)
        in_mc  = np.zeros(N_rep, dtype=bool)

        in_sde[alive_sde] = set_indicator(final_sde[alive_sde])
        in_mc[alive_mc]   = set_indicator(final_mc[alive_mc])

        # probas de se trouver dans A (absorbé = "en dehors")
        P_EM = in_sde.mean()
        P_MC = in_mc.mean()

    return P_EM, P_MC


def weak_error_sup_over_inits(h,
                              init_grid,     # liste/array de (x0,y0)
                              set_indicator,
                              N_rep=5000,
                              smooth=False,
                              **kwargs):
    """
    Pour un h donné et une grille de conditions initiales (x0,y0),
    renvoie :
      - max_err = max_{(x0,y0) in grille} |P_EM - P_MC|
      - errors  = tableau des erreurs pour chaque init
    """
    errors = []
    for (x0, y0) in init_grid:
        P_EM, P_MC = prob_in_set_for_init(h, x0, y0, set_indicator,
                                          N_rep=N_rep, smooth=smooth, **kwargs)
        errors.append(abs(P_EM - P_MC))
    return np.max(errors), np.array(errors)


# ------------------------------------------------------------------
# 1D set indicators (for weak error)
# ------------------------------------------------------------------

def indicator_interval(x, a=-0.1, b=0.1):
    x = np.asarray(x)
    return (a <= x) & (x <= b)


def smooth_indicator_interval(x, a=-0.1, b=0.1, eps=0.02):
    # Piecewise-linear smoothing of 1_{[a,b]} with transition width eps.
    x = np.asarray(x)
    y = np.zeros_like(x, dtype=float)

    left = (a - eps < x) & (x < a)
    inside = (a <= x) & (x <= b)
    right = (b < x) & (x < b + eps)

    y[left] = (x[left] - (a - eps)) / eps
    y[inside] = 1.0
    y[right] = ((b + eps) - x[right]) / eps
    return y


def prob_in_set_for_init_1d(h,
                            x0,
                            set_indicator,
                            N_rep=5000,
                            smooth=False,
                            T_max=1.0,
                            dt=0.01,
                            tau=1.0,
                            sig=0.05,
                            L=0.3,
                            scheme='nodal',
                            seed=0):
    # Return (P_EM, P_MC) for a 1D set event at final time T_max.
    sim = OU1D(dt=dt, T_max=T_max, tau=tau, sig=sig, L=L, h=h, scheme=scheme, seed=seed)
    _, X_sde, X_mc = sim.simulate(N_rep=N_rep, x0=x0)

    final_sde = X_sde[:, -1]
    final_mc = X_mc[:, -1]

    alive_sde = ~np.isnan(final_sde)
    alive_mc = ~np.isnan(final_mc)

    if smooth:
        phi_sde = np.zeros(N_rep)
        phi_mc = np.zeros(N_rep)

        phi_sde[alive_sde] = set_indicator(final_sde[alive_sde])
        phi_mc[alive_mc] = set_indicator(final_mc[alive_mc])

        return phi_sde.mean(), phi_mc.mean()

    in_sde = np.zeros(N_rep, dtype=bool)
    in_mc = np.zeros(N_rep, dtype=bool)

    in_sde[alive_sde] = set_indicator(final_sde[alive_sde])
    in_mc[alive_mc] = set_indicator(final_mc[alive_mc])

    return in_sde.mean(), in_mc.mean()


def weak_error_sup_over_inits_1d(h,
                                init_grid,
                                set_indicator,
                                N_rep=5000,
                                smooth=False,
                                **kwargs):
    # For a given h, compute sup_{x0 in init_grid} |P_EM - P_MC| in 1D.
    errors = []
    for x0 in init_grid:
        P_EM, P_MC = prob_in_set_for_init_1d(h, x0, set_indicator,
                                             N_rep=N_rep, smooth=smooth, **kwargs)
        errors.append(abs(P_EM - P_MC))
    return np.max(errors), np.array(errors)


# ------------------------------------------------------------------
# Strong coupling errors (1D/2D)
# ------------------------------------------------------------------

def coupling_sup_error_1d(h,
                          N_rep=1000,
                          T_max=1.0,
                          dt=0.01,
                          tau=1.0,
                          sig=0.05,
                          L=0.3,
                          scheme='nodal',
                          seed=0,
                          x0=0.0):
    # Mean and q90 of sup_t |MC - EM| in 1D (up to absorption).
    sim = OU1D(dt=dt, T_max=T_max, tau=tau, sig=sig, L=L, h=h, scheme=scheme, seed=seed)
    _, X_sde, X_mc = sim.simulate(N_rep=N_rep, x0=x0)

    sup_errors = []
    for r in range(N_rep):
        alive = (~np.isnan(X_sde[r, :])) & (~np.isnan(X_mc[r, :]))
        if not np.any(alive):
            continue
        diff = np.abs(X_sde[r, alive] - X_mc[r, alive])
        sup_errors.append(np.nanmax(diff))

    sup_errors = np.array(sup_errors)
    return sup_errors.mean(), np.quantile(sup_errors, 0.90), sup_errors


def coupling_sup_error_2d(h,
                          N_rep=1000,
                          T_max=1.0,
                          dt=0.01,
                          tau=1.0,
                          sig=0.05,
                          L=0.3,
                          scheme='nodal',
                          seed=0,
                          x0=0.0,
                          y0=0.0):
    # Mean and q90 of sup_t ||MC - EM|| in 2D (up to absorption).
    sim = OU2D(dt=dt, T_max=T_max, tau=tau, sig=sig, L=L, h=h, scheme=scheme, seed=seed)
    _, X_sde, X_mc = sim.simulate(N_rep=N_rep, x0=x0, y0=y0)

    sup_errors = []
    for r in range(N_rep):
        alive = (~np.isnan(X_sde[r, :, 0])) & (~np.isnan(X_mc[r, :, 0]))
        if not np.any(alive):
            continue
        diff = np.linalg.norm(X_sde[r, alive, :] - X_mc[r, alive, :], axis=1)
        sup_errors.append(diff.max())

    sup_errors = np.array(sup_errors)
    return sup_errors.mean(), np.quantile(sup_errors, 0.90), sup_errors


# ------------------------------------------------------------------
# Multi-scheme plotting helpers
# ------------------------------------------------------------------

from matplotlib.lines import Line2D


def plot_paths_1d(T, X_sde_1d, mc_1d_by_scheme, x_states, scheme_styles, title=None, show_legend=False):
    """1D: one EM path + both MCA schemes, with CTMC spatial grid shown."""
    fig, ax = plt.subplots(1, 1, figsize=(4, 4))

    ax.plot(T, X_sde_1d, color='k', lw=1.5, alpha=0.9, label='SDE (EM)')
    for scheme, X_mc in mc_1d_by_scheme.items():
        st = scheme_styles[scheme]
        ax.plot(T, X_mc, color=st['color'], lw=1.5, alpha=0.9, label=st['label'])

    # Show the CTMC grid as minor grid lines at the state locations
    ax.set_yticks(x_states, minor=True)
    ax.grid(True, which='both', ls=':', alpha=0.5)

    # Keep axes consistent with the (absorbing) domain
    ax.set_ylim(np.min(x_states), np.max(x_states))

    style_axes(ax, xlabel=r'$t$', ylabel=r'$x$', title=title)

    if show_legend:
        ax.legend(fontsize=TICK_LABELSIZE,
                    handlelength=1, 
                    borderaxespad=0.25,
                    handletextpad=0.4,)

    fig.tight_layout()
    return fig, ax


def plot_paths_2d(X_sde_2d, mc_2d_by_scheme, x_states, L, scheme_styles, title=None, show_legend=True):
    """2D: one EM path + both MCA schemes, with CTMC spatial grid shown."""
    fig, ax = plt.subplots(1, 1, figsize=(4, 4))

    ax.plot(X_sde_2d[:, 0], X_sde_2d[:, 1], color='k', lw=1.5, alpha=0.9, label='SDE (EM)')
    for scheme, X_mc in mc_2d_by_scheme.items():
        st = scheme_styles[scheme]
        ax.plot(X_mc[:, 0], X_mc[:, 1], color=st['color'], lw=1.5, alpha=0.9, label=st['label'])

    ax.set_xlim(-L, L)
    ax.set_ylim(-L, L)
    ax.set_aspect('equal', 'box')

    # Show the CTMC grid as minor grid lines at the state locations
    ax.set_xticks(x_states, minor=True)
    ax.set_yticks(x_states, minor=True)
    ax.grid(True, which='both', ls=':', alpha=0.5)

    style_axes(ax, xlabel=r'$x$', ylabel=r'$y$', title=title)

    if show_legend:
        ax.legend(fontsize=TICK_LABELSIZE)

    fig.tight_layout()
    return fig, ax


def plot_errors(h_list, errors_by_scheme, scheme_styles, title=None, show_legend=False):
    """Weak (dotted) + strong (solid) errors vs h for one dimension."""
    fig, ax = plt.subplots(1, 1, figsize=(4, 4))

    ax.loglog(h_list,h_list, c='k', label='linear', ls='-.')

    for scheme, st in scheme_styles.items():
        ax.loglog(h_list, errors_by_scheme[scheme]['strong'], linestyle='-', marker='o', color=st['color'])
        ax.loglog(h_list, errors_by_scheme[scheme]['weak'], linestyle=':', marker='o', color=st['color'])

    style_axes(ax, xlabel='h (grid size)', ylabel='Error', title=title)
    ax.grid(True, which='both', ls=':')

    if show_legend:
        scheme_handles = [Line2D([], [], color=st['color'], lw=2, label=st['label'])
                          for st in scheme_styles.values()]
        type_handles = [
            Line2D([], [], color='k', lw=2, linestyle='-', label='Strong'),
            Line2D([], [], color='k', lw=2, linestyle=':', label='Weak'),
            Line2D([], [], color='k', lw=2, linestyle='-.', label='Linear'),
        ]

        leg1 = ax.legend(handles=scheme_handles, loc='upper left',
                        fontsize=TICK_LABELSIZE,        
                        handlelength=1, 
                        borderaxespad=0.25,
                        handletextpad=0.4,)
        ax.add_artist(leg1)
        ax.legend(handles=type_handles, loc='lower right',
                fontsize=TICK_LABELSIZE,
                handlelength=1, 
                borderaxespad=0.25,
                handletextpad=0.4,)

    fig.tight_layout()
    return fig, ax


In [None]:
# ------------------------------------------------------------------
# Coupled sample paths (export as two separate figures: 1D and 2D)
# ------------------------------------------------------------------

dt    = 0.05
T_max = 10.0

tau   = 1.0
sig   = 0.5

L     = 1.2
h     = 0.1

seed  = 2
x0_1d = 0.0
x0_2d = 0.0
y0_2d = 0.0

# 1D: simulate both schemes (shared randomness via same seed)
mc_1d_by_scheme = {}
for scheme in SCHEMES:
    sim = OU1D(dt=dt, T_max=T_max, tau=tau, sig=sig, L=L, h=h, scheme=scheme, seed=seed)
    T_1d, X_sde_1d, X_mc_1d = sim.simulate(N_rep=1, x0=x0_1d)
    mc_1d_by_scheme[scheme] = X_mc_1d[0]

# CTMC grid (same for both schemes)
x_states = sim.x_states

# 2D: simulate both schemes (shared randomness via same seed)
mc_2d_by_scheme = {}
for scheme in SCHEMES:
    sim2 = OU2D(dt=dt, T_max=T_max, tau=tau, sig=sig, L=L, h=h, scheme=scheme, seed=seed)
    T_2d, X_sde_2d, X_mc_2d = sim2.simulate(N_rep=1, x0=x0_2d, y0=y0_2d)
    mc_2d_by_scheme[scheme] = X_mc_2d[0]


In [None]:
# Plot: 1D coupled path (no legend; the 2D plot carries the legend)

fig, ax = plot_paths_1d(
    T_1d,
    X_sde_1d[0],
    mc_1d_by_scheme,
    x_states=x_states,
    scheme_styles=SCHEME_STYLES,
    title=None,
    show_legend=True,
)
plt.show()


In [None]:
fig.savefig('figs/paths-1d.pdf', format='pdf', bbox_inches='tight')

In [None]:
# Plot: 2D coupled path (legend shown once)

fig, ax = plot_paths_2d(
    X_sde_2d[0],
    mc_2d_by_scheme,
    x_states=x_states,
    L=L,
    scheme_styles=SCHEME_STYLES,
    title=None,
    show_legend=False,
)
plt.show()


In [None]:
fig.savefig('figs/paths-2d.pdf', format='pdf', bbox_inches='tight')

In [None]:
# ------------------------------------------------------------------
# Weak + strong errors vs h (export as two separate figures: 1D and 2D)
#   - strong: solid line
#   - weak:   dotted line
#   - color:  scheme (consistent with the path plots)
# ------------------------------------------------------------------

# Shared parameters
T_max = 1.0
dt    = 0.01
tau   = 1.0
sig   = 0.05
L     = 0.3

h_list = np.logspace(-3, -1, 8)

# Monte Carlo budgets (adjust as needed)
N_rep_weak   = 1000
N_rep_strong = 1000

seed_weak   = 1
seed_strong = 2

# 1D target set A: interval around 0
set_indicator_1d = lambda x: indicator_interval(x, a=-0.1, b=0.1)
# set_indicator_1d = lambda x: smooth_indicator_interval(x, a=-0.1, b=0.1, eps=0.02)

x0_vals = np.linspace(-0.1, 0.1, 3)
init_grid_1d = list(x0_vals)

# 2D target set A: ball around 0
set_indicator_2d = lambda x: indicator_ball(x, center=(0.0, 0.0), radius=0.1)
# set_indicator_2d = lambda x: smooth_indicator_ball(x, center=(0.0, 0.0), radius=0.1, eps=0.01)

x0_vals = np.linspace(-0.1, 0.1, 3)
y0_vals = np.linspace(-0.1, 0.1, 3)
init_grid_2d = [(x0, y0) for x0 in x0_vals for y0 in y0_vals]

# Containers
errors_1d = {scheme: {'weak': [], 'strong': []} for scheme in SCHEMES}
errors_2d = {scheme: {'weak': [], 'strong': []} for scheme in SCHEMES}

for h in h_list:
    print(f"h = {h:.4g}")
    for scheme in SCHEMES:
        # Weak errors
        w1d, _ = weak_error_sup_over_inits_1d(
            h,
            init_grid_1d,
            set_indicator_1d,
            N_rep=N_rep_weak,
            smooth=False,
            T_max=T_max,
            dt=dt,
            tau=tau,
            sig=sig,
            L=L,
            scheme=scheme,
            seed=seed_weak,
        )
        w2d, _ = weak_error_sup_over_inits(
            h,
            init_grid_2d,
            set_indicator_2d,
            N_rep=N_rep_weak,
            smooth=False,
            T_max=T_max,
            dt=dt,
            tau=tau,
            sig=sig,
            L=L,
            scheme=scheme,
            seed=seed_weak,
        )
        errors_1d[scheme]['weak'].append(w1d)
        errors_2d[scheme]['weak'].append(w2d)

        # Strong errors (mean sup error under common-noise coupling)
        s1d, _, _ = coupling_sup_error_1d(
            h,
            N_rep=N_rep_strong,
            T_max=T_max,
            dt=dt,
            tau=tau,
            sig=sig,
            L=L,
            scheme=scheme,
            seed=seed_strong,
            x0=0.0,
        )
        s2d, _, _ = coupling_sup_error_2d(
            h,
            N_rep=N_rep_strong,
            T_max=T_max,
            dt=dt,
            tau=tau,
            sig=sig,
            L=L,
            scheme=scheme,
            seed=seed_strong,
            x0=0.0,
            y0=0.0,
        )
        errors_1d[scheme]['strong'].append(s1d)
        errors_2d[scheme]['strong'].append(s2d)

# Convert to numpy arrays
for errs in (errors_1d, errors_2d):
    for scheme in SCHEMES:
        for k in ('weak', 'strong'):
            errs[scheme][k] = np.asarray(errs[scheme][k])


In [None]:
# Plot: 1D errors (no legend; the 2D plot carries the legend)

fig, ax = plot_errors(
    h_list,
    errors_1d,
    scheme_styles=SCHEME_STYLES,
    title=None,
    show_legend=True,
)
plt.show()


In [None]:
fig.savefig('figs/errors-1d.pdf', format='pdf', bbox_inches='tight')

In [None]:
# Plot: 2D errors (legend shown once)

fig, ax = plot_errors(
    h_list,
    errors_2d,
    scheme_styles=SCHEME_STYLES,
    title=None,
    show_legend=False,
)
plt.show()


In [None]:
fig.savefig('figs/errors-2d.pdf', format='pdf', bbox_inches='tight')