In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors

def compute_nontrivial_slice(W_birth, W_death, Y_birth, Y_death):
    """
    Compute the positive nontrivial equilibrium (W_eq, Y_eq) by solving:
      Q1 = W_death / W_birth,   Q2 = Y_death / Y_birth
      W_eq = ½ [ (1 − Q1 + Q2) + sqrt((1 − Q1 + Q2)^2 − 4·Q2 ) ]
      Y_eq = ½ [ (1 − Q2 + Q1) + sqrt((1 − Q2 + Q1)^2 − 4·Q1 ) ]
    Returns (W_eq, Y_eq) if both lie in (0,1); otherwise (None, None).
    """
    Q1 = W_death / W_birth
    Q2 = Y_death / Y_birth

    disc_W = (1 - Q1 + Q2)**2 - 4 * Q2
    if disc_W < 0:
        return None, None
    sqrt_disc_W = np.sqrt(disc_W)
    W_equil = 0.5 * ((1 - Q1 + Q2) + sqrt_disc_W)
    if not (0.0 < W_equil < 1.0):
        return None, None

    disc_Y = (1 - Q2 + Q1)**2 - 4 * Q1
    if disc_Y < 0:
        return None, None
    sqrt_disc_Y = np.sqrt(disc_Y)
    Y_equil = 0.5 * ((1 - Q2 + Q1) + sqrt_disc_Y)
    if not (0.0 < Y_equil < 1.0):
        return None, None

    return W_equil, Y_equil

def simulate_segment(V0, W0, Y0, X0, Z0,
                     W_birth, Y_birth, W_death, Y_death,
                     X_in, Z_in, X_out, Z_out,
                     duration, dt,
                     use_X, use_Z,
                     tol=1e-9,
                     stop_at_eq=True):
    """
    Integrate from t=0 to t=duration with initial conditions
      V(0)=V0, W(0)=W0, Y(0)=Y0, X(0)=X0, Z(0)=Z0.
    If stop_at_eq=True, stops early when all |dV|,|dW|,|dY| (and |dX| if use_X, |dZ| if use_Z)
    fall below tol. Otherwise, always runs full duration.

    Returns:
      t_array,
      V_array, W_array, Y_array,
      X_raw_array (unscaled), Z_raw_array (unscaled),
      X_plot = X_raw_array * X_scaler, Z_plot = Z_raw_array * Z_scaler.
    """
    X_scaler = X_out / X_in if (use_X and X_in > 0) else 1.0
    Z_scaler = Z_out / Z_in if (use_Z and Z_in > 0) else 1.0

    N = int(np.ceil(duration / dt)) + 1
    t = np.linspace(0.0, duration, N)

    V = np.zeros(N)
    W = np.zeros(N)
    Y = np.zeros(N)
    X_raw = np.zeros(N)
    Z_raw = np.zeros(N)

    V[0] = V0
    W[0] = W0
    Y[0] = Y0
    X_raw[0] = X0
    Z_raw[0] = Z0

    final_index = N - 1
    for i in range(1, N):
        Vi = V[i-1]
        Wi = W[i-1]
        Yi = Y[i-1]
        Xi = X_raw[i-1]
        Zi = Z_raw[i-1]

        # dV/dt, dW/dt
        dV = W_birth * (1 - Wi - Vi) * Vi * Yi - W_death * Vi
        dW = W_birth * (1 - Wi - Vi) * Wi * Yi - W_death * Wi

        # dY/dt
        dY = Y_birth * (1 - Yi) * Yi * (Vi + Wi) - Y_death * Yi

        # X-coupling
        if use_X:
            dW += X_out * Xi - X_in * Wi
        # Z-coupling
        if use_Z:
            dY += Z_out * Zi - Z_in * Yi

        # dX/dt, dZ/dt
        dX = - X_out * Xi + X_in * Wi
        dZ = - Z_out * Zi + Z_in * Yi

        # Check for equilibrium if requested
        if stop_at_eq:
            cond = (abs(dV) < tol and abs(dW) < tol and abs(dY) < tol)
            if use_X:
                cond &= abs(dX) < tol
            if use_Z:
                cond &= abs(dZ) < tol
            if cond:
                final_index = i - 1
                break

        # Euler update
        V[i] = Vi + dt * dV
        W[i] = Wi + dt * dW
        Y[i] = Yi + dt * dY
        X_raw[i] = Xi + dt * dX
        Z_raw[i] = Zi + dt * dZ

        # Enforce nonnegativity
        V[i] = max(V[i], 0.0)
        W[i] = max(W[i], 0.0)
        Y[i] = max(Y[i], 0.0)
        X_raw[i] = max(X_raw[i], 0.0)
        Z_raw[i] = max(Z_raw[i], 0.0)

    # Truncate arrays
    t_trunc     = t[: final_index + 1]
    V_trunc     = V[: final_index + 1]
    W_trunc     = W[: final_index + 1]
    Y_trunc     = Y[: final_index + 1]
    X_raw_trunc = X_raw[: final_index + 1]
    Z_raw_trunc = Z_raw[: final_index + 1]

    X_plot = X_raw_trunc * X_scaler
    Z_plot = Z_raw_trunc * Z_scaler

    return t_trunc, V_trunc, W_trunc, Y_trunc, X_raw_trunc, Z_raw_trunc, X_plot, Z_plot

