In [41]:
# ======================================================================
#  - Advanced = Arakawa + FFT Poisson
#  - Naive    = Central-diff Jacobian + Jacobi Poisson
#  - CoSTA    = Naive + learned σ1 (Eq.1) and σ2 (Eq.2)
#  - Residual targets computed with NAIVE diagnostic operators
#  - Stability guard on σ1, σ2
#  - Final animation: (Advanced | Naive | Naive+CoSTA)
# ======================================================================

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display
from scipy.fft import fft2, ifft2, fftfreq

In [42]:
# Try scikit-learn; otherwise fallback to tiny NumPy MLP
USE_SKLEARN = True
try:
    from sklearn.neural_network import MLPRegressor
except Exception:
    USE_SKLEARN = False

np.seterr(all='warn')
plt.rcParams['figure.dpi'] = 120
DTYPE = np.float64

# -----------------------------
# Parameters (you can tweak)
# -----------------------------
NX = NY = 64         # grid (interior = NX+1 by NY+1 because of ghost layout)
RE = 560.0
NU = 1.0/RE
DT = 0.01

TRAIN_CASES  = 1     # number of randomized ICs for training data
TRAIN_STEPS  = 10    # steps per IC to build samples
ROLL_STEPS   = 300   # final rollout steps for animation
JITTER       = 0.0  # small IC jitter during training to diversify data
SEED         = 7

# Stability guard knobs
ANNEAL_STEPS       = 50   # ramp in sigma over these many steps
ALPHA_REL          = 1.0  # cap |σ1| per-step ≤ ALPHA_REL * |ν ∇²ω| / dt
S2_SMOOTH_PASSES   = 1    # blur σ2 (avoid grid noise)
SNAP_EVERY         = 20   # store every k-th frame for animation

In [43]:
def bc(nx,ny,u):
    u[:,0] = u[:,ny]
    u[:,ny+2] = u[:,2]
    u[0,:] = u[nx,:]
    u[nx+2,:] = u[2,:]
    return u

def make_xy(nx, ny):
    x = np.linspace(0.0, 2*np.pi, nx+1)
    y = np.linspace(0.0, 2*np.pi, ny+1)
    return np.meshgrid(x, y, indexing='ij')

def vm_ic(nx,ny,x,y):
    # two Gaussians (exact IC used in both notebooks)
    w = np.empty((nx+3,ny+3))
    sigma = np.pi
    xc1, yc1 = np.pi-np.pi/4.0, np.pi
    xc2, yc2 = np.pi+np.pi/4.0, np.pi
    w[1:nx+2, 1:ny+2] = np.exp(-sigma*((x[0:nx+1, 0:ny+1]-xc1)**2 + (y[0:nx+1, 0:ny+1]-yc1)**2)) \
                       + np.exp(-sigma*((x[0:nx+1, 0:ny+1]-xc2)**2 + (y[0:nx+1, 0:ny+1]-yc2)**2))
    return bc(nx,ny,w)

In [44]:
# =============================
# Time stepping (SSP-RK3)
# =============================
def rk3_step(w, rhs_fun, nx, ny, dx, dy, Re, dt, x, y, t):
    r1 = rhs_fun(nx, ny, dx, dy, Re, w, None, x, y, t) if rhs_fun is rhs_naive else rhs_fun(nx,ny,dx,dy,Re,w,None,x,y,t)
    # above line calls with s=None to get shape; we'll pass real s via wrappers
    raise RuntimeError("Use wrappers rk3_advanced / rk3_naive for clarity.")

### Advanced solver

