In [1]:
# func.py — Functional-source model with plotting & utilities (full code)

import os, math
import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse import eye, diags
from scipy.sparse.linalg import spsolve
from scipy.interpolate import CubicSpline, PchipInterpolator, Akima1DInterpolator, interp1d
from scipy.optimize import root_scalar
from scipy.stats import linregress
from numba import njit

# ----------------------------
# Helpers
# ----------------------------
def _latex_sci(val, pow10_threshold=100.0):
    """Format val for LaTeX titles; switch to scientific when |val| ≥ threshold."""
    if val == 0:
        return "0"
    a = abs(val)
    sign = "-" if val < 0 else ""
    if a < pow10_threshold:
        return f"{val:g}"
    exp = int(np.floor(np.log10(a)))
    mant = a / (10**exp)
    if np.isclose(mant, 1.0, rtol=1e-10, atol=1e-12):
        return rf"{sign}10^{{{exp}}}"
    return rf"{sign}{mant:.2g}\times 10^{{{exp}}}"

def _file_sci(val, pow10_threshold=100.0):
    """Compact number for filenames: 1000 -> '1e3', 2500 -> '2.5e3', 10 -> '10'."""
    if val == 0:
        return "0"
    a = abs(val)
    if a < pow10_threshold:
        return f"{int(val)}" if float(val).is_integer() else f"{val:g}"
    return f"{val:.0e}".replace("+0","").replace("+","").replace("e0","e")

def _m0_two_digits(m0):
    """m0 in [0,1] to two digits: 0.0->'00', 0.5->'05', 1.0->'10'."""
    return f"{int(round(m0*10)):02d}"

def _nearest_indices(t_vec, t_points):
    return [int(np.argmin(np.abs(t_vec - t))) for t in t_points]

# ----------------------------
# Kernels
# ----------------------------
@njit
def f_numba(N, rho, K):
    # tumour growth: N_t = rho * N * (1 - N/K)
    return rho * N * (1 - N / K)

@njit
def build_laplacian_diagonals_avg(m, D, dx):
    """
    Variable-coefficient diffusion with edge-averaged M:
      (D * M*(1-M) * u_x)_x  with homogeneous Neumann BCs.
    Returns three diagonals (lower, center, upper) scaled by 1/dx^2.
    """
    N = len(m)
    lower = np.zeros(N)
    center = np.zeros(N)
    upper = np.zeros(N)

    for i in range(1, N - 1):
        ml = 0.5 * (m[i - 1] + m[i])
        mr = 0.5 * (m[i] + m[i + 1])
        Dl = max(1e-6, D * ml * (1 - ml))
        Dr = max(1e-6, D * mr * (1 - mr))
        lower[i] = Dl
        upper[i] = Dr
        center[i] = - (Dl + Dr)

    # Neumann at x=0
    mr = 0.5 * (m[0] + m[1])
    Dr = max(1e-6, D * mr * (1 - mr))
    center[0] = -2 * Dr
    upper[0]  =  2 * Dr

    # Neumann at x=L
    ml = 0.5 * (m[-2] + m[-1])
    Dl = max(1e-6, D * ml * (1 - ml))
    center[-1] = -2 * Dl
    lower[-1]  =  2 * Dl

    invdx2 = 1.0 / dx**2
    return invdx2 * lower, invdx2 * center, invdx2 * upper

