<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/phi4_all_in_one_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# phi4_all_in_one.py
# Single-file 2D φ^4 field simulator with optional coupled χ field, Langevin noise,
# energy diagnostics, structure factor, entropy/complexity maps, live viewer, and CSV/PNG outputs.

import os
import json
import math
import time
import argparse
from dataclasses import dataclass, asdict
from typing import Tuple, Literal, Optional

import numpy as np
import matplotlib.pyplot as plt

# =========================
# Configuration (defaults)
# =========================
@dataclass
class Config:
    # Grid / time
    nx: int = 256
    ny: int = 256
    dx: float = 1.0
    dt: float = 0.1
    steps: int = 5000

    # Physics (φ)
    c: float = 1.0            # wave speed for φ
    lam: float = 1.0          # self-interaction λ
    v: float = 1.0            # vacuum expectation value
    eta: float = 0.0          # bulk damping for φ

    # Optional second field χ (environment / adaptive coupling)
    use_chi: bool = False
    c_chi: float = 1.0        # wave speed for χ
    mu_chi: float = 1.0       # quadratic mass term for χ potential 0.5*mu^2*χ^2
    eta_chi: float = 0.0      # damping for χ
    g_couple: float = 0.0     # coupling energy E_couple = g * φ * χ

    # Absorbing boundary
    absorb_width: int = 0     # boundary ramp width (cells); 0 disables
    absorb_eta: float = 0.1   # max edge damping

    # Stochastic/Langevin noise (added to acceleration of φ)
    langevin_sigma: float = 0.0   # noise strength; 0 disables
    # If > 0, acceleration gets an additive Gaussian ξ with variance σ^2 per step (scaled below)

    # Initialization
    init: Literal["random", "blob", "bias"] = "random"
    random_amp: float = 0.01
    blob_sigma: float = 20.0
    bias_value: float = 0.9

    # Output
    out_dir: str = "out_phi4_all"
    save_every: int = 50
    cmap: str = "seismic"
    seed: int = 42
    dtype: Literal["float32", "float64"] = "float64"

    # Analytics saving cadence
    struct_every: int = 200         # save power spectrum and radial profile
    entropy_every: int = 200        # save sign-entropy map
    complexity_every: int = 200     # save gradient-magnitude map

    # Live viewer
    live: bool = False              # interactive live plotting
    live_every: int = 10            # steps between live updates

    # Adaptive plateau handling
    plateau_window: int = 200       # energy window for slope estimate
    plateau_tol: float = 1e-4       # relative slope threshold to trigger action
    plateau_action: Literal["none", "add_damping"] = "none"
    plateau_eta_target: float = 0.02  # bulk damping to set when plateau detected

# =========================
# Utilities
# =========================
def ensure_dir(p: str) -> None:
    os.makedirs(p, exist_ok=True)

def save_json(path: str, obj) -> None:
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2)

def arr_dtype(cfg: Config):
    return np.float64 if str(cfg.dtype).lower() == "float64" else np.float32

def info(msg: str):
    print(msg, flush=True)

# =========================
# Numerics / physics helpers
# =========================
def five_point_laplacian(phi: np.ndarray, dx: float) -> np.ndarray:
    return (
        np.roll(phi, +1, axis=0) +
        np.roll(phi, -1, axis=0) +
        np.roll(phi, +1, axis=1) +
        np.roll(phi, -1, axis=1) -
        4.0 * phi
    ) / (dx * dx)

def central_gradients(phi: np.ndarray, dx: float) -> Tuple[np.ndarray, np.ndarray]:
    gx = (np.roll(phi, -1, axis=0) - np.roll(phi, +1, axis=0)) / (2.0 * dx)
    gy = (np.roll(phi, -1, axis=1) - np.roll(phi, +1, axis=1)) / (2.0 * dx)
    return gx, gy

def potential_phi(phi: np.ndarray, lam: float, v: float) -> np.ndarray:
    return 0.25 * lam * (phi * phi - v * v) ** 2

def dV_dphi(phi: np.ndarray, lam: float, v: float) -> np.ndarray:
    return lam * phi * (phi * phi - v * v)

def potential_chi(chi: np.ndarray, mu: float) -> np.ndarray:
    # 0.5 * mu^2 * chi^2
    return 0.5 * (mu * mu) * (chi * chi)