In [45]:
# ---- Advanced Poisson (FFT)
def poisson_fft(nx, ny, dx, dy, rhs_full):
    rhs = rhs_full[1:nx+2, 1:ny+2]
    f_hat = fft2(rhs)
    kx = fftfreq(nx+1, d=dx/(2*np.pi))
    ky = fftfreq(ny+1, d=dy/(2*np.pi))
    kx, ky = np.meshgrid(kx, ky, indexing='ij')
    ksq = kx**2 + ky**2
    ksq[0,0] = 1.0           # avoid divide by zero
    psi_hat = -f_hat/ksq     # (∇²ψ = rhs)  => ψ̂ = rhŝ/(-k²); with rhs = -ω we already pass rhs_full
    psi_hat[0,0] = 0.0       # zero-mean ψ
    ut = np.real(ifft2(psi_hat))
    u = np.empty((nx+3, ny+3))
    u[1:nx+2, 1:ny+2] = ut
    u[:, ny+1] = u[:, 1]
    u[nx+1, :] = u[1, :]
    return bc(nx, ny, u)

# ---- Advanced RHS (Arakawa)
def rhs_adv(nx,ny,dx,dy,re,w,s,x,y,ts):
    aa = 1.0/(dx*dx)
    bb = 1.0/(dy*dy)
    gg = 1.0/(4.0*dx*dy)
    hh = 1.0/3.0
    f = np.zeros((nx+3,ny+3))
    j1 = gg*((w[2:nx+3,1:ny+2]-w[0:nx+1,1:ny+2])*(s[1:nx+2,2:ny+3]-s[1:nx+2,0:ny+1]) \
             -(w[1:nx+2,2:ny+3]-w[1:nx+2,0:ny+1])*(s[2:nx+3,1:ny+2]-s[0:nx+1,1:ny+2]))
    j2 = gg*( w[2:nx+3,1:ny+2]*(s[2:nx+3,2:ny+3]-s[2:nx+3,0:ny+1]) \
            - w[0:nx+1,1:ny+2]*(s[0:nx+1,2:ny+3]-s[0:nx+1,0:ny+1]) \
            - w[1:nx+2,2:ny+3]*(s[2:nx+3,2:ny+3]-s[0:nx+1,2:ny+3]) \
            + w[1:nx+2,0:ny+1]*(s[2:nx+3,0:ny+1]-s[0:nx+1,0:ny+1]))
    j3 = gg*( w[2:nx+3,2:ny+3]*(s[1:nx+2,2:ny+3]-s[2:nx+3,1:ny+2]) \
            - w[0:nx+1,0:ny+1]*(s[0:nx+1,1:ny+2]-s[1:nx+2,0:ny+1]) \
            - w[0:nx+1,2:ny+3]*(s[1:nx+2,2:ny+3]-s[0:nx+1,1:ny+2]) \
            + w[2:nx+3,0:ny+1]*(s[2:nx+3,1:ny+2]-s[1:nx+2,0:ny+1]) )
    jac = (j1+j2+j3)*hh
    lap = aa*(w[2:nx+3,1:ny+2]-2.0*w[1:nx+2,1:ny+2]+w[0:nx+1,1:ny+2]) \
        + bb*(w[1:nx+2,2:ny+3]-2.0*w[1:nx+2,1:ny+2]+w[1:nx+2,0:ny+1])
    f[1:nx+2,1:ny+2] = -jac + lap/re 
    return f

def rk3_advanced(w, nx, ny, dx, dy, Re, dt, x, y, t):
    # Eq(2): ψ from FFT with rhs = -ω
    s = poisson_fft(nx, ny, dx, dy, -w)
    f1 = rhs_adv(nx, ny, dx, dy, Re, w, s, x, y, t)
    u  = bc(nx,ny, w + dt*f1)

    s = poisson_fft(nx, ny, dx, dy, -u)
    f2 = rhs_adv(nx, ny, dx, dy, Re, u, s, x, y, t+dt)
    v  = bc(nx,ny, 0.75*w + 0.25*(u + dt*f2))

    s = poisson_fft(nx, ny, dx, dy, -v)
    f3 = rhs_adv(nx, ny, dx, dy, Re, v, s, x, y, t+0.5*dt)
    w_new = bc(nx,ny, (1.0/3.0)*w + (2.0/3.0)*(v + dt*f3))
    return w_new

