In [None]:
"""
Second-order AP-IMEX Finite Volume solver (1D prototype)
=======================================================

Implements a robust, structure-preserving scheme for a hybrid PDE–ODE
reaction–diffusion–chemotaxis system with non-diffusive (P,A).

Key design choices (matching the paper text):
- Finite Volume (cell averages) for transport/chemotaxis in conservative form
- MUSCL reconstruction with TVD slope limiter (minmod by default)
- Monotone upwind numerical flux for chemotaxis/advection
- IMEX Runge–Kutta ARS(2,2,2) for time stepping:
    * diffusion (and linear decay terms) treated implicitly
    * transport + nonlinear reactions treated explicitly
- Non-diffusive subsystem (P,A) updated by a positivity-preserving closed-form step
  consistent with the pointwise invariant P+A = const.

This is a 1D reference implementation intended for reproducibility and extension to 2D.
No external dependencies beyond numpy.

Model implemented (chemotaxis extension of the base hybrid model):
    ∂t S = dS ΔS - χS ∂x ( S ∂x A ) + λS S(1-(S+R)/K) - α S - δ(D) S + ξ(1-φ(D)) R
    ∂t R = dR ΔR - χR ∂x ( R ∂x A ) + λR R(1-(S+R)/K) + α S + η φ(D) A R - ξ(1-φ(D)) R
    ∂t D = dD ΔD - γd D
    (P,A) pointwise ODE (non-diffusive):
        ∂t P = -θ φ(D) P + β(1-φ(D)) A
        ∂t A =  θ φ(D) P - β(1-φ(D)) A
  with pointwise invariant M(x)=P(x,t)+A(x,t)=P0(x)+A0(x).

Boundary conditions:
- Neumann (zero-flux) for diffusive variables S,R,D (standard for Turing/chemotaxis)
- Chemotaxis flux is enforced as zero at boundaries by setting interface velocities to 0.

Notes:
- Positivity: the transport discretization is monotone under a CFL condition.
  Diffusion is implicit, and the (P,A) update is unconditionally positivity-preserving.
  Small negative values from roundoff are clipped to 0 as a safety measure.
"""

from __future__ import annotations
import math
import numpy as np
from dataclasses import dataclass
from typing import Callable, Dict, Tuple


# ----------------------------
# Utilities: slope limiter, FV reconstruction, tridiagonal solver
# ----------------------------

