In [1]:
# ============================================
# TW ODE (U',P',M') via RK45 + overlay plots
# ============================================
import json
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.interpolate import PchipInterpolator
from scipy.optimize import root_scalar

# ---------- tolerant file-structure helpers ----------
def _token_variants(v: float):
    v = float(v)
    dec   = f"{v:.12f}".rstrip("0").rstrip(".")
    plain = f"{v:g}"
    sci   = f"{v:.0e}"; s,e = sci.split("e"); sci_neat = f"{s}e{int(e)}"
    toks = {dec, plain, sci, sci_neat}
    toks |= {t.replace(".", "p") for t in list(toks)}
    return toks

def _exactish_dir(parent: Path, prefix: str, value: float):
    if not parent.exists(): return None
    for tok in _token_variants(value):
        p = parent / f"{prefix}_{tok}"
        if p.is_dir(): return p
    return None

def find_run_dir(base_roots, lam, alpha, m0):
    if isinstance(base_roots, (str, Path)): base_roots = (base_roots,)
    for root in base_roots:
        root = Path(root)
        lam_dir = _exactish_dir(root, "lambda", lam)
        if lam_dir is None: continue
        a_dir   = _exactish_dir(lam_dir, "alpha", alpha)
        if a_dir is None: continue
        m_dir   = _exactish_dir(a_dir, "m0", m0)
        if m_dir is not None: return m_dir
    return None

# ---------- load run (c and snapshots) ----------
def load_run(run_dir: Path):
    meta = json.loads((run_dir / "summary.json").read_text())
    z = np.load(run_dir / "snapshots.npz", allow_pickle=True)
    data = {
        "c": float(meta.get("wave_speed")),
        "lambda": float(meta.get("lambda_val", meta.get("lambda", np.nan))),
        "alpha": float(meta.get("alpha", np.nan)),
        "m0": float(meta.get("m0", np.nan)),
        "x": np.asarray(z["x"], float),
        "times": np.asarray(z["times"], float),
        "U": np.asarray(z.get("N_arr", z.get("u_arr")), float),  # tumour
        "M": np.asarray(z["M_arr"], float),
    }
    return data

# ---------- find U=0.5 front by monotone spline ----------
def front_x_at_time(x, Urow, threshold=0.5, band=(0.1,0.9)):
    mask = (Urow > band[0]) & (Urow < band[1])
    if mask.sum() < 5: return None
    xloc = x[mask]; uloc = Urow[mask]
    order = np.argsort(xloc); xloc = xloc[order]; uloc = uloc[order]
    spl = PchipInterpolator(xloc, uloc, extrapolate=True)
    # bracket: first sign change around threshold
    s = np.sign(uloc - threshold)
    idx = np.where(s[:-1]*s[1:] < 0)[0]
    if len(idx)==0: return None
    i = idx[0]; xl, xr = xloc[i], xloc[i+1]
    sol = root_scalar(lambda xv: spl(xv)-threshold, bracket=[xl,xr], method="brentq")
    return sol.root if sol.converged else None

# ---------- the TW ODE system ----------
def tw_rhs(xi, y, c, lam, alpha):
    U, P, M = y
    # regularise denominator near M=0 or 1 for safety
    denom = M*(1.0-M) + 1e-12
    dU = P
    dP = -c*P - U*(1.0-U)*M*(1.0-M)
    dM = (1.0/c)*M*(1.0-M)*(lam*U*M - alpha*(1.0-M))
    return [dU, dP, dM]

# ---------- integrate from +L to -L using linearised tail seed ----------
def solve_tw_ivp(c, lam, alpha, m0, L=25.0, eps=1e-6,
                 rtol=1e-6, atol=1e-9, max_step=1.0):
    """
    Start at xi = +L with U=eps, P=r*eps, M=m0 where
    r = (-c + sqrt(c^2 - 4 m0(1-m0)))/2 (decaying eigenvalue of linearised tail).
    Integrate backwards to xi = -L.
    """
    disc = c*c - 4.0*m0*(1.0-m0)
    # if disc < 0, use real part (gives gentle seed); still works in practice
    r = (-c + np.sqrt(max(disc, 0.0))) / 2.0
    U0, P0, M0 = eps, r*eps, m0
    y0 = [U0, P0, M0]
    sol = solve_ivp(
        fun=lambda xi,y: tw_rhs(xi, y, c, lam, alpha),
        t_span=(+L, -L),
        y0=y0,
        method="RK45",
        rtol=rtol, atol=atol,
        max_step=max_step,
        dense_output=False
    )
    return sol.t, sol.y  # xi_vec (descending), [U,P,M] arrays