def dV_dchi(chi: np.ndarray, mu: float) -> np.ndarray:
    return (mu * mu) * chi

def build_absorb_mask(nx: int, ny: int, width: int, max_eta: float, dtype) -> np.ndarray:
    if width <= 0 or max_eta <= 0:
        return np.zeros((nx, ny), dtype=dtype)
    i = np.arange(nx)
    j = np.arange(ny)
    ii, jj = np.meshgrid(i, j, indexing="ij")
    dist = np.minimum.reduce([ii, jj, nx - 1 - ii, ny - 1 - jj]).astype(dtype)
    ramp = np.clip(1.0 - dist / width, 0.0, 1.0)
    smooth = 0.5 * (1.0 - np.cos(np.pi * ramp))
    return (max_eta * smooth).astype(dtype)

# =========================
# Analytics
# =========================
def domain_wall_length(phi: np.ndarray, dx: float) -> float:
    s = np.sign(phi)
    s[s == 0] = 1
    changes_x = (s != np.roll(s, -1, axis=0)).sum()
    changes_y = (s != np.roll(s, -1, axis=1)).sum()
    length = 0.5 * (changes_x + changes_y) * dx
    return float(length)

def radial_profile(power2d: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    nx, ny = power2d.shape
    cx, cy = (nx // 2), (ny // 2)
    y, x = np.ogrid[:ny, :nx]
    r = np.hypot(x - cx, y - cy)
    r_int = r.astype(np.int32)
    max_r = r_int.max()
    radial_sum = np.bincount(r_int.ravel(), power2d.ravel(), minlength=max_r+1)
    radial_count = np.bincount(r_int.ravel(), minlength=max_r+1)
    radial_count[radial_count == 0] = 1
    k = np.arange(max_r + 1)
    prof = radial_sum / radial_count
    return k, prof

def power_spectrum(phi: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    # 2D FFT power spectrum, centered
    F = np.fft.fftshift(np.fft.fft2(phi))
    P = (F * np.conj(F)).real
    Pnorm = P / P.size
    k, prof = radial_profile(Pnorm)
    return Pnorm, k, prof

def box_sum_2d(arr: np.ndarray, w: int) -> np.ndarray:
    # Integral image box sum for odd window size w
    assert w >= 1 and w % 2 == 1
    pad = w // 2
    padded = np.pad(arr, ((pad, pad), (pad, pad)), mode="wrap")
    c = padded.cumsum(axis=0).cumsum(axis=1)
    # sum over each w×w window using inclusion-exclusion
    s = c[w:, w:] - c[:-w, w:] - c[w:, :-w] + c[:-w, :-w]
    return s

def sign_entropy_map(phi: np.ndarray, w: int) -> np.ndarray:
    # Local binary entropy of sign(φ)
    s = (phi > 0).astype(np.float64)
    count_plus = box_sum_2d(s, w)
    total = float(w * w)
    p = np.clip(count_plus / total, 1e-8, 1 - 1e-8)
    H = -(p * np.log2(p) + (1 - p) * np.log2(1 - p))
    return H

def grad_complexity_map(phi: np.ndarray, dx: float, w: int) -> np.ndarray:
    gx, gy = central_gradients(phi, dx)
    gmag = np.sqrt(gx * gx + gy * gy)
    # average magnitude over window
    gsum = box_sum_2d(gmag, w)
    return gsum / (w * w)

# =========================
# Simulator
# =========================
class FieldSimulator:
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.nx, self.ny = cfg.nx, cfg.ny
        self.dx, self.dt = cfg.dx, cfg.dt
        self.c, self.lam, self.v = cfg.c, cfg.lam, cfg.v
        self.eta = cfg.eta
        self.dtype = arr_dtype(cfg)

        # Absorption masks
        self.absorb = build_absorb_mask(self.nx, self.ny, cfg.absorb_width, cfg.absorb_eta, self.dtype)

        # φ field
        self.phi = np.zeros((self.nx, self.ny), dtype=self.dtype)
        self.phi_prev = np.zeros_like(self.phi)

        # χ field (optional)
        self.use_chi = bool(cfg.use_chi)
        if self.use_chi:
            self.c_chi, self.mu_chi, self.eta_chi = cfg.c_chi, cfg.mu_chi, cfg.eta_chi
            self.g = cfg.g_couple
            self.chi = np.zeros((self.nx, self.ny), dtype=self.dtype)
            self.chi_prev = np.zeros_like(self.chi)

        # Noise
        self.noise_sigma = float(cfg.langevin_sigma)

        self._initialize()
        self._check_cfl()

    def _check_cfl(self):
        s_phi = self.c * self.dt / self.dx
        s_max = 1.0 / math.sqrt(2.0)  # 2D stencil guideline
        if s_phi > s_max:
            info(f"[WARN] CFL(φ) likely violated: c*dt/dx = {s_phi:.3f} > {s_max:.3f}")
        else:
            info(f"[INFO] CFL(φ) OK: c*dt/dx = {s_phi:.3f} <= {s_max:.3f}")
        if self.use_chi:
            s_chi = self.c_chi * self.dt / self.dx
            if s_chi > s_max:
                info(f"[WARN] CFL(χ) likely violated: c_chi*dt/dx = {s_chi:.3f} > {s_max:.3f}")
            else:
                info(f"[INFO] CFL(χ) OK: c_chi*dt/dx = {s_chi:.3f} <= {s_max:.3f}")

    def _initialize(self):
        rng = np.random.default_rng(self.cfg.seed)
        if self.cfg.init == "random":
            self.phi[:] = self.cfg.random_amp * (rng.random((self.nx, self.ny)) - 0.5)
        elif self.cfg.init == "blob":
            i = np.arange(self.nx)
            j = np.arange(self.ny)
            ii, jj = np.meshgrid(i, j, indexing="ij")
            r2 = (ii - (self.nx - 1) / 2.0) ** 2 + (jj - (self.ny - 1) / 2.0) ** 2
            self.phi[:] = np.exp(-r2 / (2.0 * self.cfg.blob_sigma * self.cfg.blob_sigma))
        elif self.cfg.init == "bias":
            self.phi[:] = self.cfg.bias_value
        else:
            raise ValueError(f"Unknown init: {self.cfg.init}")
        self.phi_prev[:] = self.phi

        if self.use_chi:
            # Start χ at rest near zero
            self.chi[:] = 0.0
            self.chi_prev[:] = self.chi

    def step(self):
        # φ dynamics: φ_tt = c^2 ∇²φ − dV/dφ(φ) − g χ − η φ_t + ξ
        lap_phi = five_point_laplacian(self.phi, self.dx)
        force_phi = (self.c * self.c) * lap_phi - dV_dphi(self.phi, self.lam, self.v)

        if self.use_chi:
            force_phi -= self.g * self.chi

        # Langevin noise on acceleration (scaled per step)
        if self.noise_sigma > 0.0:
            # Noise per cell; scale by 1 to represent acceleration noise
            noise = self.noise_sigma * np.random.standard_normal(self.phi.shape).astype(self.phi.dtype)
            force_phi += noise

        eta_eff_phi = self.eta + self.absorb
        damp_phi = eta_eff_phi * self.dt
        phi_new = ((2.0 - damp_phi) * self.phi) - ((1.0 - damp_phi) * self.phi_prev) + (self.dt * self.dt) * force_phi

        if self.use_chi:
            # χ dynamics: χ_tt = c_chi^2 ∇²χ − dV/dχ(χ) − g φ − η_chi χ_t
            lap_chi = five_point_laplacian(self.chi, self.dx)
            force_chi = (self.c_chi * self.c_chi) * lap_chi - dV_dchi(self.chi, self.mu_chi)
            force_chi -= self.g * self.phi
            eta_eff_chi = self.eta_chi + self.absorb
            damp_chi = eta_eff_chi * self.dt
            chi_new = ((2.0 - damp_chi) * self.chi) - ((1.0 - damp_chi) * self.chi_prev) + (self.dt * self.dt) * force_chi
            self.chi_prev, self.chi = self.chi, chi_new

        self.phi_prev, self.phi = self.phi, phi_new

    def energy_components_phi(self) -> Tuple[float, float, float]:
        vel = (self.phi - self.phi_prev) / self.dt
        Ek = 0.5 * (vel * vel)

        gx, gy = central_gradients(self.phi, self.dx)
        Eg = 0.5 * (self.c * self.c) * (gx * gx + gy * gy)

        Ep = potential_phi(self.phi, self.lam, self.v)

        cell = self.dx * self.dx
        return float(Ek.sum() * cell), float(Eg.sum() * cell), float(Ep.sum() * cell)

    def energy_components_chi(self) -> Tuple[float, float, float]:
        vel = (self.chi - self.chi_prev) / self.dt
        Ek = 0.5 * (vel * vel)

        gx, gy = central_gradients(self.chi, self.dx)
        Eg = 0.5 * (self.c_chi * self.c_chi) * (gx * gx + gy * gy)

        Ep = potential_chi(self.chi, self.mu_chi)

        cell = self.dx * self.dx
        return float(Ek.sum() * cell), float(Eg.sum() * cell), float(Ep.sum() * cell)

    def energy_coupling(self) -> float:
        if not self.use_chi or self.g == 0.0:
            return 0.0
        cell = self.dx * self.dx
        # E_couple = g ∑ φ χ
        return float(self.g * (self.phi * self.chi).sum() * cell)

    def total_energy(self) -> float:
        Ekφ, Egφ, Epφ = self.energy_components_phi()
        E = Ekφ + Egφ + Epφ
        if self.use_chi:
            Ekχ, Egχ, Epχ = self.energy_components_chi()
            E += (Ekχ + Egχ + Epχ) + self.energy_coupling()
        return E

# =========================
# Visualization
# =========================
def plot_field(phi: np.ndarray, step: int, out_dir: str, cmap: str = "seismic", name: str = "phi"):
    ensure_dir(out_dir)
    plt.figure(figsize=(6, 6))
    im = plt.imshow(phi.T, cmap=cmap, origin="lower")
    plt.colorbar(im, fraction=0.046, pad=0.04, label=name)
    plt.title(f"{name} at step {step}")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.tight_layout()
    path = os.path.join(out_dir, f"{name}_{step:06d}.png")
    plt.savefig(path, dpi=120)
    plt.close()

def plot_spectrum(P: np.ndarray, k: np.ndarray, prof: np.ndarray, step: int, out_dir: str):
    ensure_dir(out_dir)
    # 2D spectrum image
    plt.figure(figsize=(6, 6))
    im = plt.imshow(np.log10(P.T + 1e-12), origin="lower", cmap="magma")
    plt.colorbar(im, fraction=0.046, pad=0.04, label="log10 Power")
    plt.title(f"Power spectrum |FFT(φ)|^2 at step {step}")
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, f"spectrum2d_{step:06d}.png"), dpi=120)
    plt.close()
    # Radial profile
    plt.figure(figsize=(6, 4))
    plt.plot(k, prof + 1e-16)
    plt.yscale("log")
    plt.xlabel("k (pixels)")
    plt.ylabel("Radial power")
    plt.title(f"Radial spectrum at step {step}")
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, f"spectrum1d_{step:06d}.png"), dpi=120)
    plt.close()
    # Save CSV
    with open(os.path.join(out_dir, f"spectrum1d_{step:06d}.csv"), "w", encoding="utf-8") as f:
        f.write("k,Pk\n")
        for ki, pi in zip(k, prof):
            f.write(f"{int(ki)},{pi:.8e}\n")