# -------------------------------------------------
# Main functional-source class
# -------------------------------------------------
class Dissertation_Func_1D:
    """
    Functional source ECM dynamics:
        m_t = alpha * (1 - m) - k * u * m
    Tumour:
        u_t = (D * m (1-m) u_x)_x + rho * u * (1 - u/K)
    """
    def __init__(self, D=1.0, rho=1.0, K=1.0, k=1.0,
                 alpha=1.0, n0=1.0, m0=0.5, Mmax=1.0, perc=0.2,
                 L=1000.0, N=5001, T=1000.0, dt=0.1,
                 scheme="AB2AM2", init_type="step", steepness=0.1,
                 t_start=50.0, t_end=500.0, num_points=200):
        # PDE/ODE params
        self.D = D; self.rho = rho; self.K = K
        self.k = k; self.alpha = alpha
        self.n0 = n0; self.m0 = m0
        self.Mmax = Mmax; self.perc = perc
        self.steepness = steepness

        # grid/time
        self.L = L; self.N = N; self.dx = L / (N - 1)
        self.x = np.linspace(0, L, N)
        self.T = T; self.dt = dt; self.Nt = int(T / dt)
        self.scheme = scheme.upper()
        self.init_type = init_type

        # storage
        self.times = np.linspace(0, T, self.Nt)
        self.N_arr = np.zeros((self.Nt, self.N))
        self.M_arr = np.zeros((self.Nt, self.N))
        self.wave_speed = None  # filled later

        # front-tracking window
        self.t_start = t_start
        self.t_end = t_end
        self.num_points = num_points

    # ---------------------------------------
    # Solver
    # ---------------------------------------
    def initial_condition(self):
        if self.init_type == "step":
            N0 = self.n0 * np.where(self.x < self.perc * self.L, 0.7, 0.0)
        elif self.init_type == "tanh":
            N0 = self.n0 * 0.5 * (1 - np.tanh(self.steepness * (self.x - self.perc * self.L)))
        else:
            raise ValueError("Unknown initial condition.")
        M0 = self.m0 * self.Mmax * np.ones_like(self.x)
        return N0, M0

    def update_laplacian(self, M):
        lower, center, upper = build_laplacian_diagonals_avg(M, self.D, self.dx)
        return diags([lower[1:], center, upper[:-1]], [-1, 0, 1], format="csr")

    def solve(self):
        # initial data
        N_prev, M_prev = self.initial_condition()
        f_prev = f_numba(N_prev, self.rho, self.K)
        L_prev = self.update_laplacian(M_prev)

        # first step for u (implicit Euler in diffusion)
        A0 = (eye(self.N) - self.dt * L_prev)
        N_curr = spsolve(A0.tocsc(), N_prev + self.dt * f_prev)

        # first step for m (implicit Euler with u^{1})
        denom = 1.0 + self.dt * (self.alpha + self.k * np.maximum(N_curr, 0.0))
        M_curr = (M_prev + self.alpha * self.dt) / denom
        np.clip(M_curr, 0.0, self.Mmax, out=M_curr)

        # store first two frames
        self.N_arr[0], self.M_arr[0] = N_prev, M_prev
        self.N_arr[1], self.M_arr[1] = N_curr, M_curr

        # main loop
        for i in range(2, self.Nt):
            # operator with current m
            L_curr = self.update_laplacian(M_curr)
            f_curr = f_numba(N_curr, self.rho, self.K)

            # AB2–AM2 for u
            rhs = (eye(self.N) + 0.5 * self.dt * L_prev) @ N_curr \
                  + self.dt * (1.5 * f_curr - 0.5 * f_prev)
            A = (eye(self.N) - 0.5 * self.dt * L_curr)
            N_next = spsolve(A.tocsc(), rhs)

            # Neumann ends for u by copying neighbors
            N_next[0], N_next[-1] = N_next[1], N_next[-2]

            # implicit Euler for m using u^{n+1}
            denom = 1.0 + self.dt * (self.alpha + self.k * np.maximum(N_next, 0.0))
            M_next = (M_curr + self.alpha * self.dt) / denom
            np.clip(M_next, 0.0, self.Mmax, out=M_next)

            # store & roll
            self.N_arr[i] = N_next
            self.M_arr[i] = M_next
            N_prev, N_curr = N_curr, N_next
            M_prev, M_curr = M_curr, M_next
            f_prev = f_curr
            L_prev = L_curr

    # ---------------------------------------
    # Front tracking & speed estimation
    # ---------------------------------------
    def _get_spline(self, method, x, y):
        m = method.lower()
        if m == 'cubic':  return CubicSpline(x, y)
        if m == 'pchip':  return PchipInterpolator(x, y)
        if m == 'akima':  return Akima1DInterpolator(x, y)
        if m == 'linear': return interp1d(x, y, kind='linear', fill_value="extrapolate")
        raise ValueError(f"Unsupported spline_type: {method}")

    def track_wavefront_local_interpolation(self, threshold=0.5, band=(0.1, 0.9),
                                            spline_type='cubic', target='N'):
        x = self.x
        t_vec = self.times
        u_arr = self.N_arr if target.lower() == 'n' else self.M_arr
        t_list = np.linspace(self.t_start, self.t_end, self.num_points)
        x_fronts, t_fronts = [], []

        for t_target in t_list:
            idx = int(np.argmin(np.abs(t_vec - t_target)))
            u = u_arr[idx]
            mask = (u > band[0]) & (u < band[1])
            if np.sum(mask) < 5:
                continue
            x_local, u_local = x[mask], u[mask]
            sidx = np.argsort(x_local)
            x_local, u_local = x_local[sidx], u_local[sidx]
            spline = self._get_spline(spline_type, x_local, u_local)

            # find first threshold crossing in the band
            sign_change = np.where(
                np.sign(u_local[:-1] - threshold) != np.sign(u_local[1:] - threshold)
            )[0]
            if len(sign_change) == 0:
                continue
            i = int(sign_change[0])
            xl, xr = x_local[i], x_local[i + 1]

            try:
                sol = root_scalar(lambda xv: spline(xv) - threshold, bracket=[xl, xr])
                if sol.converged:
                    x_fronts.append(sol.root)
                    t_fronts.append(t_target)
            except Exception:
                pass

        return np.array(t_fronts), np.array(x_fronts)

    def estimate_wave_speed(self, threshold=0.5, band=(0.1, 0.9),
                            spline_type='cubic', plot=True, target='N'):
        t_fronts, x_fronts = self.track_wavefront_local_interpolation(
            threshold=threshold, band=band, spline_type=spline_type, target=target
        )
        if len(t_fronts) < 2:
            print("❌ Not enough valid front points.")
            return None, None, None

        slope, intercept, r_value, _, _ = linregress(t_fronts, x_fronts)

        if plot:
            plt.figure(figsize=(8, 4))
            plt.plot(t_fronts, x_fronts, 'o', label='Front')
            plt.plot(t_fronts, slope * t_fronts + intercept, 'k--',
                     label=f'Slope = {slope:.3f},  $R^2$ = {r_value**2:.4f}')
            plt.xlabel("Time t")
            plt.ylabel("Wavefront x(t)")
            plt.title("Wave speed via linear fit")
            plt.legend(); plt.grid(True); plt.tight_layout()
            plt.show()

        return slope, intercept, r_value**2

    def plot_speed_curve(self, threshold=0.5, band=(0.1, 0.9), spline_type='cubic', target='N'):
        """
        Show x_front vs t with best-fit line and annotate R^2.
        Returns (speed, intercept, R2).
        """
        t_fronts, x_fronts = self.track_wavefront_local_interpolation(
            threshold=threshold, band=band, spline_type=spline_type, target=target
        )
        if len(t_fronts) < 2:
            print("❌ Not enough valid front points.")
            return None, None, None

        slope, intercept, r_value, _, _ = linregress(t_fronts, x_fronts)
        r2 = r_value**2

        plt.figure(figsize=(8, 4.6))
        plt.plot(t_fronts, x_fronts, 'o', label='Front samples')
        plt.plot(t_fronts, slope * t_fronts + intercept, 'k--',
                 label=f'$c$={slope:.4f}, $R^2$={r2:.4f}')
        plt.xlabel("Time $t$")
        plt.ylabel("Front position $x(t)$")
        plt.title("Wave speed estimation")
        plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout()
        plt.show()

        self.wave_speed = slope
        return slope, intercept, r2

    # ---------------------------------------
    # Publication-ready snapshot plot(s)
    # ---------------------------------------
    def plot_u_m_with_custom_style(self, 
                                   t_points=[0, 100, 200, 300],
                                   target="both",          # "u", "m", or "both"
                                   yticks_mode="basic", 
                                   show_arrows=True,
                                   show_speed_text=True,
                                   print_speed=False,
                                   ceil_speed=False,     
                                   arrow_len=None, arrow_lw=2.5,
                                   arrow_x_frac=0.7,     
                                   y_red=0.8,            
                                   y_blue=0.25,          
                                   head_length=1.5, head_width=0.65,
                                   save=False, folder="Plots_Func", filename=None):
        """
        Plot snapshots of u, m, or both (depending on target).
        """
        x, N_arr, M_arr, t_vec = self.x, self.N_arr, self.M_arr, self.times

        # Map requested times -> nearest indices
        t_indices = [int(np.argmin(np.abs(t_vec - t))) for t in t_points]

        # Optionally compute wave speed (for annotation)
        if (show_speed_text or print_speed) and getattr(self, "wave_speed", None) is None:
            self.wave_speed, _, _ = self.estimate_wave_speed(
                plot=False, target='N', threshold=0.5, band=(0.1, 0.9), spline_type='cubic'
            )
        if print_speed and (self.wave_speed is not None):
            print(f"[func plot] Estimated wave speed c = {self.wave_speed:.6g}")

        # Speed string
        c_str = "—"
        if self.wave_speed is not None:
            if ceil_speed == "down":
                c_str = f"{math.floor(self.wave_speed * 100) / 100:.2f}"
            elif ceil_speed == "up":
                c_str = f"{math.ceil(self.wave_speed * 100) / 100:.2f}"
            else:
                c_str = f"{self.wave_speed:.3g}"

        # Arrow geometry
        if arrow_len is None:
            arrow_len = 0.15 * self.L
        arrow_x_start = np.clip(arrow_x_frac * self.L, 0.0, self.L)
        arrow_x_end   = np.clip(arrow_x_start + arrow_len, 0.0, self.L)
        if arrow_x_end <= arrow_x_start:
            arrow_x_start = np.clip(self.L - arrow_len, 0.0, self.L)
            arrow_x_end   = self.L

        # Figure
        fig, ax = plt.subplots(figsize=(8, 6))

        # Plot series
        targ = target.lower()
        for t, tidx in zip(t_points, t_indices):
            ls = '--' if np.isclose(t, 0.0) else '-'
            if targ in ("u", "both"):
                ax.plot(x, N_arr[tidx], color='red',  linestyle=ls, label=rf"$u(x,{int(t)})$")
            if targ in ("m", "both"):
                ax.plot(x, M_arr[tidx], color='blue', linestyle=ls, label=rf"$m(x,{int(t)})$")

        # Direction arrows (only if both)
        if show_arrows and targ == "both":
            arrow_style_red  = dict(arrowstyle=f'->,head_length={head_length},head_width={head_width}',
                                    color='red',  lw=arrow_lw)
            arrow_style_blue = dict(arrowstyle=f'->,head_length={head_length},head_width={head_width}',
                                    color='blue', lw=arrow_lw)
            ax.annotate('', xy=(arrow_x_end, y_red),  xytext=(arrow_x_start, y_red),  arrowprops=arrow_style_red)
            ax.annotate('', xy=(arrow_x_end, y_blue), xytext=(arrow_x_start, y_blue), arrowprops=arrow_style_blue)

        # Corner text
        x_text = x[0] + 0.02 * self.L
        ax.text(x_text, 0.92, rf"$\overline{{m}} = {self.m0}$", fontsize=18, ha='left')
        if show_speed_text and (self.wave_speed is not None):
            ax.text(x_text, 0.82, rf"$c = {c_str}$", fontsize=18, ha='left')

        # Axes + title
        ax.set_xlabel(r"$x$", fontsize=18)
        if targ == "u":
            ax.set_ylabel(r"$u(x,t)$", fontsize=18)
        elif targ == "m":
            ax.set_ylabel(r"$m(x,t)$", fontsize=18)
        else:
            ax.set_ylabel(r"$u(x,t),\, m(x,t)$", fontsize=18)
        ax.set_xlim([0, self.L])

        mode = str(yticks_mode).lower()
        if mode == "basic":
            ax.set_ylim([0, 1.05]); ax.set_yticks([0.0, 0.5, 1.0])
        elif mode == "split":
            ax.set_ylim([0, 1.05]); ax.set_yticks(np.arange(0.0, 1.01, 0.2))
        elif mode == "splitplus":
            ax.set_ylim([0, 1.25]); ax.set_yticks(np.arange(0.0, 1.21, 0.2))
        else:
            ax.set_ylim([0, 1.05]); ax.set_yticks([0.0, 0.5, 1.0])

        ax.tick_params(axis='y', labelsize=16)
        ax.tick_params(axis='x', labelsize=16)
        lam_str = _latex_sci(self.k, pow10_threshold=100.0)
        ax.set_title(rf"$\lambda = {lam_str}$", fontsize=20)
        ax.grid(False)
        fig.tight_layout()

        # Save or show
        if save:
            os.makedirs(folder, exist_ok=True)
            m0_str = _m0_two_digits(self.m0)
            lam_file = _file_sci(self.k, 100.0)
            fname = filename or f"func_{m0_str}_lam{lam_file}_{targ}.png"
            outpath = os.path.join(folder, fname)
            fig.savefig(outpath, dpi=300, bbox_inches="tight")
            plt.close(fig)
            print(f"[func plot] Figure saved to {outpath}")
        else:
            plt.show()

In [2]:
# ==== Imports ====
import os, json
from pathlib import Path
import numpy as np
from joblib import Parallel, delayed

# -------------------------
# Save / path helpers
# -------------------------
def _ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def _atomic_write_json(path: Path, obj):
    tmp = path.with_suffix(".tmp")
    with open(tmp, "w") as f:
        json.dump(obj, f, indent=2)
    os.replace(tmp, path)

def _save_summary(run_dir: Path, meta: dict):
    with open(run_dir / "summary.json", "w") as f:
        json.dump(meta, f, indent=2)

def _save_fronts(run_dir: Path, t_fronts, x_fronts, name=None):
    fname = "fronts.npz" if not name else f"fronts_{name}.npz"
    np.savez_compressed(run_dir / fname,
                        t_fronts=np.asarray(t_fronts),
                        x_fronts=np.asarray(x_fronts))

def _save_snapshots_every_stride(run_dir: Path, model, stride=150):
    """
    Save EVERY `stride`-th snapshot (plus the last one), with U and M kept separate.
    """
    idx = np.unique(np.concatenate([
        np.arange(0, model.Nt, stride),
        np.array([model.Nt - 1])
    ]))
    np.savez_compressed(
        run_dir / "snapshots.npz",
        x=model.x,
        times=model.times[idx],
        N_arr=model.N_arr[idx, :],   # tumour u(x,t)
        M_arr=model.M_arr[idx, :]    # ECM   m(x,t)
    )

def _fmt_val(v):
    # compact label in folder names: keeps ints clean (e.g., 10 not 10.0)
    if isinstance(v, (int, np.integer)) or (isinstance(v, float) and v.is_integer()):
        return f"{int(v)}"
    s = f"{v}"
    return s.rstrip('0').rstrip('.') if '.' in s else s

