In [None]:
import sys
sys.path.append('/home/aistudio/external-libraries')

import os
import numpy as np
from functools import partial
from typing import NamedTuple

import jax
import jax.numpy as jnp
from jax import lax
jax.config.update("jax_enable_x64", True)

from matplotlib import font_manager
import matplotlib.pyplot as plt

font_path = '/home/aistudio/Times_New_Roman.ttf'
prop = font_manager.FontProperties(fname=font_path)

In [2]:
class Params(NamedTuple):
    # PDE domain
    Lx: float
    Ly: float

    # Grid resolution
    Nx: int
    Ny: int

    # Time stepping
    dt: float
    T: float

    # Initial conditions
    A: float         # amplitude for cos(kx x) cos(ky y)
    v0_const: float  # constant initial velocity

    # PDE physical parameters
    p: float
    alpha: float
    beta: float
    lam: float
    gam: float
    eps: float

    # Output control
    record_ratio: float = 0.01   # default: record every 1%


# -------------------- Utility Functions --------------------
def make_wavenumbers(Nx, Ny, Lx, Ly, dtype=jnp.float32):
    """Generate wavenumber grids in spectral space"""
    kx = 2 * jnp.pi * jnp.fft.fftfreq(Nx, d=Lx / Nx)
    ky = 2 * jnp.pi * jnp.fft.fftfreq(Ny, d=Ly / Ny)
    KX, KY = jnp.meshgrid(kx.astype(dtype), ky.astype(dtype), indexing="ij")
    K2 = KX**2 + KY**2
    return KX, KY, K2