def plot_map(img: np.ndarray, step: int, out_dir: str, title: str, name: str, cmap: str = "viridis"):
    ensure_dir(out_dir)
    plt.figure(figsize=(6, 6))
    im = plt.imshow(img.T, origin="lower", cmap=cmap)
    plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.title(f"{title} at step {step}")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, f"{name}_{step:06d}.png"), dpi=120)
    plt.close()

# =========================
# Live viewer
# =========================
class LiveViewer:
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.enabled = bool(cfg.live)
        if not self.enabled:
            self.fig = None
            return
        plt.ion()
        self.fig = plt.figure(figsize=(10, 4))
        self.ax_field = self.fig.add_subplot(1, 2, 1)
        self.ax_energy = self.fig.add_subplot(1, 2, 2)
        self.im = None
        self.energy_trace = []

    def update(self, phi: np.ndarray, step: int, energy: float):
        if not self.enabled:
            return
        self.energy_trace.append((step, energy))
        if self.im is None:
            self.im = self.ax_field.imshow(phi.T, origin="lower", cmap=self.cfg.cmap)
            self.ax_field.set_title(f"φ at step {step}")
        else:
            self.im.set_data(phi.T)
            self.ax_field.set_title(f"φ at step {step}")
        self.ax_field.figure.colorbar(self.im, ax=self.ax_field, fraction=0.046, pad=0.04)
        self.ax_energy.cla()
        if len(self.energy_trace) > 0:
            s, e = zip(*self.energy_trace)
            self.ax_energy.plot(s, e, '-k')
        self.ax_energy.set_title("Total energy")
        self.ax_energy.set_xlabel("step")
        self.ax_energy.set_ylabel("E")
        self.fig.tight_layout()
        plt.pause(0.001)