# -------------------------
# Single-run worker
# -------------------------
def run_one(lam, alpha, m0,
            base_dir="speeds_func",
            model_kwargs=None,
            overwrite=False,
            snapshot_stride=150):
    """
    Builds, solves, measures, and saves one (λ, α, m0) run.

    Skips a run if base_dir/lambda_*/alpha_*/m0_*/summary.json exists and overwrite=False.
    Saves:
      - summary.json (metadata + c, R^2)
      - fronts.npz    (t_fronts, x_fronts for N at threshold 0.5)
      - snapshots.npz (x, times[idx], N_arr[idx,:], M_arr[idx,:]) with idx every `snapshot_stride`
    """
    if model_kwargs is None:
        model_kwargs = {}

    try:
        # Shallow copy of shared kwargs
        local_kwargs = dict(model_kwargs)

        # Optional per-λ overrides (example: finer time-step for very large λ, if you want)
        # if float(lam) >= 1e3:
        #     local_kwargs.update(dict(dt=0.01))

        base = Path(base_dir)
        lam_dir = base / f"lambda_{_fmt_val(lam)}"
        alpha_dir = lam_dir / f"alpha_{_fmt_val(alpha)}"
        run_dir = alpha_dir / f"m0_{_fmt_val(m0)}"
        _ensure_dir(run_dir)

        # Skip if already done (unless overwrite=True)
        if not overwrite and (run_dir / "summary.json").exists():
            return ("skipped", lam, alpha, m0)

        # Avoid thread over-subscription (OpenMP/BLAS)
        os.environ.setdefault("OMP_NUM_THREADS", "1")
        os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
        os.environ.setdefault("MKL_NUM_THREADS", "1")
        os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

        # Build and solve
        model = Dissertation_Func_1D(k=lam, alpha=alpha, m0=m0, **local_kwargs)
        model.solve()

        # Speed on u (N)
        c, b, r2 = model.estimate_wave_speed(
            threshold=0.5, band=(0.1, 0.9), spline_type='cubic',
            plot=False, target='N'
        )
        if (c is None) or (isinstance(c, float) and np.isnan(c)):
            raise ValueError("Wave speed could not be calculated.")
        model.wave_speed = c

        # Front points (N)
        t_fronts, x_fronts = model.track_wavefront_local_interpolation(
            threshold=0.5, band=(0.1, 0.9), spline_type='cubic', target='N'
        )

        # Save artifacts
        _save_summary(run_dir, dict(
            lambda_val=float(lam),
            alpha=float(alpha),
            m0=float(m0),
            wave_speed=float(c),
            r2=(float(r2) if r2 is not None else None),
            # useful context
            dt=model.dt, T=model.T, L=model.L, N=model.N,
            init_type=model.init_type,
            steepness=getattr(model, "steepness", None),
            perc=getattr(model, "perc", None),
            t_start=model.t_start, t_end=model.t_end,
            num_points=getattr(model, "num_points", None),
            saved_stride=int(snapshot_stride)
        ))
        _save_fronts(run_dir, t_fronts, x_fronts, name="N")
        _save_snapshots_every_stride(run_dir, model, stride=snapshot_stride)

        return ("done", lam, alpha, m0, float(c), (float(r2) if r2 is not None else None))

    except Exception as e:
        return ("failed", lam, alpha, m0, str(e))

# -------------------------
# Parallel grid runner
# -------------------------
def run_grid(lambda_vals, alpha_vals, m0_vals,
             base_dir="speeds_func",
             model_kwargs=None,
             overwrite=False,
             snapshot_stride=150,
             n_jobs=-1, verbose=10):
    """
    Launch all (λ, α, m0) runs in parallel and log failures & low-R² cases.
    """
    if model_kwargs is None:
        model_kwargs = {}

    tasks = [(lam, alpha, m0) for lam in lambda_vals for alpha in alpha_vals for m0 in m0_vals]
    results = Parallel(n_jobs=n_jobs, verbose=verbose, backend="loky")(
        delayed(run_one)(
            lam, alpha, m0,
            base_dir=base_dir,
            model_kwargs=model_kwargs,
            overwrite=overwrite,
            snapshot_stride=snapshot_stride
        ) for lam, alpha, m0 in tasks
    )

    done, skipped, failed, low_r2 = [], [], [], []
    for r in results:
        tag = r[0]
        if tag == "done":
            _, lam, alpha_eff, m0_eff, c, r2 = r
            done.append({"lambda": lam, "alpha": alpha_eff, "m0": m0_eff, "c": c, "r2": r2})
            if (r2 is None) or (isinstance(r2, float) and (np.isnan(r2) or r2 < .999)):
                low_r2.append({"lambda": lam, "alpha": alpha_eff, "m0": m0_eff, "c": c, "r2": r2})
        elif tag == "skipped":
            _, lam, alpha_eff, m0_eff = r
            skipped.append({"lambda": lam, "alpha": alpha_eff, "m0": m0_eff})
        elif tag == "failed":
            _, lam, alpha_orig, m0_orig, msg = r
            failed.append({"lambda": lam, "alpha": alpha_orig, "m0": m0_orig, "error": msg})

    base = Path(base_dir)
    _ensure_dir(base)
    _atomic_write_json(base / "failed_runs.json", failed)
    _atomic_write_json(base / "low_r2_runs.json", low_r2)

    print(f"✅ Done: {len(done)}, Skipped: {len(skipped)}, Failed: {len(failed)}, Low-R²: {len(low_r2)}")
    if failed:
        print("❌ Failed runs (sample):")
        for item in failed[:20]:
            print(f"  λ={item['lambda']}, α={item['alpha']}, m0={item['m0']} | error: {item['error']}")
    if low_r2:
        print("⚠️  Low-R² runs (R² < 0.999):")
        for item in low_r2[:20]:
            print(f"  λ={item['lambda']}, α={item['alpha']}, m0={item['m0']} | c={item['c']} | r2={item['r2']}")

    return {"done": done, "skipped": skipped, "failed": failed, "low_r2": low_r2}

In [2]:
# ==== Imports ====
import os, json
from pathlib import Path
import numpy as np
from joblib import Parallel, delayed

# -------------------------
# Save / path helpers
# -------------------------
def _ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def _atomic_write_json(path: Path, obj):
    tmp = path.with_suffix(".tmp")
    with open(tmp, "w") as f:
        json.dump(obj, f, indent=2)
    os.replace(tmp, path)

def _save_summary(run_dir: Path, meta: dict):
    with open(run_dir / "summary.json", "w") as f:
        json.dump(meta, f, indent=2)

def _save_fronts(run_dir: Path, t_fronts, x_fronts, name=None):
    fname = "fronts.npz" if not name else f"fronts_{name}.npz"
    np.savez_compressed(run_dir / fname,
                        t_fronts=np.asarray(t_fronts),
                        x_fronts=np.asarray(x_fronts))

def _save_snapshots_every_stride(run_dir: Path, model, stride=150):
    """
    Save EVERY `stride`-th snapshot (plus the last one), with U and M kept separate.
    """
    idx = np.unique(np.concatenate([
        np.arange(0, model.Nt, stride),
        np.array([model.Nt - 1])
    ]))
    np.savez_compressed(
        run_dir / "snapshots.npz",
        x=model.x,
        times=model.times[idx],
        N_arr=model.N_arr[idx, :],   # tumour u(x,t)
        M_arr=model.M_arr[idx, :]    # ECM   m(x,t)
    )

def _fmt_val(v):
    # compact label in folder names: keeps ints clean (e.g., 10 not 10.0)
    if isinstance(v, (int, np.integer)) or (isinstance(v, float) and v.is_integer()):
        return f"{int(v)}"
    s = f"{v}"
    return s.rstrip('0').rstrip('.') if '.' in s else s

# -------------------------
# Single-run worker
# -------------------------
def run_one(lam, alpha, m0,
            base_dir="speeds_func",
            model_kwargs=None,
            overwrite=False,
            snapshot_stride=150):
    """
    Builds, solves, measures, and saves one (λ, α, m0) run.

    Skips a run if base_dir/lambda_*/alpha_*/m0_*/summary.json exists and overwrite=False.
    Saves:
      - summary.json (metadata + c, R^2)
      - fronts.npz    (t_fronts, x_fronts for N at threshold 0.5)
      - snapshots.npz (x, times[idx], N_arr[idx,:], M_arr[idx,:]) with idx every `snapshot_stride`
    """
    if model_kwargs is None:
        model_kwargs = {}

    try:
        # Shallow copy of shared kwargs
        local_kwargs = dict(model_kwargs)

        # Optional per-λ overrides (example: finer time-step for very large λ, if you want)
        # if float(lam) >= 1e3:
        #     local_kwargs.update(dict(dt=0.01))

        base = Path(base_dir)
        lam_dir = base / f"lambda_{_fmt_val(lam)}"
        alpha_dir = lam_dir / f"alpha_{_fmt_val(alpha)}"
        run_dir = alpha_dir / f"m0_{_fmt_val(m0)}"
        _ensure_dir(run_dir)

        # Skip if already done (unless overwrite=True)
        if not overwrite and (run_dir / "summary.json").exists():
            return ("skipped", lam, alpha, m0)

        # Avoid thread over-subscription (OpenMP/BLAS)
        os.environ.setdefault("OMP_NUM_THREADS", "1")
        os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
        os.environ.setdefault("MKL_NUM_THREADS", "1")
        os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

        # Build and solve
        model = Dissertation_Func_1D(k=lam, alpha=alpha, m0=m0, **local_kwargs)
        model.solve()

        # Speed on u (N)
        c, b, r2 = model.estimate_wave_speed(
            threshold=0.5, band=(0.1, 0.9), spline_type='cubic',
            plot=False, target='N'
        )
        if (c is None) or (isinstance(c, float) and np.isnan(c)):
            raise ValueError("Wave speed could not be calculated.")
        model.wave_speed = c

        # Front points (N)
        t_fronts, x_fronts = model.track_wavefront_local_interpolation(
            threshold=0.5, band=(0.1, 0.9), spline_type='cubic', target='N'
        )

        # Save artifacts
        _save_summary(run_dir, dict(
            lambda_val=float(lam),
            alpha=float(alpha),
            m0=float(m0),
            wave_speed=float(c),
            r2=(float(r2) if r2 is not None else None),
            # useful context
            dt=model.dt, T=model.T, L=model.L, N=model.N,
            init_type=model.init_type,
            steepness=getattr(model, "steepness", None),
            perc=getattr(model, "perc", None),
            t_start=model.t_start, t_end=model.t_end,
            num_points=getattr(model, "num_points", None),
            saved_stride=int(snapshot_stride)
        ))
        _save_fronts(run_dir, t_fronts, x_fronts, name="N")
        _save_snapshots_every_stride(run_dir, model, stride=snapshot_stride)

        return ("done", lam, alpha, m0, float(c), (float(r2) if r2 is not None else None))

    except Exception as e:
        return ("failed", lam, alpha, m0, str(e))