### Naive solver

In [46]:
# ---- Naive Poisson (Jacobi)
def poisson_jacobi(nx, ny, dx, dy, w):
    tol=1e-6
    max_iter=10000
    psi = np.zeros_like(w)
    rhs = -w
    dx2, dy2 = dx*dx, dy*dy
    denom = 2.0*(dx2 + dy2)
    for it in range(max_iter):
        psi_old = psi.copy()
        psi[1:nx+2,1:ny+2] = ((psi_old[2:nx+3,1:ny+2] + psi_old[0:nx+1,1:ny+2]) * dy2 +
                              (psi_old[1:nx+2,2:ny+3] + psi_old[1:nx+2,0:ny+1]) * dx2 -
                              rhs[1:nx+2,1:ny+2] * dx2 * dy2) / denom
        psi = bc(nx, ny, psi)
        err = np.linalg.norm(psi - psi_old) / (nx*ny)
        if err < tol:
            break
    return psi

# ---- Naive RHS (central-difference Jacobian)
def rhs_naive(nx, ny, dx, dy, Re, w, s, x, y, t):
    inv_dx = 1.0 / (2.0*dx)
    inv_dy = 1.0 / (2.0*dy)
    dpsi_dx   = (s[2:nx+3,1:ny+2] - s[0:nx+1,1:ny+2]) * inv_dx
    dpsi_dy   = (s[1:nx+2,2:ny+3] - s[1:nx+2,0:ny+1]) * inv_dy
    domega_dx = (w[2:nx+3,1:ny+2] - w[0:nx+1,1:ny+2]) * inv_dx
    domega_dy = (w[1:nx+2,2:ny+3] - w[1:nx+2,0:ny+1]) * inv_dy
    jac = dpsi_dx * domega_dy - dpsi_dy * domega_dx
    lap = ((w[2:nx+3,1:ny+2] - 2.0*w[1:nx+2,1:ny+2] + w[0:nx+1,1:ny+2]) / dx**2
         + (w[1:nx+2,2:ny+3] - 2.0*w[1:nx+2,1:ny+2] + w[1:nx+2,0:ny+1]) / dy**2)
    f = np.zeros_like(w)
    f[1:nx+2,1:ny+2] = -jac + lap/Re 
    return f

def rk3_naive(w, nx, ny, dx, dy, Re, dt, x, y, t):
    s = poisson_jacobi(nx, ny, dx, dy, -w)
    f1 = rhs_naive(nx, ny, dx, dy, Re, w, s, x, y, t)
    u  = bc(nx,ny, w + dt*f1)

    s = poisson_jacobi(nx, ny, dx, dy, -u)
    f2 = rhs_naive(nx, ny, dx, dy, Re, u, s, x, y, t+dt)
    v  = bc(nx,ny, 0.75*w + 0.25*(u + dt*f2))

    s = poisson_jacobi(nx, ny, dx, dy, -v)
    f3 = rhs_naive(nx, ny, dx, dy, Re, v, s, x, y, t+0.5*dt)
    w_new = bc(nx,ny, (1.0/3.0)*w + (2.0/3.0)*(v + dt*f3))
    return w_new

# Corrected (Naive + σ): add σ1 to Eq.1 RHS, and solve Poisson with rhs = -ω + σ2 (Jacobi)
def rk3_naive_with_sigma(w, nx, ny, dx, dy, Re, dt, x, y, t, sigma1_full, sigma2_full):
    s = poisson_jacobi(nx, ny, dx, dy, -w + sigma2_full)  # since ∇²ψ = -ω + σ2  -> rhs = -(ω - σ2)
    f1 = rhs_naive(nx, ny, dx, dy, Re, w, s, x, y, t) + sigma1_full
    u  = bc(nx,ny, w + dt*f1)

    s = poisson_jacobi(nx, ny, dx, dy, -u + sigma2_full)
    f2 = rhs_naive(nx, ny, dx, dy, Re, u, s, x, y, t+dt) + sigma1_full
    v  = bc(nx,ny, 0.75*w + 0.25*(u + dt*f2))

    s = poisson_jacobi(nx, ny, dx, dy, -v + sigma2_full)
    f3 = rhs_naive(nx, ny, dx, dy, Re, v, s, x, y, t+0.5*dt) + sigma1_full
    w_new = bc(nx,ny, (1.0/3.0)*w + (2.0/3.0)*(v + dt*f3))
    return w_new