def minmod(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    """Minmod limiter, elementwise."""
    out = np.zeros_like(a)
    same_sign = (a * b) > 0
    out[same_sign] = np.sign(a[same_sign]) * np.minimum(np.abs(a[same_sign]), np.abs(b[same_sign]))
    return out


def muscl_reconstruct(W: np.ndarray, limiter: Callable[[np.ndarray, np.ndarray], np.ndarray] = minmod
                      ) -> Tuple[np.ndarray, np.ndarray]:
    """
    MUSCL reconstruction for cell averages W[i] (i=0..N-1).
    Returns left/right interface states:
        W_L[i+1/2] = W_{i+1/2}^{-,} from cell i
        W_R[i+1/2] = W_{i+1/2}^{+,} from cell i+1
    Arrays returned are length N+1 for interfaces at i-1/2 .. N-1/2.
    Boundary interfaces are filled by extrapolation consistent with zero gradient.
    """
    N = W.size

    # Neumann ghost values for reconstruction (zero gradient): W[-1]=W[0], W[N]=W[N-1]
    Wm1 = np.empty(N)
    Wp1 = np.empty(N)
    Wm1[0] = W[0]
    Wm1[1:] = W[:-1]
    Wp1[-1] = W[-1]
    Wp1[:-1] = W[1:]

    dL = W - Wm1
    dR = Wp1 - W
    slope = limiter(dL, dR)  # limited slope per cell

    # Interface states
    W_L = np.empty(N + 1)
    W_R = np.empty(N + 1)

    # interior interfaces i+1/2 for i=0..N-2
    W_L[1:N] = W[:-1] + 0.5 * slope[:-1]
    W_R[1:N] = W[1:] - 0.5 * slope[1:]

    # boundary interfaces: enforce zero gradient (use cell value)
    W_L[0] = W[0]
    W_R[0] = W[0]
    W_L[N] = W[-1]
    W_R[N] = W[-1]

    return W_L, W_R


def thomas_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray:
    """
    Solve tridiagonal system:
        a[i]*x[i-1] + b[i]*x[i] + c[i]*x[i+1] = d[i], i=0..N-1
    where a[0]=0, c[N-1]=0.
    """
    N = d.size
    cp = np.empty(N)
    dp = np.empty(N)

    # Forward sweep
    denom = b[0]
    if abs(denom) < 1e-15:
        raise ZeroDivisionError("Thomas solver: near-zero pivot at i=0.")
    cp[0] = c[0] / denom
    dp[0] = d[0] / denom

    for i in range(1, N):
        denom = b[i] - a[i] * cp[i - 1]
        if abs(denom) < 1e-15:
            raise ZeroDivisionError(f"Thomas solver: near-zero pivot at i={i}.")
        cp[i] = c[i] / denom if i < N - 1 else 0.0
        dp[i] = (d[i] - a[i] * dp[i - 1]) / denom

    # Back substitution
    x = np.empty(N)
    x[-1] = dp[-1]
    for i in range(N - 2, -1, -1):
        x[i] = dp[i] - cp[i] * x[i + 1]
    return x


# ----------------------------
# Model parameters and nonlinearities
# ----------------------------

@dataclass
class Params:
    # Diffusion
    dS: float = 1e-3
    dR: float = 1e-3
    dD: float = 1e-2

    # Chemotaxis sensitivities (can be 0)
    chiS: float = 0.1
    chiR: float = 0.1

    # Logistic growth
    lamS: float = 0.5
    lamR: float = 0.3
    K: float = 1.0

    # Switching / transitions
    alpha: float = 0.01
    xi: float = 0.01

    # D inhibition and decay
    delta0: float = 0.3
    KD: float = 0.2
    gamma_d: float = 0.1

    # A promotes R
    eta: float = 0.2

    # P <-> A transitions driven by phi(D)
    theta: float = 0.3
    beta: float = 0.1

    # phi(D)=tanh(sigma D)
    sigma_phi: float = 5.0

    # Numerics safety
    positivity_clip: float = 0.0  # set to small negative e.g. -1e-14 to allow tiny undershoot


def phi(D: np.ndarray, p: Params) -> np.ndarray:
    return np.tanh(p.sigma_phi * D)


def delta_inhib(D: np.ndarray, p: Params) -> np.ndarray:
    return p.delta0 * D / (D + p.KD)


# ----------------------------
# Discrete operators: Neumann Laplacian and chemotaxis flux divergence
# ----------------------------

def neumann_laplacian_tridiag(N: int, h: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Return tridiagonal arrays (a,b,c) for the 1D Laplacian with homogeneous Neumann BC
    using second-order finite differences:
      (Δu)_0 = (-2u0 + 2u1)/h^2
      (Δu)_i = (u_{i-1} -2u_i + u_{i+1})/h^2
      (Δu)_{N-1} = (2u_{N-2} - 2u_{N-1})/h^2
    """
    a = np.zeros(N)
    b = np.zeros(N)
    c = np.zeros(N)

    invh2 = 1.0 / (h * h)

    # i=0
    b[0] = -2.0 * invh2
    c[0] =  2.0 * invh2

    # interior
    for i in range(1, N - 1):
        a[i] = 1.0 * invh2
        b[i] = -2.0 * invh2
        c[i] = 1.0 * invh2

    # i=N-1
    a[N - 1] = 2.0 * invh2
    b[N - 1] = -2.0 * invh2

    return a, b, c


def chemotaxis_divergence(W: np.ndarray, A: np.ndarray, chi: float, h: float,
                          limiter: Callable[[np.ndarray, np.ndarray], np.ndarray] = minmod
                          ) -> np.ndarray:
    """
    Compute conservative FV approximation to:
        -∂x( chi * W * ∂x A ) = -∂x( W * v ), v = chi * ∂x A
    using upwind flux with MUSCL reconstruction (TVD).
    Boundary fluxes are enforced as zero (no-flux).
    Returns array same size as W (cell-centered divergence).
    """
    N = W.size
    if chi == 0.0:
        return np.zeros_like(W)

    # Reconstruct W to interfaces
    W_L, W_R = muscl_reconstruct(W, limiter=limiter)

    # Interface velocities v_{i+1/2} = chi*(A_{i+1}-A_i)/h for interior interfaces
    v = np.zeros(N + 1)
    v[1:N] = chi * (A[1:] - A[:-1]) / h

    # Enforce zero normal velocity at boundary interfaces to guarantee no-flux:
    v[0] = 0.0
    v[N] = 0.0

    # Upwind flux: F = v^+ W_L + v^- W_R
    vp = np.maximum(v, 0.0)
    vm = np.minimum(v, 0.0)
    F = vp * W_L + vm * W_R

    # Conservative divergence: -(F_{i+1/2}-F_{i-1/2})/h
    div = -(F[1:] - F[:-1]) / h
    return div


# ----------------------------
# (P,A) structure-preserving update (closed form with frozen coefficient)
# ----------------------------

def update_PA_step(Pn: np.ndarray, An: np.ndarray, D_for_phi: np.ndarray, dt: float, p: Params
                   ) -> Tuple[np.ndarray, np.ndarray]:
    """
    Positivity-preserving, invariant-preserving local update for non-diffusive (P,A),
    freezing phi(D) over dt at the provided D_for_phi (e.g., stage value or midpoint).
    Implements:
      kappa = theta*phi + beta*(1-phi)
      P^{n+1} = exp(-kappa dt)*P^n + (1-exp(-kappa dt)) * [ beta*(1-phi)/kappa ] * M
      A^{n+1} = M - P^{n+1}
    with M = P^n + A^n pointwise.
    """
    M = Pn + An
    ph = phi(D_for_phi, p)
    kappa = p.theta * ph + p.beta * (1.0 - ph)  # nonnegative

    # Avoid division by zero (kappa=0 occurs only in degenerate parameter settings)
    eps = 1e-15
    expfac = np.exp(-kappa * dt)
    frac = p.beta * (1.0 - ph) / (kappa + eps)

    Pnp1 = expfac * Pn + (1.0 - expfac) * frac * M
    Anp1 = M - Pnp1

    # Hard positivity safety
    # Pnp1 = np.maximum(Pnp1, 0.0)
    # Anp1 = np.maximum(Anp1, 0.0)
    return Pnp1, Anp1


# ----------------------------
# IMEX RK (ARS(2,2,2)) coefficients
# ----------------------------

@dataclass(frozen=True)
class ARS222:
    gamma: float = 1.0 - 1.0 / math.sqrt(2.0)

    @property
    def aI(self) -> np.ndarray:
        g = self.gamma
        return np.array([[g, 0.0],
                         [1.0 - 2 * g, g]], dtype=float)

    @property
    def aE(self) -> np.ndarray:
        g = self.gamma
        return np.array([[0.0, 0.0],
                         [1.0, 0.0]], dtype=float)

    @property
    def bI(self) -> np.ndarray:
        return np.array([0.5, 0.5], dtype=float)

    @property
    def bE(self) -> np.ndarray:
        return np.array([0.5, 0.5], dtype=float)

    @property
    def cI(self) -> np.ndarray:
        g = self.gamma
        return np.array([g, 1.0 - g], dtype=float)
    
    @property
    def cE(self) -> np.ndarray:
        return np.array([0.0, 1.0], dtype=float)


# ----------------------------
# Solver class
# ----------------------------

class APIMEX1DSolver:
    def __init__(self, N: int, params: Params, limiter=minmod, bc: str = "neumann"):
        assert bc.lower() == "neumann", "This reference solver implements Neumann BC for diffusion."
        self.N = N
        self.p = params
        self.limiter = limiter

        self.x0 = 0.0
        self.x1 = 1.0
        self.h = (self.x1 - self.x0) / N
        self.xc = np.linspace(self.x0 + 0.5 * self.h, self.x1 - 0.5 * self.h, N)

        # Tridiagonal Laplacian (Neumann)
        self.La, self.Lb, self.Lc = neumann_laplacian_tridiag(N, self.h)

        # IMEX tableau
        self.ars = ARS222()

    def _implicit_solve_diffusion(self, W_rhs: np.ndarray, d: float, dt_scale: float,
                                  extra_diag: np.ndarray | float = 0.0) -> np.ndarray:
        """
        Solve (I - dt_scale * d * Lap + dt_scale * extra_diag) W = W_rhs
        extra_diag can be scalar or array (nonnegative).
        """
        N = self.N
        a = -dt_scale * d * self.La.copy()
        b = 1.0 - dt_scale * d * self.Lb.copy()
        c = -dt_scale * d * self.Lc.copy()

        if np.isscalar(extra_diag):
            b += dt_scale * float(extra_diag)
        else:
            b += dt_scale * extra_diag

        # Enforce tridiagonal boundary conventions
        a[0] = 0.0
        c[N - 1] = 0.0
        return thomas_solve(a, b, c, W_rhs)

    def _explicit_reactions(self, S: np.ndarray, R: np.ndarray, D: np.ndarray, A: np.ndarray
                            ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Explicit reaction terms (no diffusion, no chemotaxis).
        Returns (RS, RR, RD) for S,R,D respectively.
        """
        p = self.p
        ph = phi(D, p)
        delt = delta_inhib(D, p)

        # Logistic competition
        total = (S + R) / p.K
        growthS = p.lamS * S * (1.0 - total)
        growthR = p.lamR * R * (1.0 - total)

        # Switching and modulation by phi(D)
        RS = growthS - p.alpha * S - delt * S + p.xi * (1.0 - ph) * R
        RR = growthR + p.alpha * S + p.eta * ph * A * R - p.xi * (1.0 - ph) * R

        # D decay is treated implicitly as "stiff linear" in this implementation
        RD = np.zeros_like(D)
        return RS, RR, RD

    def _explicit_transport(self, S: np.ndarray, R: np.ndarray, A: np.ndarray
                            ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Explicit chemotaxis/transport terms using FV upwind/MUSCL.
        Returns (TS, TR) for S and R respectively.
        """
        p = self.p
        TS = chemotaxis_divergence(S, A, chi=p.chiS, h=self.h, limiter=self.limiter)
        TR = chemotaxis_divergence(R, A, chi=p.chiR, h=self.h, limiter=self.limiter)
        return TS, TR

    def step(self, state: Dict[str, np.ndarray], dt: float) -> Dict[str, np.ndarray]:
        """
        One IMEX step (ARS(2,2,2)) for (S,R,D) with diffusion implicit,
        reactions+chemotaxis explicit, and (P,A) updated structure-preservingly.

        State dict keys: "S","R","D","P","A"
        """
        S0 = state["S"].copy()
        R0 = state["R"].copy()
        D0 = state["D"].copy()
        P0 = state["P"].copy()
        A0 = state["A"].copy()

        p = self.p
        ars = self.ars
        aI, aE, bI, bE, cI, cE = ars.aI, ars.aE, ars.bI, ars.bE, ars.cI, ars.cE
        g = ars.gamma

        # --- Stage 1 ---
        # Explicit eval at (S0,R0,D0,A0)
        RS0, RR0, RD0 = self._explicit_reactions(S0, R0, D0, A0)
        TS0, TR0 = self._explicit_transport(S0, R0, A0)
        ES0 = RS0 + TS0
        ER0 = RR0 + TR0
        ED0 = RD0  # zero

        # Implicit solve for diffusion (and D linear decay) at stage 1:
        # U1 = U0 + dt * [ aE11*E(U0) + aI11*I(U1) ] with aE11=0, aI11=g
        # => (I - g dt dW Lap) W1 = W0 + 0 (explicit part)
        S1 = self._implicit_solve_diffusion(S0, d=p.dS, dt_scale=g * dt)
        R1 = self._implicit_solve_diffusion(R0, d=p.dR, dt_scale=g * dt)
        # For D: implicit part includes diffusion + gamma_d decay: dD Lap D - gamma_d D
        # => (I - g dt dD Lap + g dt gamma_d I) D1 = D0
        D1 = self._implicit_solve_diffusion(D0, d=p.dD, dt_scale=g * dt, extra_diag=p.gamma_d)

        # Update (P,A) to stage 1 time using frozen phi(D1) over dt1=g*dt
        # This provides A-stage information for stage 2 transport/reaction.
        P1, A1 = update_PA_step(P0, A0, D_for_phi=D1, dt=g * dt, p=p)

        # Positivity safety (parabolic implicit solve may create tiny negatives near machine eps)
        # S1 = np.maximum(S1, p.positivity_clip)
        # R1 = np.maximum(R1, p.positivity_clip)
        # D1 = np.maximum(D1, p.positivity_clip)

        # --- Stage 2 ---
        # Need explicit evaluation at stage 1 (S1,R1,D1,A1)
        RS1, RR1, RD1 = self._explicit_reactions(S1, R1, D1, A1)
        TS1, TR1 = self._explicit_transport(S1, R1, A1)
        ES1 = RS1 + TS1
        ER1 = RR1 + TR1
        ED1 = RD1

        # Stage 2 equation:
        # U2 = U0 + dt*(aE21*E(U0) + aE22*E(U1)) + dt*(aI21*I(U1) + aI22*I(U2))
        # Here we take I(U)=diffusion (and D decay).
        # So solve:
        # (I - aI22 dt d Lap) W2 = W0 + dt*(aE21*E0 + aE22*E1) + dt*(aI21*d Lap W1)
        aE21, aE22 = aE[1, 0], aE[1, 1]
        aI21, aI22 = aI[1, 0], aI[1, 1]

        # Build RHS for S2
        rhsS2 = S0 + dt * (aE21 * ES0 + aE22 * ES1)
        # add dt*aI21*diffusion(S1) by applying Laplacian operator explicitly
        # (since it's known at stage 1)
        lapS1 = self.apply_laplacian(S1)
        rhsS2 += dt * aI21 * (p.dS * lapS1)
        S2 = self._implicit_solve_diffusion(rhsS2, d=p.dS, dt_scale=aI22 * dt)

        # RHS for R2
        rhsR2 = R0 + dt * (aE21 * ER0 + aE22 * ER1)
        lapR1 = self.apply_laplacian(R1)
        rhsR2 += dt * aI21 * (p.dR * lapR1)
        R2 = self._implicit_solve_diffusion(rhsR2, d=p.dR, dt_scale=aI22 * dt)

        # RHS for D2: include implicit decay
        rhsD2 = D0 + dt * (aE21 * ED0 + aE22 * ED1)  # ED* are zeros here
        lapD1 = self.apply_laplacian(D1)
        rhsD2 += dt * aI21 * (p.dD * lapD1 - p.gamma_d * D1)
        D2 = self._implicit_solve_diffusion(rhsD2, d=p.dD, dt_scale=aI22 * dt, extra_diag=p.gamma_d)

        # Update (P,A) from stage 1 to stage 2 over dt2=(1-g)*dt, freezing phi(D2)
        P2, A2 = update_PA_step(P1, A1, D_for_phi=D2, dt=(1.0 - g) * dt, p=p)

        # Positivity safety
        # S2 = np.maximum(S2, p.positivity_clip)
        # R2 = np.maximum(R2, p.positivity_clip)
        # D2 = np.maximum(D2, p.positivity_clip)

        # --- Final combination ---
        # U^{n+1} = U0 + dt * sum bE_i E(Ui) + dt * sum bI_i I(Ui)
        # Compute explicit E at both stages (stage 1 and stage 2) and implicit diffusion terms at stages
        RS2, RR2, RD2 = self._explicit_reactions(S2, R2, D2, A2)
        TS2, TR2 = self._explicit_transport(S2, R2, A2)
        ES2 = RS2 + TS2
        ER2 = RR2 + TR2
        ED2 = RD2

        # Implicit terms I(Ui): diffusion for S,R; diffusion - gamma_d*D for D
        IS1 = p.dS * self.apply_laplacian(S1)
        IR1 = p.dR * self.apply_laplacian(R1)
        ID1 = p.dD * self.apply_laplacian(D1) - p.gamma_d * D1

        IS2 = p.dS * self.apply_laplacian(S2)
        IR2 = p.dR * self.apply_laplacian(R2)
        ID2 = p.dD * self.apply_laplacian(D2) - p.gamma_d * D2

        Sn1 = S0 + dt * (bE[0] * ES1 + bE[1] * ES2) + dt * (bI[0] * IS1 + bI[1] * IS2)
        Rn1 = R0 + dt * (bE[0] * ER1 + bE[1] * ER2) + dt * (bI[0] * IR1 + bI[1] * IR2)
        Dn1 = D0 + dt * (bE[0] * ED1 + bE[1] * ED2) + dt * (bI[0] * ID1 + bI[1] * ID2)

        # For (P,A), take the stage 2 values as the next-step value (consistent with stage evolution).
        Pn1 = P2
        An1 = A2

        # Final positivity enforcement
        # Sn1 = np.maximum(Sn1, 0.0)
        # Rn1 = np.maximum(Rn1, 0.0)
        # Dn1 = np.maximum(Dn1, 0.0)
        # Pn1 = np.maximum(Pn1, 0.0)
        # An1 = np.maximum(An1, 0.0)

        return {"S": Sn1, "R": Rn1, "D": Dn1, "P": Pn1, "A": An1}

    def apply_laplacian(self, W: np.ndarray) -> np.ndarray:
        """Apply Neumann Laplacian matrix to a vector W."""
        N = self.N
        a, b, c = self.La, self.Lb, self.Lc
        out = b * W
        out[1:] += a[1:] * W[:-1]
        out[:-1] += c[:-1] * W[1:]
        return out

    def suggest_dt_cfl(self, A: np.ndarray, chi: float, safety: float = 0.45) -> float:
        """
        Suggest a CFL-limited dt for transport based on max |v|, v=chi*∂xA at interfaces.
        """
        if chi == 0.0:
            return float("inf")
        v_int = chi * np.abs((A[1:] - A[:-1]) / self.h)
        vmax = float(np.max(v_int)) if v_int.size > 0 else 0.0
        if vmax < 1e-14:
            return float("inf")
        return safety * self.h / vmax


# ----------------------------
# Demonstration / driver
# ----------------------------

def initialize_state(xc: np.ndarray, p: Params) -> Dict[str, np.ndarray]:
    """
    Example initial condition consistent with the paper’s style:
    - (S,R) around the coexistence equilibrium with small random perturbations
    - D as a localized bump
    - P,A as heterogeneous but nonnegative
    """
    N = xc.size
    rng = np.random.default_rng(1)

    # Coexistence equilibrium for reduced ODE:
    S_star = p.xi * p.K / (p.alpha + p.xi)
    R_star = p.alpha * p.K / (p.alpha + p.xi)

    eps = 1e-2
    S0 = S_star + eps * (2 * rng.random(N) - 1)
    R0 = R_star + eps * (2 * rng.random(N) - 1)
    S0 = np.maximum(S0, 0.0)
    R0 = np.maximum(R0, 0.0)

    # Localized D bump
    x0 = 0.5
    sigma = 0.05
    D0 = np.exp(-((xc - x0) ** 2) / (2 * sigma**2))

    # Passive/active initialization (nonnegative)
    P0 = 0.5 + eps * (2 * rng.random(N) - 1)
    A0 = 0.5 + eps * (2 * rng.random(N) - 1)
    P0 = np.maximum(P0, 0.0)
    A0 = np.maximum(A0, 0.0)

    return {"S": S0, "R": R0, "D": D0, "P": P0, "A": A0}


def run_simulation():
    # Parameters and solver
    p = Params()
    N = 200
    solver = APIMEX1DSolver(N=N, params=p, limiter=minmod, bc="neumann")

    # Initial state
    state = initialize_state(solver.xc, p)

    # Time loop
    Tfinal = 10.0
    t = 0.0

    # Choose dt: satisfy transport CFL for both S and R (using current A)
    # Also respect a diffusion scale dt ~ O(h^2) for accuracy (not stability since implicit diffusion)
    dt_diff = 0.25 * solver.h * solver.h / max(p.dS, p.dR, p.dD)
    dt_adv_S = solver.suggest_dt_cfl(state["A"], chi=p.chiS, safety=0.45)
    dt_adv_R = solver.suggest_dt_cfl(state["A"], chi=p.chiR, safety=0.45)
    dt = min(dt_diff, dt_adv_S, dt_adv_R, 1e-2)
    dt = max(dt, 1e-5)

    print(f"h={solver.h:.4e}, dt={dt:.4e}, dt_diff={dt_diff:.4e}, dt_adv(min)={min(dt_adv_S,dt_adv_R):.4e}")

    # Simple diagnostics
    def diagnostics(s: Dict[str, np.ndarray]) -> Dict[str, float]:
        return {
            "minS": float(np.min(s["S"])),
            "minR": float(np.min(s["R"])),
            "minD": float(np.min(s["D"])),
            "massS": float(np.sum(s["S"]) * solver.h),
            "massR": float(np.sum(s["R"]) * solver.h),
            "massD": float(np.sum(s["D"]) * solver.h),
            "PA_err": float(np.max(np.abs((s["P"] + s["A"]) - (state0_PA)))),
        }

    # Store initial invariant for P+A
    state0_PA = state["P"] + state["A"]

    step = 0
    while t < Tfinal - 1e-14:
        if t + dt > Tfinal:
            dt = Tfinal - t

        state = solver.step(state, dt)
        t += dt
        step += 1

        # update dt occasionally based on evolving A gradients
        if step % 50 == 0:
            dt_adv_S = solver.suggest_dt_cfl(state["A"], chi=p.chiS, safety=0.45)
            dt_adv_R = solver.suggest_dt_cfl(state["A"], chi=p.chiR, safety=0.45)
            dt = min(dt, dt_adv_S, dt_adv_R, dt_diff, 1e-2)
            dt = max(dt, 1e-5)

        if step % 200 == 0 or abs(t - Tfinal) < 1e-12:
            diag = diagnostics(state)
            print(f"t={t:8.4f} step={step:6d} "
                  f"min(S,R,D)=({diag['minS']:.2e},{diag['minR']:.2e},{diag['minD']:.2e}) "
                  f"mass(S,R,D)=({diag['massS']:.4f},{diag['massR']:.4f},{diag['massD']:.4f}) "
                  f"max|P+A-M|={diag['PA_err']:.2e}")

    return solver, state


if __name__ == "__main__":
    run_simulation()

h=5.0000e-03, dt=6.2500e-04, dt_diff=6.2500e-04, dt_adv(min)=6.3818e-03
t=  0.1250 step=   200 min(S,R,D)=(2.17e-01,2.17e-01,3.40e-11) mass(S,R,D)=(0.4937,0.4996,0.1238) max|P+A-M|=1.11e-16
t=  0.2500 step=   400 min(S,R,D)=(2.01e-01,2.03e-01,8.63e-08) mass(S,R,D)=(0.4871,0.5010,0.1222) max|P+A-M|=1.11e-16
t=  0.3750 step=   600 min(S,R,D)=(1.86e-01,1.87e-01,4.28e-06) mass(S,R,D)=(0.4806,0.5030,0.1207) max|P+A-M|=1.11e-16
t=  0.5000 step=   800 min(S,R,D)=(1.74e-01,1.74e-01,4.38e-05) mass(S,R,D)=(0.4740,0.5055,0.1192) max|P+A-M|=1.11e-16
t=  0.6250 step=  1000 min(S,R,D)=(1.65e-01,1.65e-01,2.03e-04) mass(S,R,D)=(0.4675,0.5084,0.1177) max|P+A-M|=1.11e-16
t=  0.7500 step=  1200 min(S,R,D)=(1.58e-01,1.57e-01,6.00e-04) mass(S,R,D)=(0.4608,0.5115,0.1163) max|P+A-M|=1.11e-16
t=  0.8750 step=  1400 min(S,R,D)=(1.52e-01,1.51e-01,1.34e-03) mass(S,R,D)=(0.4541,0.5148,0.1148) max|P+A-M|=1.11e-16
t=  1.0000 step=  1600 min(S,R,D)=(1.48e-01,1.46e-01,2.47e-03) mass(S,R,D)=(0.4474,0.5182,0.1134) max|