# -------------------------
# Parallel grid runner
# -------------------------
def run_grid(lambda_vals, alpha_vals, m0_vals,
             base_dir="speeds_func",
             model_kwargs=None,
             overwrite=False,
             snapshot_stride=150,
             n_jobs=-1, verbose=10):
    """
    Launch all (λ, α, m0) runs in parallel and log failures & low-R² cases.
    """
    if model_kwargs is None:
        model_kwargs = {}

    tasks = [(lam, alpha, m0) for lam in lambda_vals for alpha in alpha_vals for m0 in m0_vals]
    results = Parallel(n_jobs=n_jobs, verbose=verbose, backend="loky")(
        delayed(run_one)(
            lam, alpha, m0,
            base_dir=base_dir,
            model_kwargs=model_kwargs,
            overwrite=overwrite,
            snapshot_stride=snapshot_stride
        ) for lam, alpha, m0 in tasks
    )

    done, skipped, failed, low_r2 = [], [], [], []
    for r in results:
        tag = r[0]
        if tag == "done":
            _, lam, alpha_eff, m0_eff, c, r2 = r
            done.append({"lambda": lam, "alpha": alpha_eff, "m0": m0_eff, "c": c, "r2": r2})
            if (r2 is None) or (isinstance(r2, float) and (np.isnan(r2) or r2 < .999)):
                low_r2.append({"lambda": lam, "alpha": alpha_eff, "m0": m0_eff, "c": c, "r2": r2})
        elif tag == "skipped":
            _, lam, alpha_eff, m0_eff = r
            skipped.append({"lambda": lam, "alpha": alpha_eff, "m0": m0_eff})
        elif tag == "failed":
            _, lam, alpha_orig, m0_orig, msg = r
            failed.append({"lambda": lam, "alpha": alpha_orig, "m0": m0_orig, "error": msg})

    base = Path(base_dir)
    _ensure_dir(base)
    _atomic_write_json(base / "failed_runs.json", failed)
    _atomic_write_json(base / "low_r2_runs.json", low_r2)

    print(f"✅ Done: {len(done)}, Skipped: {len(skipped)}, Failed: {len(failed)}, Low-R²: {len(low_r2)}")
    if failed:
        print("❌ Failed runs (sample):")
        for item in failed[:20]:
            print(f"  λ={item['lambda']}, α={item['alpha']}, m0={item['m0']} | error: {item['error']}")
    if low_r2:
        print("⚠️  Low-R² runs (R² < 0.999):")
        for item in low_r2[:20]:
            print(f"  λ={item['lambda']}, α={item['alpha']}, m0={item['m0']} | c={item['c']} | r2={item['r2']}")

    return {"done": done, "skipped": skipped, "failed": failed, "low_r2": low_r2}

# -------------------------
# Example usage in the notebook
# -------------------------
# 1) Make sure you've already done:
#    from func import Dissertation_Func_1D
#
# 2) Define your grids and shared kwargs:

lambda_vals = [0.001, 0.01, 0.1, 0.5, 1, 5, 10, 100, 1000, 100000, 1000000, 10000000, 100000000]
alpha_vals  = [0.001, 0.01, 0.1, 0.5, 1, 5, 10, 100, 1000, 100000, 1000000, 10000000, 100000000]
m0_vals     = [0, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, ]

shared_kwargs = dict(
    # Your requested defaults
    L=200, N=20001, T=400, dt=0.1,
    init_type="tanh", steepness=0.85, perc=0.4,
    t_start=100, t_end=350, num_points=250,
    n0=1.0, K=1.0, rho=1.0, D=1.0, Mmax=1.0
)

# 3) Run the grid (tweak n_jobs as your machine allows)
results = run_grid(lambda_vals, alpha_vals, m0_vals,                    
                   base_dir="speeds_func",
                    model_kwargs=shared_kwargs,
                    snapshot_stride=150,   # <- only every 150th snapshot (plus last)
                    overwrite=False, n_jobs=8, verbose=10)

# 4) (Optional) Quick flagged summary (after results) — example:
print("\n=== Summary of Problematic Runs ===")
for item in results["failed"]:
    print(f"FAIL -> λ={item['lambda']}, α={item['alpha']}, m0={item['m0']} | {item['error']}")
for item in results["low_r2"][:20]:
    print(f"LOW R² -> λ={item['lambda']}, α={item['alpha']}, m0={item['m0']} | r2={item['r2']}")

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   2 tasks      | elapsed:  3.4min
[Parallel(n_jobs=8)]: Done   9 tasks      | elapsed:  7.6min
[Parallel(n_jobs=8)]: Done  16 tasks      | elapsed:  7.7min
[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed: 26.5min
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed: 28.9min
[Parallel(n_jobs=8)]: Done  45 tasks      | elapsed: 31.4min
[Parallel(n_jobs=8)]: Done  56 tasks      | elapsed: 33.8min


❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done  69 tasks      | elapsed: 38.2min


❌ Not enough valid front points.
❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done  82 tasks      | elapsed: 42.6min


❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done  97 tasks      | elapsed: 47.2min
[Parallel(n_jobs=8)]: Done 112 tasks      | elapsed: 49.9min
[Parallel(n_jobs=8)]: Done 129 tasks      | elapsed: 56.3min
[Parallel(n_jobs=8)]: Done 146 tasks      | elapsed: 60.9min


❌ Not enough valid front points.
❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done 165 tasks      | elapsed: 65.7min


❌ Not enough valid front points.
❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed: 70.7min
[Parallel(n_jobs=8)]: Done 205 tasks      | elapsed: 77.2min
[Parallel(n_jobs=8)]: Done 226 tasks      | elapsed: 83.5min


❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done 249 tasks      | elapsed: 90.2min


❌ Not enough valid front points.
❌ Not enough valid front points.
❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done 272 tasks      | elapsed: 96.0min
[Parallel(n_jobs=8)]: Done 297 tasks      | elapsed: 103.9min
[Parallel(n_jobs=8)]: Done 322 tasks      | elapsed: 110.6min


❌ Not enough valid front points.
❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done 349 tasks      | elapsed: 117.7min


❌ Not enough valid front points.
❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done 376 tasks      | elapsed: 125.5min
[Parallel(n_jobs=8)]: Done 405 tasks      | elapsed: 133.4min


❌ Not enough valid front points.
❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done 434 tasks      | elapsed: 142.2min


❌ Not enough valid front points.
❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done 465 tasks      | elapsed: 150.8min
[Parallel(n_jobs=8)]: Done 496 tasks      | elapsed: 159.6min


❌ Not enough valid front points.
❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done 529 tasks      | elapsed: 168.5min


❌ Not enough valid front points.
❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done 562 tasks      | elapsed: 178.0min
[Parallel(n_jobs=8)]: Done 597 tasks      | elapsed: 187.2min


❌ Not enough valid front points.
❌ Not enough valid front points.
❌ Not enough valid front points.
❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done 632 tasks      | elapsed: 197.9min
[Parallel(n_jobs=8)]: Done 669 tasks      | elapsed: 207.6min
[Parallel(n_jobs=8)]: Done 706 tasks      | elapsed: 218.6min


❌ Not enough valid front points.
❌ Not enough valid front points.
❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done 745 tasks      | elapsed: 230.0min
[Parallel(n_jobs=8)]: Done 784 tasks      | elapsed: 241.1min


❌ Not enough valid front points.
❌ Not enough valid front points.


[Parallel(n_jobs=8)]: Done 825 tasks      | elapsed: 252.5min
[Parallel(n_jobs=8)]: Done 866 tasks      | elapsed: 263.7min
[Parallel(n_jobs=8)]: Done 909 tasks      | elapsed: 275.6min
[Parallel(n_jobs=8)]: Done 952 tasks      | elapsed: 287.9min
[Parallel(n_jobs=8)]: Done 997 tasks      | elapsed: 300.1min
[Parallel(n_jobs=8)]: Done 1042 tasks      | elapsed: 312.3min
[Parallel(n_jobs=8)]: Done 1089 tasks      | elapsed: 325.7min
[Parallel(n_jobs=8)]: Done 1136 tasks      | elapsed: 339.0min


✅ Done: 1150, Skipped: 0, Failed: 33, Low-R²: 97
❌ Failed runs (sample):
  λ=0.001, α=100000, m0=0 | error: Wave speed could not be calculated.
  λ=0.001, α=1000000, m0=0 | error: Wave speed could not be calculated.
  λ=0.001, α=10000000, m0=0 | error: Wave speed could not be calculated.
  λ=0.001, α=100000000, m0=0 | error: Wave speed could not be calculated.
  λ=0.01, α=100000, m0=0 | error: Wave speed could not be calculated.
  λ=0.01, α=1000000, m0=0 | error: Wave speed could not be calculated.
  λ=0.01, α=10000000, m0=0 | error: Wave speed could not be calculated.
  λ=0.01, α=100000000, m0=0 | error: Wave speed could not be calculated.
  λ=0.1, α=100000, m0=0 | error: Wave speed could not be calculated.
  λ=0.1, α=1000000, m0=0 | error: Wave speed could not be calculated.
  λ=0.1, α=10000000, m0=0 | error: Wave speed could not be calculated.
  λ=0.1, α=100000000, m0=0 | error: Wave speed could not be calculated.
  λ=0.5, α=100000, m0=0 | error: Wave speed could not be calculated.


[Parallel(n_jobs=8)]: Done 1183 out of 1183 | elapsed: 351.1min finished


In [3]:
# ============================================
# 1x3 row: (alpha, lambda) = (10^{-1},10^{-1}), (10^{-2},10^{-2}), (10^{-3},10^{-3})
# Simulated U (orange) vs KPP U(ξ) (black dashed), using c from summary.json
# ============================================
import json
from pathlib import Path
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.interpolate import PchipInterpolator
from scipy.optimize import root_scalar

# ---------- LaTeX styling (safe: no unicode literals) ----------
mpl.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Palatino"],
    "axes.unicode_minus": False,
})

# ---------- tolerant run-dir helpers ----------
def _token_variants(v: float):
    v = float(v)
    dec   = f"{v:.12f}".rstrip("0").rstrip(".")
    plain = f"{v:g}"
    sci   = f"{v:.0e}"; s,e = sci.split("e"); sci_neat = f"{s}e{int(e)}"
    toks = {dec, plain, sci, sci_neat}
    toks |= {t.replace(".", "p") for t in list(toks)}
    return toks

def _exactish_dir(parent: Path, prefix: str, value: float):
    if not parent.exists(): return None
    for tok in _token_variants(value):
        p = parent / f"{prefix}_{tok}"
        if p.is_dir(): return p
    return None