# =============================
# Diagnostic residuals (NAIVE ops) for training targets
# =============================
def laplacian_cd_full(nx, ny, dx, dy, a):
    lap = ((a[2:nx+3,1:ny+2] - 2.0*a[1:nx+2,1:ny+2] + a[0:nx+1,1:ny+2]) / dx**2
         + (a[1:nx+2,2:ny+3] - 2.0*a[1:nx+2,1:ny+2] + a[1:nx+2,0:ny+1]) / dy**2)
    out = np.zeros_like(a)
    out[1:nx+2,1:ny+2] = lap
    return out

def jacobian_cd_full(nx, ny, dx, dy, psi, w):
    inv_dx = 1.0/(2*dx); inv_dy = 1.0/(2*dy)
    dpsi_dx   = (psi[2:nx+3,1:ny+2] - psi[0:nx+1,1:ny+2]) * inv_dx
    dpsi_dy   = (psi[1:nx+2,2:ny+3] - psi[1:nx+2,0:ny+1]) * inv_dy
    domega_dx = (w[2:nx+3,1:ny+2]   - w[0:nx+1,1:ny+2])   * inv_dx
    domega_dy = (w[1:nx+2,2:ny+3]   - w[1:nx+2,0:ny+1])   * inv_dy
    jac = dpsi_dx*domega_dy - dpsi_dy*domega_dx
    out = np.zeros_like(w); out[1:nx+2,1:ny+2] = jac
    return out

def residuals_naive_diag(nx, ny, dx, dy, nu, w_n, w_np1, dt):
    # ψ^n via Naive Poisson
    psi_n = poisson_jacobi(nx, ny, dx, dy, w_n)
    J_cd  = jacobian_cd_full(nx, ny, dx, dy, psi_n, w_n)
    Lapw  = laplacian_cd_full(nx, ny, dx, dy, w_n)
    R1 = (w_np1 - w_n)/dt + J_cd - nu*Lapw              
    Lappsi = laplacian_cd_full(nx, ny, dx, dy, psi_n)
    R2 = Lappsi + w_n                                    
    return R1, R2, psi_n