def dealias_mask(Nx, Ny, frac=2/3):
    """2/3 dealiasing rule: retain low modes along each axis"""
    cx = int((frac * (Nx / 2)) // 1)
    cy = int((frac * (Ny / 2)) // 1)
    ax = jnp.arange(Nx)
    ay = jnp.arange(Ny)
    keep_x = (ax <= cx) | (ax >= Nx - cx)
    keep_y = (ay <= cy) | (ay >= Ny - cy)
    return keep_x[:, None] & keep_y[None, :]


def to_phys(hat_field):
    """Spectral → Physical space"""
    return jnp.fft.ifft2(hat_field).real


def to_spec(field):
    """Physical → Spectral space"""
    return jnp.fft.fft2(field)


# -------------------- Nonlinear Term N(u) (JAX) --------------------
def N_hat(u_hat, params: Params, KX, KY, K2, mask):
    """
    N(u) = Δ_p u + λ|u|^α + γ|∇u|^β
    where Δ_p u = div( a(|∇u|) ∇u ), a=(|∇u|^2 + eps^2)^((p-2)/2)
    """
    p     = params.p
    alpha = params.alpha
    beta  = params.beta
    lam   = params.lam
    gam   = params.gam
    eps   = params.eps

    u  = to_phys(u_hat)                 # Physical space field
    ux = to_phys(1j * KX * u_hat)       # x-derivative
    uy = to_phys(1j * KY * u_hat)       # y-derivative

    grad2 = ux * ux + uy * uy           # |∇u|^2

    # p-Laplacian flux q = a ∇u
    a  = (grad2 + eps**2) ** ((p - 2.0) / 2.0)
    qx = a * ux
    qy = a * uy

    qx_hat = to_spec(qx)
    qy_hat = to_spec(qy)
    zqx = jnp.zeros_like(qx_hat); zqy = jnp.zeros_like(qy_hat)
    qx_hat = jnp.where(mask, qx_hat, zqx)    # Apply dealiasing
    qy_hat = jnp.where(mask, qy_hat, zqy)

    divq_hat = 1j * KX * qx_hat + 1j * KY * qy_hat

    # Source terms
    src1_hat = to_spec(lam * jnp.abs(u) ** alpha)
    src1_hat = jnp.where(mask, src1_hat, jnp.zeros_like(src1_hat))

    grad_beta_hat = to_spec(gam * (grad2) ** (beta / 2.0))
    grad_beta_hat = jnp.where(mask, grad_beta_hat, jnp.zeros_like(grad_beta_hat))

    return divq_hat + src1_hat + grad_beta_hat


# -------------------- Single IF-RK4 Step --------------------
@partial(jax.jit, static_argnames=("params",))
def if_rk4_step(u_hat, v_hat, dt, params: Params, KX, KY, K2, mask, E, E2):
    """Integrating Factor RK4 step for second-order wave equation"""
    # --- Stage 1 ---
    N1  = N_hat(u_hat, params, KX, KY, K2, mask)
    v_a = E2 * (v_hat + 0.5 * dt * N1)
    u_a = jnp.where(mask, u_hat + 0.5 * dt * v_hat, jnp.zeros_like(u_hat))

    # --- Stage 2 ---
    N2  = N_hat(u_a, params, KX, KY, K2, mask)
    v_b = E2 * (v_hat + 0.5 * dt * N2)
    u_b = jnp.where(mask, u_hat + 0.5 * dt * v_a, jnp.zeros_like(u_hat))

    # --- Stage 3 ---
    N3  = N_hat(u_b, params, KX, KY, K2, mask)
    v_c = E * (v_hat + dt * N3)
    u_c = jnp.where(mask, u_hat + dt * v_b, jnp.zeros_like(u_hat))

    # --- Stage 4 ---
    N4  = N_hat(u_c, params, KX, KY, K2, mask)

    # Combine stages with integrating factors
    v_next = E * v_hat + (dt / 6.0) * (E * N1 + 2.0 * E2 * N2 + 2.0 * E2 * N3 + N4)
    u_next = u_hat + (dt / 6.0) * (v_hat + 2.0 * v_a + 2.0 * v_b + v_c)

    # Apply dealiasing mask
    u_next = jnp.where(mask, u_next, jnp.zeros_like(u_next))
    v_next = jnp.where(mask, v_next, jnp.zeros_like(v_next))
    return u_next, v_next


# -------------------- Main Solver --------------------
def simulate_2d(u0, v0, Lx, Ly, Nx, Ny, dt, nt, params: Params,
                dealias_frac=2/3, record_every=0):
    """Main simulation driver for 2D nonlinear wave equation"""
    # Precompute constants: grids, mask, integrating factors
    KX, KY, K2 = make_wavenumbers(Nx, Ny, Lx, Ly, dtype=jnp.float32)
    mask = dealias_mask(Nx, Ny, frac=dealias_frac)
    Lh   = -K2 * dt
    E    = jnp.exp(Lh)        # e^{hΔ}
    E2   = jnp.exp(Lh / 2.0)  # e^{hΔ/2}

    # Initial conditions: convert to spectral and apply dealiasing
    u_hat = to_spec(jnp.asarray(u0))
    v_hat = to_spec(jnp.asarray(v0))
    zc = jnp.zeros_like(u_hat)
    u_hat = jnp.where(mask, u_hat, zc)
    v_hat = jnp.where(mask, v_hat, zc)

    # Robust blow-up cutoff in physical space
    STOP_THRESHOLD = 1e12

    # ---------------------- no history recording: use fori_loop (memory efficient) ----------------------
    if record_every == 0:

        def body(i, carry):
            u_h, v_h, stop_flag, step_idx = carry

            # if already stopped, skip computation
            def compute_step(_):
                u_n, v_n = if_rk4_step(u_h, v_h, dt, params, KX, KY, K2, mask, E, E2)
                u_phys = to_phys(u_n)
                maxu = jnp.max(jnp.abs(u_phys))

                # robust blow-up detection: large, NaN or inf
                new_flag = jnp.logical_or(
                    maxu > STOP_THRESHOLD,
                    jnp.logical_or(jnp.isnan(maxu), jnp.isinf(maxu))
                )
                return (u_n, v_n, new_flag)

            u_new, v_new, new_flag = lax.cond(
                stop_flag,
                lambda _: (u_h, v_h, stop_flag),
                compute_step,
                operand=None
            )

            # step counter only increases when not yet stopped
            step_idx = lax.cond(stop_flag, lambda x: x, lambda x: x + 1, step_idx)

            return (u_new, v_new, new_flag, step_idx)

        u_hat, v_hat, stop_flag, stop_iter = lax.fori_loop(
            0, nt, body,
            (u_hat, v_hat, False, jnp.array(0, dtype=jnp.int32))
        )
        # Hist is None in no-history mode
        return u_hat, v_hat, None, int(stop_iter)

    # ---------------------- history recording mode: preallocate buffer and write periodically ----------------------
    n_records = nt // record_every + 1  # Include initial condition
    hist0 = to_phys(u_hat)[None, ...]   # (1, Nx, Ny)
    hist_buf = jnp.concatenate(
        [hist0, jnp.zeros((n_records-1, Nx, Ny), dtype=hist0.dtype)],
        axis=0
    )

    def scan_step(carry, i):
        u_h, v_h, rec_i, hist, stop_flag, step_idx = carry

        def compute_step(_):
            u_n, v_n = if_rk4_step(u_h, v_h, dt, params, KX, KY, K2, mask, E, E2)
            u_phys = to_phys(u_n)
            maxu = jnp.max(jnp.abs(u_phys))

            # robust blow-up detection: large, NaN or inf
            new_flag = jnp.logical_or(
                maxu > STOP_THRESHOLD,
                jnp.logical_or(jnp.isnan(maxu), jnp.isinf(maxu))
            )
            return (u_n, v_n, new_flag, u_phys)

        # if already stopped → freeze solution
        u_n, v_n, new_flag, u_phys = lax.cond(
            stop_flag,
            lambda _: (u_h, v_h, True, to_phys(u_h)),
            compute_step,
            operand=None
        )

        # Check if we should record at step i+1
        do_rec = ((i + 1) % record_every) == 0

        def _write(hist_in):
            return hist_in.at[rec_i].set(u_phys)

        hist = lax.cond(do_rec, _write, lambda x: x, hist)
        rec_i = rec_i + jnp.where(do_rec, 1, 0)

        # step index only increases before global stop
        step_idx = lax.cond(stop_flag, lambda x: x, lambda x: x + 1, step_idx)

        return (u_n, v_n, rec_i, hist, new_flag, step_idx), None

    (u_hat, v_hat, _, hist_buf, stop_flag, stop_iter), _ = lax.scan(
        scan_step,
        (u_hat, v_hat, jnp.array(1, dtype=jnp.int32), hist_buf,
         False, jnp.array(0, dtype=jnp.int32)),
        jnp.arange(nt)
    )
    return u_hat, v_hat, hist_buf, int(stop_iter)


# -------------------- Demonstration --------------------
def solve_pde(params: Params):
    """
    PDE solver using parameters contained in a Params object.
    """

    # Extract parameters
    Lx, Ly = params.Lx, params.Ly
    Nx, Ny = params.Nx, params.Ny
    dt, T  = params.dt, params.T
    A = params.A
    v0_const = params.v0_const
    record_ratio = params.record_ratio

    nt = int(T / dt)
    record_every = max(1, int(nt * record_ratio))

    print(f"Starting simulation: Nx={Nx}, Ny={Ny}, dt={dt}, T={T}")
    print(f"Physical parameters: p={params.p}, α={params.alpha}, β={params.beta}")

    # --- Build grid ---
    x = jnp.linspace(0, Lx, Nx, endpoint=False)
    y = jnp.linspace(0, Ly, Ny, endpoint=False)
    X, Y = jnp.meshgrid(x, y, indexing="ij")

    # --- Initial condition ---
    kxx = 2 * jnp.pi / Lx
    kyy = 2 * jnp.pi / Ly
    u0 = A * jnp.cos(kxx * X) * jnp.cos(kyy * Y)
    v0 = v0_const * jnp.ones_like(u0)

    # --- Run simulation ---
    u_hat, v_hat, hist, stop_iter = simulate_2d(
        u0, v0, Lx, Ly, Nx, Ny, dt, nt, params,
        dealias_frac=2/3,
        record_every=record_every
    )

    # Numerical blow-up / stopping time
    t_blow = stop_iter * dt

    uT = to_phys(u_hat)

    # --- Collect stats ---
    stats = {
        "u_min": float(uT.min().item()),
        "u_max": float(uT.max().item()),
        "u_mean": float(uT.mean().item()),
        "nt": nt,
        "Nx": Nx, "Ny": Ny,
        "dt": float(dt),
        "T": float(T),
        "Lx": float(Lx),
        "Ly": float(Ly),
        "p": float(params.p),
        "alpha": float(params.alpha),
        "beta": float(params.beta),
        "eps": float(params.eps),
        "lam": float(params.lam),
        "gam": float(params.gam),
        "record_every": int(record_every),
        "stop_iter": int(stop_iter),
        "t_blow": float(t_blow),
    }

    print("Simulation completed! Final field statistics:")
    for k, v in stats.items():
        if isinstance(v, float):
            print(f"{k}: {v:.6g}")
        else:
            print(f"{k}: {v}")

    print(f"Detected numerical blow-up (cutoff) at t ≈ {t_blow:.6g}, step = {stop_iter}")

    return uT, hist, stats


def visualize_blowup(hist, Lx, Ly, T, stats,
                     cmap='RdBu_r',
                     prefix='blowup',
                     save_dir="results"):

    hist = np.array(hist)

    # Grid construction
    Nx, Ny = hist[0].shape
    x = np.linspace(0, Lx, Nx, endpoint=False)
    y = np.linspace(0, Ly, Ny, endpoint=False)
    X, Y = np.meshgrid(x, y, indexing='ij')

    # Time array based on dt and record_every
    n_frames = len(hist)
    dt = stats["dt"]
    nt = stats["nt"]
    record_every = stats.get("record_every", max(1, int(nt * stats.get("record_ratio", 0.01))))

    times = np.arange(n_frames) * dt * record_every

    # Truncate if solver stopped early
    if "t_blow" in stats:
        t_blow = stats["t_blow"]
        idx_valid = np.where(times <= t_blow + 1e-12)[0]
        if len(idx_valid) > 0:
            last_idx = idx_valid[-1]
        else:
            last_idx = n_frames - 1
        hist = hist[:last_idx + 1]
        times = times[:last_idx + 1]
        n_frames = len(hist)
        print(f"Truncated history to {n_frames} frames up to t ≈ {t_blow:.6f}")
    else:
        t_blow = T

    # Compute diagnostics
    max_vals = np.array([np.max(np.abs(f)) for f in hist])

    # Proper L2 norm: sqrt( ∫ u^2 dx dy )
    dx = Lx / Nx
    dy = Ly / Ny
    l2_vals = np.array([np.sqrt(np.sum(f**2) * dx * dy) for f in hist])

    # Find nearest index to blow-up time
    idx_near = int(np.argmin(np.abs(times - t_blow)))
    blow_time = times[idx_near]

    print(f"Blow-up time used in plot: t = {blow_time:.6f} (frame {idx_near})")

    # ============================================================
    # 1. Time evolution plot (max|u| and L2 norm)
    # ============================================================
    fig, ax = plt.subplots(figsize=(4, 6))

    ax.plot(times, max_vals, 'r-', label=r'$\max|u|$', linewidth=2.5)
    ax.plot(times, l2_vals, 'b--', label=r'$L_2 \text{ norm}$', linewidth=2)
    ax.axvline(blow_time, color='k', linestyle=':', label=f'Blow-up t={blow_time:.3f}')

    ax.set_yscale("log")
    y_min = max(max_vals.min() * 0.5, 1e-12)
    y_max = 1e10*5
    ax.set_ylim(y_min, y_max)

    # ------------------------
    # Apply fontproperties=prop
    # ------------------------
    ax.set_xlabel("Time", fontproperties=prop)
    ax.set_ylabel("Magnitude", fontproperties=prop)
    ax.set_title(
        f"Evolution (p={stats['p']}, α={stats['alpha']}, β={stats['beta']})",
        fontproperties=prop
    )

    # Axis tick fonts
    ax.tick_params(axis='both', labelsize=10)
    for label in ax.get_xticklabels():
        label.set_fontproperties(prop)
    for label in ax.get_yticklabels():
        label.set_fontproperties(prop)

    ax.grid(which="both", alpha=0.5)
    legend = ax.legend()
    for t in legend.get_texts():
        t.set_fontproperties(prop)

    plt.tight_layout()

    fname = os.path.join(save_dir, f"{prefix}_time_plot.png")
    plt.savefig(fname, dpi=300, bbox_inches='tight')
    plt.close()
    print("Saved:", fname)

    # ============================================================
    # 2. Initial field (t = 0)
    # ============================================================
    fig, ax = plt.subplots(figsize=(8, 6))
    f0 = hist[0]
    vmax = np.max(np.abs(f0))
    vmin = -vmax

    pcm = ax.pcolormesh(X, Y, f0, cmap=cmap, shading='auto', vmin=vmin, vmax=vmax)
    plt.colorbar(pcm)

    ax.contour(X, Y, f0, levels=12, colors='k', linewidths=0.6, alpha=0.5)

    ax.set_title("Initial state (t=0)", fontproperties=prop)
    ax.set_aspect("equal")

    # Tick font
    for label in ax.get_xticklabels():
        label.set_fontproperties(prop)
    for label in ax.get_yticklabels():
        label.set_fontproperties(prop)

    fname = os.path.join(save_dir, f"{prefix}_initial.png")
    plt.savefig(fname, dpi=300, bbox_inches='tight')
    plt.close()
    print("Saved:", fname)

    # ============================================================
    # 3. Field immediately before blow-up
    # ============================================================
    near_blow_idx = np.where(max_vals > 100 * max_vals[0])[0]

    if len(near_blow_idx) > 0:
        blow_frame = near_blow_idx[0]
        idx_pre = max(0, blow_frame - 1)
    else:
        idx_pre = n_frames // 2

    f_pre = hist[idx_pre]
    vmax = np.max(np.abs(f_pre))
    vmin = -vmax

    fig, ax = plt.subplots(figsize=(8, 6))

    pcm = ax.pcolormesh(X, Y, f_pre, shading='auto', cmap=cmap, vmin=vmin, vmax=vmax)
    plt.colorbar(pcm)

    if vmax < 1e6:
        ax.contour(X, Y, f_pre, levels=12, colors='k', linewidths=0.6, alpha=0.5)

    ax.set_title(f"Near blow-up (t={times[idx_pre]:.3f})", fontproperties=prop)
    ax.set_aspect("equal")

    for label in ax.get_xticklabels():
        label.set_fontproperties(prop)
    for label in ax.get_yticklabels():
        label.set_fontproperties(prop)

    fname = os.path.join(save_dir, f"{prefix}_pre_blowup.png")
    plt.savefig(fname, dpi=300, bbox_inches='tight')
    plt.close()
    print("Saved:", fname)

    return blow_time

In [3]:
# [p, alpha, beta]
param_groups = [
    # First four groups: fixed alpha=2.0, beta=3.0, varying p
    [2.9, 2.0, 3.0],  # Condition 2
    [3.0, 2.0, 3.0],  # Condition 2
    [3.5, 2.0, 3.0],  # Condition 2
    [3.9, 2.0, 3.0],  # Condition 1
    
    # Next two groups: fixed p=3.0, beta=3.0, varying alpha
    [3.0, 1.5, 3.0],  # Condition 2
    [3.0, 2.5, 3.0],  # Condition 2
    
    # Last two groups: fixed p=3.0, alpha=2.0, varying beta
    [3.0, 2.0, 2.1],  # Condition 1
    [3.0, 2.0, 2.5]   # Condition 2
]

for group in param_groups:
    p_val, alpha_val, beta_val = group
    
    print(f"\n====== Running group: p={p_val}, alpha={alpha_val}, beta={beta_val} ======")

    params = Params(
        Lx=8*jnp.pi, 
        Ly=8*jnp.pi,
        Nx=256,
        Ny=256,
        dt=1e-3, 
        T=10,
        A=1.2,
        v0_const=1e-3,
        p=p_val,
        alpha=alpha_val,
        beta=beta_val,
        lam=2.5,
        gam=2.5,
        eps=1e-10,
        record_ratio=1e-4
    )

    uT, hist, stats = solve_pde(params)

    save_dir = rf"/home/aistudio/PDE(25.11.5)/result/p_{p_val}_alpha_{alpha_val}_beta_{beta_val}"
    os.makedirs(save_dir, exist_ok=True)

    blow_t = visualize_blowup(
        hist,
        stats['Lx'],
        stats['Ly'],
        stats['T'],
        stats,
        save_dir=save_dir
    )

    print(f"=== Finished group p={p_val}, α={alpha_val}, β={beta_val}, blow_t={blow_t} ===")


Starting simulation: Nx=256, Ny=256, dt=0.001, T=10
Physical parameters: p=2.9, α=2.0, β=3.0
Simulation completed! Final field statistics:
u_min: -9.22256e+86
u_max: 1.75661e+88
u_mean: 1.70644e+85
nt: 10000
Nx: 256
Ny: 256
dt: 0.001
T: 10
Lx: 25.1327
Ly: 25.1327
p: 2.9
alpha: 2
beta: 3
eps: 1e-10
lam: 2.5
gam: 2.5
record_every: 1
stop_iter: 1606
t_blow: 1.606
Detected numerical blow-up (cutoff) at t ≈ 1.606, step = 1606
Truncated history to 1607 frames up to t ≈ 1.606000
Blow-up time used in plot: t = 1.606000 (frame 1606)
Saved: /home/aistudio/PDE(25.11.5)/result/p_2.9_alpha_2.0_beta_3.0/blowup_time_plot.png
Saved: /home/aistudio/PDE(25.11.5)/result/p_2.9_alpha_2.0_beta_3.0/blowup_initial.png
Saved: /home/aistudio/PDE(25.11.5)/result/p_2.9_alpha_2.0_beta_3.0/blowup_pre_blowup.png
=== Finished group p=2.9, α=2.0, β=3.0, blow_t=1.606 ===

Starting simulation: Nx=256, Ny=256, dt=0.001, T=10
Physical parameters: p=3.0, α=2.0, β=3.0
Simulation completed! 