def find_run_dir(base_roots, lam, alpha, m0):
    if isinstance(base_roots, (str, Path)): base_roots = (base_roots,)
    for root in base_roots:
        root = Path(root)
        d1 = _exactish_dir(root, "lambda", lam)
        if d1 is None: continue
        d2 = _exactish_dir(d1, "alpha", alpha)
        if d2 is None: continue
        d3 = _exactish_dir(d2, "m0", m0)
        if d3 is not None: return d3
    return None

# ---------- load c, m0 and snapshots ----------
def load_run(run_dir: Path):
    meta = json.loads((run_dir / "summary.json").read_text())
    z = np.load(run_dir / "snapshots.npz", allow_pickle=True)
    return dict(
        c=float(meta["wave_speed"]),
        lam=float(meta.get("lambda_val", meta.get("lambda", np.nan))),
        alpha=float(meta.get("alpha", np.nan)),
        m0=float(meta.get("m0", np.nan)),
        x=np.asarray(z["x"], float),
        times=np.asarray(z["times"], float),
        U=np.asarray(z.get("N_arr", z.get("u_arr")), float),
    )

# ---------- U=0.5 front (monotone spline) ----------
def front_x_at_time(x, Urow, threshold=0.5, band=(0.1,0.9)):
    mask = (Urow > band[0]) & (Urow < band[1])
    if mask.sum() < 5: return None
    xloc = x[mask]; uloc = Urow[mask]
    order = np.argsort(xloc); xloc = xloc[order]; uloc = uloc[order]
    spl = PchipInterpolator(xloc, uloc, extrapolate=True)
    s = np.sign(uloc - threshold)
    idx = np.where(s[:-1]*s[1:] < 0)[0]
    if len(idx)==0: return None
    i = idx[0]; xl, xr = xloc[i], xloc[i+1]
    sol = root_scalar(lambda xv: spl(xv)-threshold, bracket=[xl,xr], method="brentq")
    return sol.root if sol.converged else None

# ---------- robust centering around t_ref ----------
def best_center_x(x, Uall, times, t_ref,
                  thresholds=(0.5, 0.45, 0.55),
                  bands=((0.1,0.9),(0.05,0.95),(0.0,1.0)),
                  time_window=120.0):
    dt = np.abs(times - t_ref)
    cand = np.argsort(dt)
    cand = [k for k in cand if dt[k] <= time_window]
    for k in cand:
        Urow = Uall[k]
        for thr in thresholds:
            for band in bands:
                x0 = front_x_at_time(x, Urow, threshold=thr, band=band)
                if x0 is not None:
                    return k, x0
    # fallback: center at domain midpoint if no clean crossing found
    k = int(np.argmin(np.abs(times - t_ref)))
    return k, 0.5*(x[0]+x[-1])

# ---------- 10^k neat formatter ----------
def fmt_pow10(v: float):
    v = float(v)
    if np.isclose(v, 1.0):  return "1"
    if np.isclose(v, 10.0): return "10"
    if v == 0: return "0"
    k = int(np.round(np.log10(abs(v))))
    if np.isclose(v, 10.0**k):
        return rf"10^{{{k}}}"
    mant = v/(10.0**k)
    return rf"{mant:.2g}\times 10^{{{k}}}"

# ---------- KPP (constant-D) solver using c and m0 ----------
def kpp_rhs(xi, y, c, D):
    U, V = y
    return [V, -(c/D)*V - (1.0/D)*U*(1.0-U)]

def solve_kpp_ivp(c, m0, L=40.0, epsL=1e-6, epsR=1e-8,
                  solver="Radau", rtol=1e-8, atol=1e-11, max_step=0.05):
    D = float(m0*(1.0-m0))
    # back-tail decay rate (linearized near U≈1)
    r = (-c + np.sqrt(c*c + 4.0*D)) / (2.0*D)
    U0, V0 = 1.0-epsL, -r*epsL

    def hit_tip(xi, y): return y[0] - epsR
    hit_tip.terminal = True; hit_tip.direction = -1

    sol = solve_ivp(lambda x,y: kpp_rhs(x, y, c, D),
                    t_span=(-L, +L), y0=[U0, V0],
                    method=solver, rtol=rtol, atol=atol,
                    max_step=max_step, events=hit_tip)
    return sol.t, sol.y[0]

# ---------- 1×3 row plot ----------
def row_sim_vs_kpp_same_exponents(base_roots, exponents=(-1, -2, -3),
                                  m0=0.5, t_ref=100,
                                  time_window=120.0,
                                  xi_span_factor=12.0,  # ±(xi_span_factor*c)
                                  xi_min_span=15.0,     # at least ±this
                                  tumor_color="#ff8c00",
                                  kpp_style={"color":"k","ls":"--","lw":2.2},
                                  figsize=(12, 4.5),
                                  tick_fs=12, label_fs=14, title_fs=15):
    ncols = len(exponents)
    fig, axes = plt.subplots(1, ncols, figsize=figsize, sharex=True, sharey=True)
    if ncols == 1:
        axes = [axes]

    legend_handles = None

    for j, e in enumerate(exponents):
        ax = axes[j]
        lam = 10.0**e
        alpha = 10.0**e

        run_dir = find_run_dir(base_roots, lam, alpha, m0)
        if run_dir is None:
            ax.text(0.5, 0.5, "(run not found)", ha="center", va="center", transform=ax.transAxes)
            ax.set_axis_off()
            continue

        data = load_run(run_dir)
        c, x, T, Uall = data["c"], data["x"], data["times"], data["U"]
        k, x0 = best_center_x(x, Uall, T, t_ref, time_window=time_window)
        xi_num = x - x0

        # choose ξ half-span
        xi_half = max(xi_span_factor*abs(c), xi_min_span)

        # KPP line (integrate a bit beyond displayed window)
        xi_kpp, U_kpp = solve_kpp_ivp(c, m0, L=xi_half*1.15)

        l1, = ax.plot(xi_num, Uall[k], color=tumor_color, lw=2.4, alpha=0.9, label=r"$U$ (simulated)")
        l2, = ax.plot(xi_kpp, U_kpp, label=r"$U(\xi)$ (KPP)", **kpp_style)

        if legend_handles is None:
            legend_handles = (l1, l2)

        ax.set_xlim(-xi_half, +xi_half)
        ax.set_ylim(-0.05, 1.05)
        ax.grid(True, ls="--", alpha=0.25)
        ax.tick_params(labelsize=tick_fs)

        ax.set_title(rf"$\alpha={fmt_pow10(alpha)}$, $\lambda={fmt_pow10(lam)}$", fontsize=title_fs)

        if j == 0:
            ax.set_ylabel(r"$U(\xi)$", fontsize=label_fs)
        ax.set_xlabel(r"$\xi$", fontsize=label_fs)

    if legend_handles is not None:
        fig.legend(legend_handles, [h.get_label() for h in legend_handles],
                   loc="center left", bbox_to_anchor=(1.01, 0.5),
                   frameon=False, fontsize=12)

    fig.tight_layout(rect=[0.05, 0.05, 0.88, 0.95])
    plt.show()
    return fig, axes

# ------------------------------
# Example call
# ------------------------------
if __name__ == "__main__":
    BASES = ("speeds_func_4", "speeds_func_l", "speeds_func_u")
    row_sim_vs_kpp_same_exponents(
        base_roots=BASES,
        exponents=(-1, -2, -3),  # columns: 10^{-1}, 10^{-2}, 10^{-3}
        m0=0.5,
        t_ref=100,               # target time; code searches ±120 for a clean crossing
        time_window=120.0,
        xi_span_factor=12.0,
        xi_min_span=15.0
    )# ============================================
# 1x3 row: (alpha, lambda) = (10^{-1},10^{-1}), (10^{-2},10^{-2}), (10^{-3},10^{-3})
# Simulated U (orange) vs KPP U(ξ) (black dashed), using c from summary.json
# ============================================
import json
from pathlib import Path
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.interpolate import PchipInterpolator
from scipy.optimize import root_scalar

# ---------- LaTeX styling (safe: no unicode literals) ----------
mpl.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Palatino"],
    "axes.unicode_minus": False,
})

# ---------- tolerant run-dir helpers ----------
def _token_variants(v: float):
    v = float(v)
    dec   = f"{v:.12f}".rstrip("0").rstrip(".")
    plain = f"{v:g}"
    sci   = f"{v:.0e}"; s,e = sci.split("e"); sci_neat = f"{s}e{int(e)}"
    toks = {dec, plain, sci, sci_neat}
    toks |= {t.replace(".", "p") for t in list(toks)}
    return toks

def _exactish_dir(parent: Path, prefix: str, value: float):
    if not parent.exists(): return None
    for tok in _token_variants(value):
        p = parent / f"{prefix}_{tok}"
        if p.is_dir(): return p
    return None

def find_run_dir(base_roots, lam, alpha, m0):
    if isinstance(base_roots, (str, Path)): base_roots = (base_roots,)
    for root in base_roots:
        root = Path(root)
        d1 = _exactish_dir(root, "lambda", lam)
        if d1 is None: continue
        d2 = _exactish_dir(d1, "alpha", alpha)
        if d2 is None: continue
        d3 = _exactish_dir(d2, "m0", m0)
        if d3 is not None: return d3
    return None

# ---------- load c, m0 and snapshots ----------
def load_run(run_dir: Path):
    meta = json.loads((run_dir / "summary.json").read_text())
    z = np.load(run_dir / "snapshots.npz", allow_pickle=True)
    return dict(
        c=float(meta["wave_speed"]),
        lam=float(meta.get("lambda_val", meta.get("lambda", np.nan))),
        alpha=float(meta.get("alpha", np.nan)),
        m0=float(meta.get("m0", np.nan)),
        x=np.asarray(z["x"], float),
        times=np.asarray(z["times"], float),
        U=np.asarray(z.get("N_arr", z.get("u_arr")), float),
    )

# ---------- U=0.5 front (monotone spline) ----------
def front_x_at_time(x, Urow, threshold=0.5, band=(0.1,0.9)):
    mask = (Urow > band[0]) & (Urow < band[1])
    if mask.sum() < 5: return None
    xloc = x[mask]; uloc = Urow[mask]
    order = np.argsort(xloc); xloc = xloc[order]; uloc = uloc[order]
    spl = PchipInterpolator(xloc, uloc, extrapolate=True)
    s = np.sign(uloc - threshold)
    idx = np.where(s[:-1]*s[1:] < 0)[0]
    if len(idx)==0: return None
    i = idx[0]; xl, xr = xloc[i], xloc[i+1]
    sol = root_scalar(lambda xv: spl(xv)-threshold, bracket=[xl,xr], method="brentq")
    return sol.root if sol.converged else None