In [47]:
# =============================
# Build dataset 
# =============================
def build_dataset(nx=64, ny=64, Re=560.0, dt=0.01, n_cases=4, steps=10, jitter=0.1, seed=7):
    rng = np.random.RandomState(seed)
    nu = 1.0/Re
    dx = (2*np.pi)/nx; dy = (2*np.pi)/ny
    x, y = make_xy(nx, ny)
    Xs = []; Ys = []

    for c in range(n_cases):
        wA = vm_ic(nx,ny,x,y); wN = wA.copy()
        t = 0.0
        for s in range(steps):
            # One step Advanced & Naive (exact solvers)
            wA_next = rk3_advanced(wA, nx, ny, dx, dy, Re, dt, x, y, t)
            wN_next = rk3_naive(    wN, nx, ny, dx, dy, Re, dt, x, y, t)

            # Residuals with NAIVE operators at time n
            R1_N, R2_N, psiN = residuals_naive_diag(nx,ny,dx,dy,nu,wN,wN_next,dt)
            R1_A, R2_A, _    = residuals_naive_diag(nx,ny,dx,dy,nu,wA,wA_next,dt)

            sigma1 = R1_N - R1_A 
            sigma2 = R2_N - R2_A

            # 8-neighbor stencil features (from NAIVE fields at time n)
            wm = wN[0:nx+1, 1:ny+2];  wp = wN[2:nx+3, 1:ny+2]
            wn = wN[1:nx+2, 2:ny+3];  ws = wN[1:nx+2, 0:ny+1]
            pm = psiN[0:nx+1, 1:ny+2]; pp = psiN[2:nx+3, 1:ny+2]
            pn = psiN[1:nx+2, 2:ny+3]; ps = psiN[1:nx+2, 0:ny+1]

            s1I = sigma1[1:nx+2,1:ny+2]; s2I = sigma2[1:nx+2,1:ny+2]
            feats = np.stack([wm, wp, wn, ws, pm, pp, pn, ps], axis=-1)
            Xs.append(feats.reshape(-1, 8))
            Ys.append(np.stack([s1I, s2I], axis=-1).reshape(-1,2))

            wA, wN = wA_next, wN_next
            t += dt

    X = np.concatenate(Xs, axis=0).astype(DTYPE)
    Y = np.concatenate(Ys, axis=0).astype(DTYPE)
    # Normalize
    X_mu = X.mean(axis=0, keepdims=True); X_sd = X.std(axis=0, keepdims=True) + 1e-8
    Y_mu = Y.mean(axis=0, keepdims=True); Y_sd = Y.std(axis=0, keepdims=True) + 1e-8
    Xn = (X - X_mu)/X_sd
    Yn = (Y - Y_mu)/Y_sd
    return (X, Y, Xn, Yn, X_mu, X_sd, Y_mu, Y_sd), (nx,ny,dx,dy,nu,dt,x,y)

# =============================
# Tiny MLP (fallback)
# =============================
class TinyMLP:
    def __init__(self, in_dim=8, hidden=96, out_dim=2, lr=3e-3, seed=0):
        rng = np.random.RandomState(seed)
        k1 = np.sqrt(2/in_dim); k2 = np.sqrt(2/hidden)
        self.W1 = rng.randn(in_dim, hidden)*k1; self.b1 = np.zeros((1, hidden))
        self.W2 = rng.randn(hidden, out_dim)*k2; self.b2 = np.zeros((1, out_dim))
        self.lr = lr
    @staticmethod
    def relu(z):  return np.maximum(0.0, z)
    @staticmethod
    def drelu(z): return (z>0.0).astype(z.dtype)
    def fit(self, X, Y, epochs=15, batch=8192, verbose=True):
        N = X.shape[0]
        for ep in range(epochs):
            idx = np.random.permutation(N); loss = 0.0
            for k in range(0, N, batch):
                j = idx[k:k+batch]; x = X[j]; y = Y[j]
                z1 = x@self.W1 + self.b1; h1 = self.relu(z1); yhat = h1@self.W2 + self.b2
                e = yhat - y; loss += float((e*e).mean())
                dy = 2.0*e/len(j); dW2 = h1.T@dy; db2 = dy.sum(axis=0, keepdims=True)
                dh1 = dy@self.W2.T; dz1 = dh1*self.drelu(z1)
                dW1 = x.T@dz1; db1 = dz1.sum(axis=0, keepdims=True)
                self.W2 -= self.lr*dW2; self.b2 -= self.lr*db2
                self.W1 -= self.lr*dW1; self.b1 -= self.lr*db1
            if verbose: print(f"[TinyMLP] epoch {ep+1}: MSE={loss*batch/N:.4e}")
    def predict(self, X):
        z1 = X@self.W1 + self.b1; h1 = self.relu(z1); return h1@self.W2 + self.b2

In [48]:
# =============================
# Train (Steps 1–11)
# =============================
(X_raw, Y_raw, Xn, Yn, X_mu, X_sd, Y_mu, Y_sd), meta = build_dataset(
    nx=NX, ny=NY, Re=RE, dt=DT, n_cases=TRAIN_CASES, steps=TRAIN_STEPS, jitter=JITTER, seed=SEED
)
print("Dataset:", Xn.shape, Yn.shape)