def compute_deltaW_curve(W_birth, Y_birth, W_death, Y_death,
                         X_in, Z_in, X_out, Z_out,
                         Time, dt, use_X, use_Z,
                         num_points, severity,
                         perturb_V, perturb_W, perturb_Y,
                         tol):
    """
    Compute W0_values and corresponding ΔW for a given 'severity'.
    Now interpret severity so that perturbation multiplier = (1 - severity).
    ΔW is computed by applying that perturbation at t=0, then simulating to Time
    with early-stop.  Returns W0_values, DeltaW_vals, and the integral of |DeltaW|.
    """
    W_eq, Y_eq = compute_nontrivial_slice(W_birth, W_death, Y_birth, Y_death)
    if (W_eq is None) or (Y_eq is None):
        raise RuntimeError("No positive, nontrivial equilibrium exists.")
    W0_values = np.linspace(0.0, W_eq, num_points)
    DeltaW = np.zeros_like(W0_values)

    for idx, W0 in enumerate(W0_values):
        V0 = W_eq - W0
        X0 = (X_in / X_out) * W0 if use_X else 0.0
        Z0 = (Z_in / Z_out) * Y_eq if use_Z else 0.0

        # Apply perturbation multiplier = (1 - severity)
        V0p = ((1 - severity) * V0) if perturb_V else V0
        W0p = ((1 - severity) * W0) if perturb_W else W0
        Y0p = ((1 - severity) * Y_eq) if perturb_Y else Y_eq

        # Simulate [0 → Time] with early-stop at equilibrium
        _, _, W_full, _, _, _, _, _ = simulate_segment(
            V0=V0p, W0=W0p, Y0=Y0p, X0=X0, Z0=Z0,
            W_birth=W_birth, Y_birth=Y_birth,
            W_death=W_death, Y_death=Y_death,
            X_in=X_in, Z_in=Z_in,
            X_out=X_out, Z_out=Z_out,
            duration=Time, dt=dt,
            use_X=use_X, use_Z=use_Z,
            tol=tol,
            stop_at_eq=True
        )
        W_final = W_full[-1]
        DeltaW[idx] = W_final - W0

    # Compute integral of absolute ΔW over W0
    integral = np.trapz(np.abs(DeltaW), W0_values)
    return W0_values, DeltaW, integral