# ---------- robust centering around t_ref ----------
def best_center_x(x, Uall, times, t_ref,
                  thresholds=(0.5, 0.45, 0.55),
                  bands=((0.1,0.9),(0.05,0.95),(0.0,1.0)),
                  time_window=120.0):
    dt = np.abs(times - t_ref)
    cand = np.argsort(dt)
    cand = [k for k in cand if dt[k] <= time_window]
    for k in cand:
        Urow = Uall[k]
        for thr in thresholds:
            for band in bands:
                x0 = front_x_at_time(x, Urow, threshold=thr, band=band)
                if x0 is not None:
                    return k, x0
    # fallback: center at domain midpoint if no clean crossing found
    k = int(np.argmin(np.abs(times - t_ref)))
    return k, 0.5*(x[0]+x[-1])

# ---------- 10^k neat formatter ----------
def fmt_pow10(v: float):
    v = float(v)
    if np.isclose(v, 1.0):  return "1"
    if np.isclose(v, 10.0): return "10"
    if v == 0: return "0"
    k = int(np.round(np.log10(abs(v))))
    if np.isclose(v, 10.0**k):
        return rf"10^{{{k}}}"
    mant = v/(10.0**k)
    return rf"{mant:.2g}\times 10^{{{k}}}"

# ---------- KPP (constant-D) solver using c and m0 ----------
def kpp_rhs(xi, y, c, D):
    U, V = y
    return [V, -(c/D)*V - (1.0/D)*U*(1.0-U)]

def solve_kpp_ivp(c, m0, L=40.0, epsL=1e-6, epsR=1e-8,
                  solver="Radau", rtol=1e-8, atol=1e-11, max_step=0.05):
    D = float(m0*(1.0-m0))
    # back-tail decay rate (linearized near U≈1)
    r = (-c + np.sqrt(c*c + 4.0*D)) / (2.0*D)
    U0, V0 = 1.0-epsL, -r*epsL

    def hit_tip(xi, y): return y[0] - epsR
    hit_tip.terminal = True; hit_tip.direction = -1

    sol = solve_ivp(lambda x,y: kpp_rhs(x, y, c, D),
                    t_span=(-L, +L), y0=[U0, V0],
                    method=solver, rtol=rtol, atol=atol,
                    max_step=max_step, events=hit_tip)
    return sol.t, sol.y[0]

# ---------- 1×3 row plot ----------
def row_sim_vs_kpp_same_exponents(base_roots, exponents=(-1, -2, -3),
                                  m0=0.5, t_ref=100,
                                  time_window=120.0,
                                  xi_span_factor=12.0,  # ±(xi_span_factor*c)
                                  xi_min_span=15.0,     # at least ±this
                                  tumor_color="#ff8c00",
                                  kpp_style={"color":"k","ls":"--","lw":2.2},
                                  figsize=(12, 4.5),
                                  tick_fs=12, label_fs=14, title_fs=15):
    ncols = len(exponents)
    fig, axes = plt.subplots(1, ncols, figsize=figsize, sharex=True, sharey=True)
    if ncols == 1:
        axes = [axes]

    legend_handles = None

    for j, e in enumerate(exponents):
        ax = axes[j]
        lam = 10.0**e
        alpha = 10.0**e

        run_dir = find_run_dir(base_roots, lam, alpha, m0)
        if run_dir is None:
            ax.text(0.5, 0.5, "(run not found)", ha="center", va="center", transform=ax.transAxes)
            ax.set_axis_off()
            continue

        data = load_run(run_dir)
        c, x, T, Uall = data["c"], data["x"], data["times"], data["U"]
        k, x0 = best_center_x(x, Uall, T, t_ref, time_window=time_window)
        xi_num = x - x0

        # choose ξ half-span
        xi_half = max(xi_span_factor*abs(c), xi_min_span)

        # KPP line (integrate a bit beyond displayed window)
        xi_kpp, U_kpp = solve_kpp_ivp(c, m0, L=xi_half*1.15)

        l1, = ax.plot(xi_num, Uall[k], color=tumor_color, lw=2.4, alpha=0.9, label=r"$U$ (simulated)")
        l2, = ax.plot(xi_kpp, U_kpp, label=r"$U(\xi)$ (KPP)", **kpp_style)

        if legend_handles is None:
            legend_handles = (l1, l2)

        ax.set_xlim(-xi_half, +xi_half)
        ax.set_ylim(-0.05, 1.05)
        ax.grid(True, ls="--", alpha=0.25)
        ax.tick_params(labelsize=tick_fs)

        ax.set_title(rf"$\alpha={fmt_pow10(alpha)}$, $\lambda={fmt_pow10(lam)}$", fontsize=title_fs)

        if j == 0:
            ax.set_ylabel(r"$U(\xi)$", fontsize=label_fs)
        ax.set_xlabel(r"$\xi$", fontsize=label_fs)

    if legend_handles is not None:
        fig.legend(legend_handles, [h.get_label() for h in legend_handles],
                   loc="center left", bbox_to_anchor=(1.01, 0.5),
                   frameon=False, fontsize=12)

    fig.tight_layout(rect=[0.05, 0.05, 0.88, 0.95])
    plt.show()
    return fig, axes