# =========================
# Runner
# =========================
def run(cfg: Config):
    # Prepare output
    out = cfg.out_dir
    frames_dir = os.path.join(out, "frames")
    spec_dir = os.path.join(out, "spectrum")
    maps_dir = os.path.join(out, "maps")
    ensure_dir(out); ensure_dir(frames_dir); ensure_dir(spec_dir); ensure_dir(maps_dir)
    save_json(os.path.join(out, "config.json"), asdict(cfg))

    # Seed RNG
    np.random.seed(cfg.seed)

    sim = FieldSimulator(cfg)
    viewer = LiveViewer(cfg)

    # CSV logs
    energy_path = os.path.join(out, "energy.csv")
    with open(energy_path, "w", encoding="utf-8") as f:
        if cfg.use_chi:
            f.write("step,E_total,Ek_phi,Eg_phi,Ep_phi,Ek_chi,Eg_chi,Ep_chi,E_couple,wall_length\n")
        else:
            f.write("step,E_total,Ek_phi,Eg_phi,Ep_phi,wall_length\n")

    info("[INFO] Starting simulation...")
    t0 = time.time()

    # For plateau detection
    E_hist: list[float] = []
    S_hist: list[int] = []

    for step in range(cfg.steps + 1):
        # Save analytics and frames
        if step % cfg.save_every == 0:
            Ekφ, Egφ, Epφ = sim.energy_components_phi()
            if cfg.use_chi:
                Ekχ, Egχ, Epχ = sim.energy_components_chi()
                E_c = sim.energy_coupling()
                E = Ekφ + Egφ + Epφ + Ekχ + Egχ + Epχ + E_c
            else:
                Ekχ = Egχ = Epχ = E_c = 0.0
                E = Ekφ + Egφ + Epφ

            L = domain_wall_length(sim.phi, cfg.dx)

            with open(energy_path, "a", encoding="utf-8") as f:
                if cfg.use_chi:
                    f.write(f"{step},{E:.8e},{Ekφ:.8e},{Egφ:.8e},{Epφ:.8e},{Ekχ:.8e},{Egχ:.8e},{Epχ:.8e},{E_c:.8e},{L:.8e}\n")
                else:
                    f.write(f"{step},{E:.8e},{Ekφ:.8e},{Egφ:.8e},{Epφ:.8e},{L:.8e}\n")

            plot_field(sim.phi, step, frames_dir, cmap=cfg.cmap, name="phi")
            if cfg.use_chi:
                plot_field(sim.chi, step, frames_dir, cmap="coolwarm", name="chi")
            info(f"Step {step:6d} | Energy = {E:.6e}")

        # Structure factor
        if cfg.struct_every > 0 and step % cfg.struct_every == 0:
            P, k, prof = power_spectrum(sim.phi)
            plot_spectrum(P, k, prof, step, spec_dir)

        # Entropy and complexity maps
        if cfg.entropy_every > 0 and step % cfg.entropy_every == 0:
            H = sign_entropy_map(sim.phi, w=21)
            plot_map(H, step, maps_dir, "Sign entropy (w=21)", "entropy", cmap="viridis")
        if cfg.complexity_every > 0 and step % cfg.complexity_every == 0:
            C = grad_complexity_map(sim.phi, cfg.dx, w=21)
            plot_map(C, step, maps_dir, "Gradient complexity (w=21)", "complexity", cmap="inferno")

        # Live viewer
        if cfg.live and (step % cfg.live_every == 0):
            Etot = sim.total_energy()
            viewer.update(sim.phi, step, Etot)

        # Plateau detection/action
        if cfg.plateau_action != "none":
            # Collect history every save_every for smoother slope
            if step % cfg.save_every == 0:
                Etot = sim.total_energy()
                E_hist.append(Etot); S_hist.append(step)
                if len(E_hist) >= max(3, cfg.plateau_window // max(1, cfg.save_every)):
                    # Linear fit slope on recent window
                    m = max(2, cfg.plateau_window // cfg.save_every)
                    Ew = np.array(E_hist[-m:])
                    Sw = np.array(S_hist[-m:], dtype=np.float64)
                    Sw = Sw - Sw[0]
                    denom = np.dot(Sw, Sw) + 1e-12
                    slope = float(np.dot(Sw, Ew - Ew[0]) / denom)  # dE/dstep
                    rel = abs(slope) / (abs(Ew.mean()) + 1e-12)
                    if rel < cfg.plateau_tol:
                        if cfg.plateau_action == "add_damping":
                            if sim.eta < cfg.plateau_eta_target:
                                sim.eta = cfg.plateau_eta_target
                                info(f"[ADAPT] Plateau detected (rel slope {rel:.2e}). Setting η -> {sim.eta:.3g}")

        # Advance
        if step < cfg.steps:
            sim.step()

    dt_sec = time.time() - t0
    info(f"[DONE] Completed {cfg.steps} steps in {dt_sec:.2f}s. Outputs in: {out}")

# =========================
# CLI (notebook-safe)
# =========================
def build_parser():
    p = argparse.ArgumentParser(
        description="2D φ⁴ field simulator with optional χ coupling, Langevin noise, analytics, and live viewer."
    )
    # Grid/time
    p.add_argument("--nx", type=int)
    p.add_argument("--ny", type=int)
    p.add_argument("--dx", type=float)
    p.add_argument("--dt", type=float)
    p.add_argument("--steps", type=int)

    # Physics φ
    p.add_argument("--c", type=float)
    p.add_argument("--lam", type=float)
    p.add_argument("--v", type=float)
    p.add_argument("--eta", type=float)

    # χ field
    p.add_argument("--use_chi", action="store_true")
    p.add_argument("--c_chi", type=float)
    p.add_argument("--mu_chi", type=float)
    p.add_argument("--eta_chi", type=float)
    p.add_argument("--g_couple", type=float)

    # Absorption
    p.add_argument("--absorb_width", type=int)
    p.add_argument("--absorb_eta", type=float)

    # Noise
    p.add_argument("--langevin_sigma", type=float)

    # Init
    p.add_argument("--init", choices=["random", "blob", "bias"])
    p.add_argument("--random_amp", type=float)
    p.add_argument("--blob_sigma", type=float)
    p.add_argument("--bias_value", type=float)

    # Output
    p.add_argument("--out_dir", type=str)
    p.add_argument("--save_every", type=int)
    p.add_argument("--cmap", type=str)
    p.add_argument("--seed", type=int)
    p.add_argument("--dtype", choices=["float32", "float64"])

    # Analytics cadence
    p.add_argument("--struct_every", type=int)
    p.add_argument("--entropy_every", type=int)
    p.add_argument("--complexity_every", type=int)

    # Live viewer
    p.add_argument("--live", action="store_true")
    p.add_argument("--live_every", type=int)

    # Plateau control
    p.add_argument("--plateau_window", type=int)
    p.add_argument("--plateau_tol", type=float)
    p.add_argument("--plateau_action", choices=["none", "add_damping"])
    p.add_argument("--plateau_eta_target", type=float)

    return p

def merge_args_into_cfg(args, cfg: Config) -> Config:
    d = asdict(cfg)
    for k, v in vars(args).items():
        if v is not None and k in d:
            d[k] = v
    return Config(**d)

def main(argv=None):
    parser = build_parser()
    args, _ = parser.parse_known_args(argv)  # notebook-safe
    cfg = merge_args_into_cfg(args, Config())
    run(cfg)

if __name__ == "__main__":
    main()