def assess_numerical_stability(W_birth, Y_birth, W_death, Y_death,
                               X_in, Z_in, X_out, Z_out,
                               Time=200.0,
                               use_X=True, use_Z=True,
                               num_points=100,
                               perturb_V=False, perturb_W=False, perturb_Y=True,
                               severity=0.3,  # example severity
                               tol=1e-6):
    """
    Assess numerical stability by comparing ΔW vs W0 computed with different dt.
    - Reference solution: dt_ref = 1e-6.
    - Other dt values: logspace from 1e-1 down to 1e-6 (inclusive).
    - For each dt, compute ΔW curve and measure error = max|ΔW_dt - ΔW_ref|.
    - Plot error vs dt on a log-log scale.
    """
    # (1) Compute reference (dt_ref = 1e-6)
    dt_ref = 1e-6
    print(f"Computing reference solution with dt = {dt_ref:e} ...")
    W0_ref, DeltaW_ref, _ = compute_deltaW_curve(
        W_birth=W_birth, Y_birth=Y_birth,
        W_death=W_death, Y_death=Y_death,
        X_in=X_in, Z_in=Z_in, X_out=X_out, Z_out=Z_out,
        Time=Time, dt=dt_ref, use_X=use_X, use_Z=use_Z,
        num_points=num_points, severity=severity,
        perturb_V=perturb_V, perturb_W=perturb_W, perturb_Y=perturb_Y,
        tol=tol
    )

    # (2) Define a set of dt values: logspace from 1e-1 down to 1e-6
    dt_values = np.logspace(-1, -6, num=7)  # [1e-1, 1e-2, ..., 1e-6]

    errors = np.zeros_like(dt_values)

    for i, dt in enumerate(dt_values):
        print(f"Computing ΔW curve with dt = {dt:e} ...")
        W0_vals, DeltaW_vals, _ = compute_deltaW_curve(
            W_birth=W_birth, Y_birth=Y_birth,
            W_death=W_death, Y_death=Y_death,
            X_in=X_in, Z_in=Z_in, X_out=X_out, Z_out=Z_out,
            Time=Time, dt=dt, use_X=use_X, use_Z=use_Z,
            num_points=num_points, severity=severity,
            perturb_V=perturb_V, perturb_W=perturb_W, perturb_Y=perturb_Y,
            tol=tol
        )
        # Interpolate the coarser DeltaW onto the reference W0 grid if needed
        # Here W0_vals == W0_ref by construction (same num_points and same W_eq),
        # so we can directly compare
        error = np.max(np.abs(DeltaW_vals - DeltaW_ref))
        errors[i] = error

    # (3) Plot error vs dt on log-log
    plt.figure(figsize=(7, 5))
    plt.loglog(dt_values, errors, marker='o', linestyle='-',
               color='darkblue', label=r'$\max|\Delta W(dt)-\Delta W_{\rm ref}|$')
    plt.xlabel(r'$\Delta t$', fontsize=12)
    plt.ylabel(r'Error', fontsize=12)
    plt.title('Numerical Stability: Error vs Timestep', fontsize=14)
    plt.grid(True, which='both', ls='--', lw=0.5)
    plt.tight_layout()
    plt.show()

    # Print out dt and corresponding error
    print("\nΔt          Error")
    for dt, err in zip(dt_values, errors):
        print(f"{dt:.1e}    {err:.4e}")

# =========================
# Example usage:
if __name__ == "__main__":
    # === User-supplied parameters ===
    W_birth = 0.4
    W_death = 0.1
    Y_birth = 0.9
    Y_death = 0.15

    X_in = 0.2
    X_out = 0.1
    Z_in = 0.5
    Z_out = 0.25

    Time = 200.0
    use_X = True
    use_Z = False

    num_points = 100

    perturb_V = False
    perturb_W = False
    perturb_Y = True

    severity = 0.3  # means multiply Y by (1 - 0.3) = 0.7 at t=0

    assess_numerical_stability(
        W_birth=W_birth, Y_birth=Y_birth,
        W_death=W_death, Y_death=Y_death,
        X_in=X_in, Z_in=Z_in, X_out=X_out, Z_out=Z_out,
        Time=Time, use_X=use_X, use_Z=use_Z,
        num_points=num_points,
        perturb_V=perturb_V, perturb_W=perturb_W, perturb_Y=perturb_Y,
        severity=severity, tol=1e-6
    )

Computing reference solution with dt = 1.000000e-06 ...
Computing ΔW curve with dt = 1.000000e-01 ...
Computing ΔW curve with dt = 1.467799e-02 ...
Computing ΔW curve with dt = 2.154435e-03 ...
Computing ΔW curve with dt = 3.162278e-04 ...
Computing ΔW curve with dt = 4.641589e-05 ...
Computing ΔW curve with dt = 6.812921e-06 ...
Computing ΔW curve with dt = 1.000000e-06 ...