# ------------------------------
# Example call
# ------------------------------
if __name__ == "__main__":
    BASES = ("speeds_func_4", "speeds_func_l", "speeds_func_u")
    row_sim_vs_kpp_same_exponents(
        base_roots=BASES,
        exponents=(-1, -2, -3),  # columns: 10^{-1}, 10^{-2}, 10^{-3}
        m0=0.5,
        t_ref=100,               # target time; code searches ±120 for a clean crossing
        time_window=120.0,
        xi_span_factor=12.0,
        xi_min_span=15.0
    )

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   2 tasks      | elapsed:  3.3min
[Parallel(n_jobs=8)]: Done   9 tasks      | elapsed:  6.2min
[Parallel(n_jobs=8)]: Done  16 tasks      | elapsed:  6.3min
[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed: 11.7min
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed: 14.2min
[Parallel(n_jobs=8)]: Done  45 tasks      | elapsed: 16.6min
[Parallel(n_jobs=8)]: Done  56 tasks      | elapsed: 19.0min
[Parallel(n_jobs=8)]: Done  69 tasks      | elapsed: 23.5min
[Parallel(n_jobs=8)]: Done  82 tasks      | elapsed: 27.9min
[Parallel(n_jobs=8)]: Done  97 tasks      | elapsed: 32.4min
[Parallel(n_jobs=8)]: Done 112 tasks      | elapsed: 35.2min
[Parallel(n_jobs=8)]: Done 129 tasks      | elapsed: 41.7min
[Parallel(n_jobs=8)]: Done 146 tasks      | elapsed: 46.3min
[Parallel(n_jobs=8)]: Done 165 tasks      | elapsed: 51.3min
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed: 56.4min
[Parallel(

✅ Done: 975, Skipped: 0, Failed: 0, Low-R²: 147
⚠️  Low-R² runs (R² < 0.999):
  λ=0.001, α=0.1, m0=0 | c=0.1825422738218574 | r2=0.9939823199012505
  λ=0.001, α=0.1, m0=0.01 | c=0.18254901365038861 | r2=0.9939817531622139
  λ=0.001, α=0.1, m0=0.05 | c=0.18245838266681294 | r2=0.9939920615285572
  λ=0.001, α=0.1, m0=0.1 | c=0.1820808563510841 | r2=0.9940340471360012
  λ=0.001, α=0.1, m0=0.2 | c=0.180424092727537 | r2=0.9942266700510298
  λ=0.001, α=0.1, m0=0.3 | c=0.17762218665941246 | r2=0.9946744215865934
  λ=0.001, α=0.1, m0=0.4 | c=0.17429039383805775 | r2=0.9957977145424453
  λ=0.001, α=0.1, m0=0.5 | c=0.16994869857466 | r2=0.9972079210629353
  λ=0.001, α=0.1, m0=0.6 | c=0.16240009330638527 | r2=0.9982217199467363
  λ=0.001, α=0.1, m0=0.7 | c=0.149694602887888 | r2=0.9988480618761971
  λ=0.001, α=0.2, m0=0 | c=0.14554441406227098 | r2=0.9970117699364457
  λ=0.001, α=0.2, m0=0.01 | c=0.14555348055396236 | r2=0.9970102882911462
  λ=0.001, α=0.2, m0=0.05 | c=0.14554556741467964 | r2=0

[Parallel(n_jobs=8)]: Done 975 out of 975 | elapsed: 312.7min finished


In [5]:
# ==== Imports ====
import os, json
from pathlib import Path
import numpy as np
from joblib import Parallel, delayed

# -------------------------
# Save / path helpers
# -------------------------
def _ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def _atomic_write_json(path: Path, obj):
    tmp = path.with_suffix(".tmp")
    with open(tmp, "w") as f:
        json.dump(obj, f, indent=2)
    os.replace(tmp, path)

def _save_summary(run_dir: Path, meta: dict):
    with open(run_dir / "summary.json", "w") as f:
        json.dump(meta, f, indent=2)

def _save_fronts(run_dir: Path, t_fronts, x_fronts, name=None):
    fname = "fronts.npz" if not name else f"fronts_{name}.npz"
    np.savez_compressed(run_dir / fname,
                        t_fronts=np.asarray(t_fronts),
                        x_fronts=np.asarray(x_fronts))

def _save_snapshots_every_stride(run_dir: Path, model, stride=150):
    """
    Save EVERY `stride`-th snapshot (plus the last one), with U and M kept separate.
    """
    idx = np.unique(np.concatenate([
        np.arange(0, model.Nt, stride),
        np.array([model.Nt - 1])
    ]))
    np.savez_compressed(
        run_dir / "snapshots.npz",
        x=model.x,
        times=model.times[idx],
        N_arr=model.N_arr[idx, :],   # tumour u(x,t)
        M_arr=model.M_arr[idx, :]    # ECM   m(x,t)
    )

def _fmt_val(v):
    # compact label in folder names: keeps ints clean (e.g., 10 not 10.0)
    if isinstance(v, (int, np.integer)) or (isinstance(v, float) and v.is_integer()):
        return f"{int(v)}"
    s = f"{v}"
    return s.rstrip('0').rstrip('.') if '.' in s else s

# -------------------------
# Single-run worker
# -------------------------
def run_one(lam, alpha, m0,
            base_dir="speeds_func",
            model_kwargs=None,
            overwrite=False,
            snapshot_stride=150):
    """
    Builds, solves, measures, and saves one (λ, α, m0) run.

    Skips a run if base_dir/lambda_*/alpha_*/m0_*/summary.json exists and overwrite=False.
    Saves:
      - summary.json (metadata + c, R^2)
      - fronts.npz    (t_fronts, x_fronts for N at threshold 0.5)
      - snapshots.npz (x, times[idx], N_arr[idx,:], M_arr[idx,:]) with idx every `snapshot_stride`
    """
    if model_kwargs is None:
        model_kwargs = {}

    try:
        # Shallow copy of shared kwargs
        local_kwargs = dict(model_kwargs)

        # Optional per-λ overrides (example: finer time-step for very large λ, if you want)
        # if float(lam) >= 1e3:
        #     local_kwargs.update(dict(dt=0.01))

        base = Path(base_dir)
        lam_dir = base / f"lambda_{_fmt_val(lam)}"
        alpha_dir = lam_dir / f"alpha_{_fmt_val(alpha)}"
        run_dir = alpha_dir / f"m0_{_fmt_val(m0)}"
        _ensure_dir(run_dir)

        # Skip if already done (unless overwrite=True)
        if not overwrite and (run_dir / "summary.json").exists():
            return ("skipped", lam, alpha, m0)

        # Avoid thread over-subscription (OpenMP/BLAS)
        os.environ.setdefault("OMP_NUM_THREADS", "1")
        os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
        os.environ.setdefault("MKL_NUM_THREADS", "1")
        os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

        # Build and solve
        model = Dissertation_Func_1D(k=lam, alpha=alpha, m0=m0, **local_kwargs)
        model.solve()

        # Speed on u (N)
        c, b, r2 = model.estimate_wave_speed(
            threshold=0.5, band=(0.1, 0.9), spline_type='cubic',
            plot=False, target='N'
        )
        if (c is None) or (isinstance(c, float) and np.isnan(c)):
            raise ValueError("Wave speed could not be calculated.")
        model.wave_speed = c

        # Front points (N)
        t_fronts, x_fronts = model.track_wavefront_local_interpolation(
            threshold=0.5, band=(0.1, 0.9), spline_type='cubic', target='N'
        )

        # Save artifacts
        _save_summary(run_dir, dict(
            lambda_val=float(lam),
            alpha=float(alpha),
            m0=float(m0),
            wave_speed=float(c),
            r2=(float(r2) if r2 is not None else None),
            # useful context
            dt=model.dt, T=model.T, L=model.L, N=model.N,
            init_type=model.init_type,
            steepness=getattr(model, "steepness", None),
            perc=getattr(model, "perc", None),
            t_start=model.t_start, t_end=model.t_end,
            num_points=getattr(model, "num_points", None),
            saved_stride=int(snapshot_stride)
        ))
        _save_fronts(run_dir, t_fronts, x_fronts, name="N")
        _save_snapshots_every_stride(run_dir, model, stride=snapshot_stride)

        return ("done", lam, alpha, m0, float(c), (float(r2) if r2 is not None else None))

    except Exception as e:
        return ("failed", lam, alpha, m0, str(e))

# -------------------------
# Parallel grid runner
# -------------------------
def run_grid(lambda_vals, alpha_vals, m0_vals,
             base_dir="speeds_func",
             model_kwargs=None,
             overwrite=False,
             snapshot_stride=150,
             n_jobs=-1, verbose=10):
    """
    Launch all (λ, α, m0) runs in parallel and log failures & low-R² cases.
    """
    if model_kwargs is None:
        model_kwargs = {}

    tasks = [(lam, alpha, m0) for lam in lambda_vals for alpha in alpha_vals for m0 in m0_vals]
    results = Parallel(n_jobs=n_jobs, verbose=verbose, backend="loky")(
        delayed(run_one)(
            lam, alpha, m0,
            base_dir=base_dir,
            model_kwargs=model_kwargs,
            overwrite=overwrite,
            snapshot_stride=snapshot_stride
        ) for lam, alpha, m0 in tasks
    )

    done, skipped, failed, low_r2 = [], [], [], []
    for r in results:
        tag = r[0]
        if tag == "done":
            _, lam, alpha_eff, m0_eff, c, r2 = r
            done.append({"lambda": lam, "alpha": alpha_eff, "m0": m0_eff, "c": c, "r2": r2})
            if (r2 is None) or (isinstance(r2, float) and (np.isnan(r2) or r2 < .999)):
                low_r2.append({"lambda": lam, "alpha": alpha_eff, "m0": m0_eff, "c": c, "r2": r2})
        elif tag == "skipped":
            _, lam, alpha_eff, m0_eff = r
            skipped.append({"lambda": lam, "alpha": alpha_eff, "m0": m0_eff})
        elif tag == "failed":
            _, lam, alpha_orig, m0_orig, msg = r
            failed.append({"lambda": lam, "alpha": alpha_orig, "m0": m0_orig, "error": msg})

    base = Path(base_dir)
    _ensure_dir(base)
    _atomic_write_json(base / "failed_runs.json", failed)
    _atomic_write_json(base / "low_r2_runs.json", low_r2)

    print(f"✅ Done: {len(done)}, Skipped: {len(skipped)}, Failed: {len(failed)}, Low-R²: {len(low_r2)}")
    if failed:
        print("❌ Failed runs (sample):")
        for item in failed[:20]:
            print(f"  λ={item['lambda']}, α={item['alpha']}, m0={item['m0']} | error: {item['error']}")
    if low_r2:
        print("⚠️  Low-R² runs (R² < 0.999):")
        for item in low_r2[:20]:
            print(f"  λ={item['lambda']}, α={item['alpha']}, m0={item['m0']} | c={item['c']} | r2={item['r2']}")

    return {"done": done, "skipped": skipped, "failed": failed, "low_r2": low_r2}

# -------------------------
# Example usage in the notebook
# -------------------------
# 1) Make sure you've already done:
#    from func import Dissertation_Func_1D
#
# 2) Define your grids and shared kwargs:

lambda_vals = [1e-4,1e-5]
alpha_vals  = [1e-3,1e-2,1e-1]
m0_vals     = [0.5]

shared_kwargs = dict(
    # Your requested defaults
    L=200, N=20001, T=400, dt=0.1,
    init_type="tanh", steepness=0.85, perc=0.4,
    t_start=100, t_end=350, num_points=250,
    n0=1.0, K=1.0, rho=1.0, D=1.0, Mmax=1.0
)

# 3) Run the grid (tweak n_jobs as your machine allows)
results = run_grid(lambda_vals, alpha_vals, m0_vals,                    
                   base_dir="speeds_func_tiny2",
                    model_kwargs=shared_kwargs,
                    snapshot_stride=150,   # <- only every 150th snapshot (plus last)
                    overwrite=False, n_jobs=8, verbose=10)

# 4) (Optional) Quick flagged summary (after results) — example:
print("\n=== Summary of Problematic Runs ===")
for item in results["failed"]:
    print(f"FAIL -> λ={item['lambda']}, α={item['alpha']}, m0={item['m0']} | {item['error']}")
for item in results["low_r2"][:20]:
    print(f"LOW R² -> λ={item['lambda']}, α={item['alpha']}, m0={item['m0']} | r2={item['r2']}")

[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.


[Parallel(n_jobs=8)]: Done   2 out of   6 | elapsed:  2.7min remaining:  5.5min
[Parallel(n_jobs=8)]: Done   3 out of   6 | elapsed:  2.7min remaining:  2.7min
[Parallel(n_jobs=8)]: Done   4 out of   6 | elapsed:  2.7min remaining:  1.4min


✅ Done: 6, Skipped: 0, Failed: 0, Low-R²: 2
⚠️  Low-R² runs (R² < 0.999):
  λ=0.0001, α=0.1, m0=0.5 | c=0.16994178908803234 | r2=0.9972097066131171
  λ=1e-05, α=0.1, m0=0.5 | c=0.16994115453915076 | r2=0.9972098645032523

=== Summary of Problematic Runs ===
LOW R² -> λ=0.0001, α=0.1, m0=0.5 | r2=0.9972097066131171
LOW R² -> λ=1e-05, α=0.1, m0=0.5 | r2=0.9972098645032523


[Parallel(n_jobs=8)]: Done   6 out of   6 | elapsed:  2.7min finished


In [8]:
# ==== Imports ====
import os, json
from pathlib import Path
import numpy as np
from joblib import Parallel, delayed

# -------------------------
# Save / path helpers
# -------------------------
def _ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def _atomic_write_json(path: Path, obj):
    tmp = path.with_suffix(".tmp")
    with open(tmp, "w") as f:
        json.dump(obj, f, indent=2)
    os.replace(tmp, path)

def _save_summary(run_dir: Path, meta: dict):
    with open(run_dir / "summary.json", "w") as f:
        json.dump(meta, f, indent=2)

def _save_fronts(run_dir: Path, t_fronts, x_fronts, name=None):
    fname = "fronts.npz" if not name else f"fronts_{name}.npz"
    np.savez_compressed(run_dir / fname,
                        t_fronts=np.asarray(t_fronts),
                        x_fronts=np.asarray(x_fronts))

def _save_snapshots_every_stride(run_dir: Path, model, stride=150):
    """
    Save EVERY `stride`-th snapshot (plus the last one), with U and M kept separate.
    """
    idx = np.unique(np.concatenate([
        np.arange(0, model.Nt, stride),
        np.array([model.Nt - 1])
    ]))
    np.savez_compressed(
        run_dir / "snapshots.npz",
        x=model.x,
        times=model.times[idx],
        N_arr=model.N_arr[idx, :],   # tumour u(x,t)
        M_arr=model.M_arr[idx, :]    # ECM   m(x,t)
    )

def _fmt_val(v):
    # compact label in folder names: keeps ints clean (e.g., 10 not 10.0)
    if isinstance(v, (int, np.integer)) or (isinstance(v, float) and float(v).is_integer()):
        return f"{int(v)}"
    s = f"{v}"
    return s.rstrip('0').rstrip('.') if '.' in s else s

# -------------------------
# Single-run worker
# -------------------------
def run_one(lam, alpha, m0, n0,
            base_dir="speeds_func",
            model_kwargs=None,
            overwrite=False,
            snapshot_stride=150):
    """
    Builds, solves, measures, and saves one (λ, α, m0, n0) run.

    Skips a run if base_dir/lambda_*/alpha_*/m0_*/n0_*/summary.json exists and overwrite=False.
    Saves:
      - summary.json (metadata + c, R^2)
      - fronts_N.npz  (t_fronts, x_fronts for N at threshold 0.5)
      - snapshots.npz (x, times[idx], N_arr[idx,:], M_arr[idx,:]) with idx every `snapshot_stride`
    """
    if model_kwargs is None:
        model_kwargs = {}

    try:
        # Shallow copy of shared kwargs
        local_kwargs = dict(model_kwargs)
        # Ensure per-run n0 overrides the shared one
        local_kwargs["n0"] = float(n0)

        base = Path(base_dir)
        lam_dir   = base / f"lambda_{_fmt_val(lam)}"
        alpha_dir = lam_dir / f"alpha_{_fmt_val(alpha)}"
        m0_dir    = alpha_dir / f"m0_{_fmt_val(m0)}"
        run_dir   = m0_dir / f"n0_{_fmt_val(n0)}"
        _ensure_dir(run_dir)

        # Skip if already done (unless overwrite=True)
        if not overwrite and (run_dir / "summary.json").exists():
            return ("skipped", lam, alpha, m0, n0)

        # Avoid thread over-subscription (OpenMP/BLAS)
        os.environ.setdefault("OMP_NUM_THREADS", "1")
        os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
        os.environ.setdefault("MKL_NUM_THREADS", "1")
        os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

        # Build and solve
        # NOTE: import your model in the caller's environment:
        # from func import Dissertation_Func_1D
        model = Dissertation_Func_1D(k=lam, alpha=alpha, m0=m0, **local_kwargs)
        model.solve()

        # Speed on u (N)
        c, b, r2 = model.estimate_wave_speed(
            threshold=0.5, band=(0.1, 0.9), spline_type='cubic',
            plot=False, target='N'
        )
        if (c is None) or (isinstance(c, float) and np.isnan(c)):
            raise ValueError("Wave speed could not be calculated.")
        model.wave_speed = c

        # Front points (N)
        t_fronts, x_fronts = model.track_wavefront_local_interpolation(
            threshold=0.5, band=(0.1, 0.9), spline_type='cubic', target='N'
        )

        # Save artifacts
        _save_summary(run_dir, dict(
            lambda_val=float(lam),
            alpha=float(alpha),
            m0=float(m0),
            n0=float(n0),
            wave_speed=float(c),
            r2=(float(r2) if r2 is not None else None),
            # useful context
            dt=model.dt, T=model.T, L=model.L, N=model.N,
            init_type=model.init_type,
            steepness=getattr(model, "steepness", None),
            perc=getattr(model, "perc", None),
            t_start=model.t_start, t_end=model.t_end,
            num_points=getattr(model, "num_points", None),
            saved_stride=int(snapshot_stride)
        ))
        _save_fronts(run_dir, t_fronts, x_fronts, name="N")
        _save_snapshots_every_stride(run_dir, model, stride=snapshot_stride)

        return ("done", lam, alpha, m0, n0, float(c), (float(r2) if r2 is not None else None))

    except Exception as e:
        return ("failed", lam, alpha, m0, n0, str(e))

# -------------------------
# Parallel grid runner
# -------------------------
def run_grid(lambda_vals, alpha_vals, m0_vals, n0_vals,
             base_dir="speeds_func",
             model_kwargs=None,
             overwrite=False,
             snapshot_stride=150,
             n_jobs=-1, verbose=10):
    """
    Launch all (λ, α, m0, n0) runs in parallel and log failures & low-R² cases.
    """
    if model_kwargs is None:
        model_kwargs = {}

    tasks = [(lam, alpha, m0, n0)
             for lam in lambda_vals
             for alpha in alpha_vals
             for m0 in m0_vals
             for n0 in n0_vals]

    results = Parallel(n_jobs=n_jobs, verbose=verbose, backend="loky")(
        delayed(run_one)(
            lam, alpha, m0, n0,
            base_dir=base_dir,
            model_kwargs=model_kwargs,
            overwrite=overwrite,
            snapshot_stride=snapshot_stride
        ) for lam, alpha, m0, n0 in tasks
    )

    done, skipped, failed, low_r2 = [], [], [], []
    for r in results:
        tag = r[0]
        if tag == "done":
            _, lam, alpha_eff, m0_eff, n0_eff, c, r2 = r
            done.append({"lambda": lam, "alpha": alpha_eff, "m0": m0_eff, "n0": n0_eff, "c": c, "r2": r2})
            if (r2 is None) or (isinstance(r2, float) and (np.isnan(r2) or r2 < .999)):
                low_r2.append({"lambda": lam, "alpha": alpha_eff, "m0": m0_eff, "n0": n0_eff, "c": c, "r2": r2})
        elif tag == "skipped":
            _, lam, alpha_eff, m0_eff, n0_eff = r
            skipped.append({"lambda": lam, "alpha": alpha_eff, "m0": m0_eff, "n0": n0_eff})
        elif tag == "failed":
            _, lam, alpha_orig, m0_orig, n0_orig, msg = r
            failed.append({"lambda": lam, "alpha": alpha_orig, "m0": m0_orig, "n0": n0_orig, "error": msg})

    base = Path(base_dir)
    _ensure_dir(base)
    _atomic_write_json(base / "failed_runs.json", failed)
    _atomic_write_json(base / "low_r2_runs.json", low_r2)

    print(f"✅ Done: {len(done)}, Skipped: {len(skipped)}, Failed: {len(failed)}, Low-R²: {len(low_r2)}")
    if failed:
        print("❌ Failed runs (sample):")
        for item in failed[:20]:
            print(f"  λ={item['lambda']}, α={item['alpha']}, m0={item['m0']}, n0={item['n0']} | error: {item['error']}")
    if low_r2:
        print("⚠️  Low-R² runs (R² < 0.999):")
        for item in low_r2[:20]:
            print(f"  λ={item['lambda']}, α={item['alpha']}, m0={item['m0']}, n0={item['n0']} | c={item['c']} | r2={item['r2']}")

    return {"done": done, "skipped": skipped, "failed": failed, "low_r2": low_r2}

# -------------------------
# Example usage in the notebook
# -------------------------
# 1) Make sure you've already done:
#    from func import Dissertation_Func_1D
#
# 2) Define your grids and shared kwargs:

lambda_vals = [0.001, 0.01, 0.1,  1,10, 100, 100000, 1000000, 100000000]
alpha_vals  = [0.001, 0.01, 0.1,  1,  10, 100, 100000, 1000000, 100000000]
m0_vals     = [0.1, 0.5, 0.9]
n0_vals     = [0.1, 0.3, 0.5, 0.7, 0.9]   # <-- NEW

shared_kwargs = dict(
    # Your requested defaults
    L=200, N=20001, T=400, dt=0.1,
    init_type="tanh", steepness=0.85, perc=0.4,
    t_start=100, t_end=350, num_points=250,
    n0=1.0,   # will be overridden per-run by run_one(...)
    K=1.0, rho=1.0, D=1.0, Mmax=1.0
)

# 3) Run the grid (tweak n_jobs as your machine allows)
results = run_grid(lambda_vals, alpha_vals, m0_vals, n0_vals,
                   base_dir="speeds_func_u0_mini",
                   model_kwargs=shared_kwargs,
                   snapshot_stride=150,   # <- only every 150th snapshot (plus last)
                   overwrite=False, n_jobs=-2, verbose=10)

# 4) (Optional) Quick flagged summary (after results) — example:
print("\n=== Summary of Problematic Runs ===")
for item in results["failed"]:
    print(f"FAIL -> λ={item['lambda']}, α={item['alpha']}, m0={item['m0']}, n0={item['n0']} | {item['error']}")
for item in results["low_r2"][:20]:
    print(f"LOW R² -> λ={item['lambda']}, α={item['alpha']}, m0={item['m0']}, n0={item['n0']} | r2={item['r2']}")

[Parallel(n_jobs=-2)]: Using backend LokyBackend with 7 concurrent workers.
[Parallel(n_jobs=-2)]: Done   4 tasks      | elapsed:  3.8min
[Parallel(n_jobs=-2)]: Done  11 tasks      | elapsed:  6.6min
[Parallel(n_jobs=-2)]: Done  18 tasks      | elapsed:  9.9min
[Parallel(n_jobs=-2)]: Done  27 tasks      | elapsed: 13.6min
[Parallel(n_jobs=-2)]: Done  36 tasks      | elapsed: 19.7min
[Parallel(n_jobs=-2)]: Done  47 tasks      | elapsed: 23.1min
[Parallel(n_jobs=-2)]: Done  58 tasks      | elapsed: 29.3min
[Parallel(n_jobs=-2)]: Done  71 tasks      | elapsed: 35.6min
[Parallel(n_jobs=-2)]: Done  84 tasks      | elapsed: 40.5min
[Parallel(n_jobs=-2)]: Done  99 tasks      | elapsed: 48.8min
[Parallel(n_jobs=-2)]: Done 114 tasks      | elapsed: 54.5min
[Parallel(n_jobs=-2)]: Done 131 tasks      | elapsed: 60.8min
[Parallel(n_jobs=-2)]: Done 148 tasks      | elapsed: 70.3min
[Parallel(n_jobs=-2)]: Done 167 tasks      | elapsed: 77.1min
[Parallel(n_jobs=-2)]: Done 186 tasks      | elapsed: 87

KeyboardInterrupt: 