if USE_SKLEARN:
    print("Training scikit-learn MLPRegressor …")
    nn = MLPRegressor(hidden_layer_sizes=(128,64), activation='relu',
                      solver='adam', learning_rate_init=1e-3, max_iter=60,
                      batch_size=4096, random_state=0, verbose=True)
    nn.fit(Xn, Yn)
    def predict_sigma_batch(X_in):
        Xn_in = (X_in - X_mu)/X_sd
        Yn_hat = nn.predict(Xn_in)
        return Yn_hat*Y_sd + Y_mu
else:
    print("Training TinyMLP (NumPy) …")
    nn = TinyMLP(in_dim=8, hidden=96, out_dim=2, lr=3e-3, seed=0)
    nn.fit(Xn, Yn, epochs=15, batch=8192, verbose=True)
    def predict_sigma_batch(X_in):
        Xn_in = (X_in - X_mu)/X_sd
        Yn_hat = nn.predict(Xn_in)
        return Yn_hat*Y_sd + Y_mu

# =============================
# Stability guard (sanity checks for σ1, σ2)
# =============================
s1_cap_abs = float(np.quantile(np.abs(Y_raw[:,0]), 0.995))
s2_cap_abs = float(np.quantile(np.abs(Y_raw[:,1]), 0.995))
print(f"[Stab] training |σ1|_q99.5={s1_cap_abs:.3e}, |σ2|_q99.5={s2_cap_abs:.3e}")

def hard_cap_sigma(s_full, cap_abs, nx, ny):
    sI = s_full[1:nx+2,1:ny+2]
    sI = np.clip(sI, -cap_abs, cap_abs)
    out = np.zeros_like(s_full); out[1:nx+2,1:ny+2] = sI
    return bc(nx,ny,out)

def box_blur_full(a_full, nx, ny, passes=1):
    out = a_full.copy()
    for _ in range(passes):
        ain = out[1:nx+2,1:ny+2]
        nb = (out[0:nx+1,1:ny+2] + out[2:nx+3,1:ny+2] +
              out[1:nx+2,0:ny+1] + out[1:nx+2,2:ny+3] +
              out[0:nx+1,0:ny+1] + out[2:nx+3,0:ny+1] +
              out[0:nx+1,2:ny+3] + out[2:nx+3,2:ny+3])
        sm = (ain*1.0 + nb)/9.0
        out[1:nx+2,1:ny+2] = sm
        out = bc(nx,ny,out)
    return out

def dissipative_gate_sigma1(s1_full, w_full, nx, ny, dx, dy, nu, dt, alpha=1.0):
    lap_full = laplacian_cd_full(nx, ny, dx, dy, w_full)
    lapI = lap_full[1:nx+2,1:ny+2]
    s1I  = s1_full[1:nx+2,1:ny+2]
    # keep only same-sign as Laplacian (adds diffusion)
    s1I = np.where(lapI * s1I > 0.0, s1I, 0.0)
    # relative cap to physical diffusion per-step
    cap_rel = alpha * np.abs(nu * lapI) / (DT + 1e-12)
    s1I = np.clip(s1I, -cap_rel, cap_rel)
    out = np.zeros_like(s1_full); out[1:nx+2,1:ny+2] = s1I
    return bc(nx,ny,out)

def stabilize_sigma2(s2_full, nx, ny, cap_abs, smooth_passes=1):
    s2I = s2_full[1:nx+2,1:ny+2]
    s2I = s2I - s2I.mean()  # remove k=0 to avoid psi drift
    out = np.zeros_like(s2_full); out[1:nx+2,1:ny+2] = s2I
    out = bc(nx,ny,out)
    out = box_blur_full(out, nx, ny, passes=smooth_passes)
    out = hard_cap_sigma(out, cap_abs, nx, ny)
    return out