# ---------- overlay plot ----------
def plot_overlay(base_roots, lam, alpha, m0,
                 times_to_plot=(50,100,150),
                 L=25.0, eps=1e-6, rtol=1e-6, atol=1e-9, max_step=1.0,
                 cmap_name="plasma", window_factor=8.0,
                 title=None):
    run_dir = find_run_dir(base_roots, lam, alpha, m0)
    if run_dir is None:
        raise FileNotFoundError("Run directory not found.")
    data = load_run(run_dir)
    c = data["c"]
    x = data["x"]; T = data["times"]; Uall = data["U"]

    # ---- integrate ODE
    xi_num, Y = solve_tw_ivp(c, lam, alpha, m0, L=L, eps=eps, rtol=rtol, atol=atol, max_step=max_step)
    U_tw, M_tw = Y[0], Y[2]

    # ---- logistic comparator
    xi_log = np.linspace(-window_factor*c, window_factor*c, 800)
    U_log  = 1.0/(1.0 + np.exp(xi_log/c))

    # ---- numerical overlays centred at U=0.5
    cmap = plt.get_cmap(cmap_name)
    colors = [cmap(i/max(1,len(times_to_plot)-1)) for i in range(len(times_to_plot))]

    fig, ax = plt.subplots(figsize=(9.6, 6.0))

    for color, t_ref in zip(colors, times_to_plot):
        k = int(np.argmin(np.abs(T - t_ref)))
        x0 = front_x_at_time(x, Uall[k], threshold=0.5, band=(0.1,0.9))
        if x0 is None: 
            continue
        xi_cent = x - x0
        ax.plot(xi_cent, Uall[k], color=color, alpha=0.45, lw=2, label=fr"num $t\approx{T[k]:g}$")

    # ODE solution (xi_num is descending; sort increasing for plotting)
    order = np.argsort(xi_num)
    ax.plot(np.array(xi_num)[order], np.array(U_tw)[order], color="k", lw=2.5, label="TW ODE (U)")

    # logistic comparator
    ax.plot(xi_log, U_log, ls="--", color="k", lw=2, alpha=0.8, label=r"logistic $1/(1+e^{\xi/c})$")

    ax.set_xlim(-window_factor*c, window_factor*c)
    ax.set_ylim(-0.05, 1.05)
    ax.set_xlabel(r"$\xi$", fontsize=16)
    ax.set_ylabel(r"$U(\xi)$", fontsize=16)
    if title is None:
        title = fr"TW ODE vs numerics ( $\lambda={lam:g}$, $\alpha={alpha:g}$, $m_0={m0:g}$, $c={c:.3g}$ )"
    ax.set_title(title, fontsize=18, weight="bold")
    ax.grid(True, ls="--", alpha=0.25)
    ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False, fontsize=12)
    fig.tight_layout()
    plt.show()

    return fig, ax, (xi_num, U_tw, M_tw)

# ======================
# Example usage
# ======================
BASES = ("speeds_func_4", "speeds_func_l", "speeds_func_u")  # search order
lam     = 10               # λ for the run you want to compare
alpha   = 0.1              # α matching the stored run
m0      = 0.5              # initial ECM at +∞
times   = (25, 50, 75, 100)

# Integrate with RK45 (τ=1 cap), then overlay with numerics + logistic
plot_overlay(
    base_roots=BASES,
    lam=lam, alpha=alpha, m0=m0,
    times_to_plot=times,
    L=25.0, eps=1e-6,
    rtol=1e-6, atol=1e-9, max_step=1.0,  # RK45 with step cap ~ your τ=1
    cmap_name="plasma",
    window_factor=8.0
)

FileNotFoundError: Run directory not found.