class SigmaGuard:
    def __init__(self, nx, ny, dx, dy, nu, dt, s1_cap_abs, s2_cap_abs,
                 alpha_rel=1.0, s2_smooth=1, anneal_steps=50):
        self.nx, self.ny, self.dx, self.dy = nx, ny, dx, dy
        self.nu, self.dt = nu, dt
        self.s1_cap_abs, self.s2_cap_abs = s1_cap_abs, s2_cap_abs
        self.alpha_rel = alpha_rel; self.s2_smooth = s2_smooth
        self.k = 0; self.anneal_steps = max(1, anneal_steps)
    def factor(self):
        return min(1.0, self.k / self.anneal_steps)
    def predict_safe(self, w_nav, psi_nav, predict_sigma_batch):
        nx, ny = self.nx, self.ny
        # 8-neighbor stencil
        wm = w_nav[0:nx+1, 1:ny+2];  wp = w_nav[2:nx+3, 1:ny+2]
        wn = w_nav[1:nx+2, 2:ny+3];  ws = w_nav[1:nx+2, 0:ny+1]
        pm = psi_nav[0:nx+1, 1:ny+2]; pp = psi_nav[2:nx+3, 1:ny+2]
        pn = psi_nav[1:nx+2, 2:ny+3]; ps = psi_nav[1:nx+2, 0:ny+1]
        feats = np.stack([wm, wp, wn, ws, pm, pp, pn, ps], axis=-1).reshape(-1, 8)
        sig12 = predict_sigma_batch(feats).reshape(nx+1, ny+1, 2)
        s1_full = np.zeros_like(w_nav); s2_full = np.zeros_like(w_nav)
        s1_full[1:nx+2,1:ny+2] = sig12[...,0]; s2_full[1:nx+2,1:ny+2] = sig12[...,1]
        s1_full = bc(nx,ny,s1_full); s2_full = bc(nx,ny,s2_full)
        # hard caps + s2 stabilize
        s1_full = hard_cap_sigma(s1_full, self.s1_cap_abs, nx, ny)
        s2_full = stabilize_sigma2(s2_full, nx, ny, self.s2_cap_abs, smooth_passes=self.s2_smooth)
        # dissipativity + relative cap
        s1_full = dissipative_gate_sigma1(s1_full, w_nav, nx, ny, self.dx, self.dy, self.nu, self.dt, alpha=self.alpha_rel)
        # anneal
        a = self.factor(); self.k += 1
        if a < 1.0: s1_full = a*s1_full; s2_full = a*s2_full
        return s1_full, s2_full

# =============================
# Rollout (Step 12) + Animation
# =============================
nx, ny = NX, NY
dx = (2*np.pi)/nx; dy = (2*np.pi)/ny
x, y = make_xy(nx, ny)

rng = np.random.RandomState(123)
w_adv = vm_ic(nx,ny,x,y)
w_nav = w_adv.copy()
w_cos = w_adv.copy()

guard = SigmaGuard(nx, ny, dx, dy, NU, DT,
                   s1_cap_abs=s1_cap_abs, s2_cap_abs=s2_cap_abs,
                   alpha_rel=ALPHA_REL, s2_smooth=S2_SMOOTH_PASSES,
                   anneal_steps=ANNEAL_STEPS)

adv_snaps, nav_snaps, cos_snaps = [], [], []


Dataset: (42250, 8) (42250, 2)
Training scikit-learn MLPRegressor …
Iteration 1, loss = 0.26574608
Iteration 2, loss = 0.25184466
Iteration 3, loss = 0.24907127
Iteration 4, loss = 0.24779596
Iteration 5, loss = 0.24680953
Iteration 6, loss = 0.24622369
Iteration 7, loss = 0.24458332
Iteration 8, loss = 0.24304102
Iteration 9, loss = 0.24156472
Iteration 10, loss = 0.23964198
Iteration 11, loss = 0.23735733
Iteration 12, loss = 0.23496662
Iteration 13, loss = 0.23200234
Iteration 14, loss = 0.22883370
Iteration 15, loss = 0.22545463
Iteration 16, loss = 0.22167980
Iteration 17, loss = 0.21832157
Iteration 18, loss = 0.21489365
Iteration 19, loss = 0.21113459
Iteration 20, loss = 0.20762700
Iteration 21, loss = 0.20400301
Iteration 22, loss = 0.20078437
Iteration 23, loss = 0.19732376
Iteration 24, loss = 0.19492836
Iteration 25, loss = 0.19179300
Iteration 26, loss = 0.18844173
Iteration 27, loss = 0.18554985
Iteration 28, loss = 0.18330917
Iteration 29, loss = 0.17872648
Iteration 30,



In [49]:
t = 0.0
for k in range(ROLL_STEPS):
    # Advanced (FFT + Arakawa)
    w_adv = rk3_advanced(w_adv, nx, ny, dx, dy, RE, DT, x, y, t)
    t += DT

    if k % SNAP_EVERY == 0:
        adv_snaps.append(w_adv[1:nx+2,1:ny+2].copy())

In [50]:
t = 0.0
for k in range(ROLL_STEPS):
    # Naive (Jacobi + central Jacobian)
    w_nav_next = rk3_naive(w_nav, nx, ny, dx, dy, RE, DT, x, y, t)
    # Build ψ^n (naive) ONCE for features
    psi_nav_n = poisson_jacobi(nx, ny, dx, dy, w_nav)
    w_nav = w_nav_next
    t += DT

    if k % SNAP_EVERY == 0:
        nav_snaps.append(w_nav[1:nx+2,1:ny+2].copy())

In [53]:
t = 0.0
for k in range(ROLL_STEPS):
    # Predict σ safely from (ω_naive^n, ψ_naive^n)
    sigma1_full, sigma2_full = guard.predict_safe(w_nav, psi_nav_n, predict_sigma_batch)

    # Corrected step (still Naive solver, adds σ1 and uses -ω+σ2 in Poisson)
    w_cos = rk3_naive_with_sigma(w_cos, nx, ny, dx, dy, RE, DT, x, y, t, sigma1_full, sigma2_full)

    w_nav = w_nav_next
    t += DT

    if k % SNAP_EVERY == 0:
        cos_snaps.append(w_cos[1:nx+2,1:ny+2].copy())

In [55]:
 nav_snaps=cos_snaps

In [56]:
# -----------------------------
# Triple animation (Advanced | Naive | Naive+CoSTA)
# -----------------------------
def animate_triptych(adv_snaps, nav_snaps, cos_snaps, interval=40, cmap='jet'):
    T = len(adv_snaps)
    assert T>0 and T==len(nav_snaps)==len(cos_snaps)
    vmin = min(np.min(s) for s in adv_snaps+nav_snaps+cos_snaps)
    vmax = max(np.max(s) for s in adv_snaps+nav_snaps+cos_snaps)
    fig, axs = plt.subplots(1, 3, figsize=(11, 3.6), constrained_layout=True)
    titles = ["Advanced (FFT+Arakawa)", "Naive (Jacobi+CD)", "Naive + CoSTA (Jacobi+CD)"]
    ims = []
    for ax, snap, ttl in zip(axs, [adv_snaps[0], nav_snaps[0], cos_snaps[0]], titles):
        im = ax.imshow(snap.T, origin='lower', extent=[0,2*np.pi,0,2*np.pi],
                       vmin=vmin, vmax=vmax, cmap=cmap, animated=True)
        ax.set_title(ttl); ims.append(im)
    cbar = fig.colorbar(ims[0], ax=axs, shrink=0.9, pad=0.02); cbar.set_label(r'$\omega$')

    def update(i):
        ims[0].set_array(adv_snaps[i].T)
        ims[1].set_array(nav_snaps[i].T)
        ims[2].set_array(cos_snaps[i].T)
        return ims

    ani = FuncAnimation(fig, update, frames=T, interval=interval, blit=True)
    plt.close(fig)
    return HTML(ani.to_jshtml())

display(animate_triptych(adv_snaps, nav_snaps, cos_snaps, interval=35))
