<a href="https://colab.research.google.com/github/Misha-private/Demo-repo/blob/main/Shallow_WaterModel_KleinBeta_v4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# gen_klein_ics_cgrid_kelvin.py
# Build a Klein-β IC bundle with C-grid winds (u: ny×(nx+1), v: (ny+1)×nx) and h at centers.
# Extended set (10 ICs): RH/modons (your originals) + Kelvin and Kelvin-mixed cases.
# Also saves labeled plots for each IC.

import os
import numpy as np
import matplotlib.pyplot as plt

# ---------------- Paths ----------------
OUT_BUNDLE = "/content/drive/MyDrive/klein_ics/bundle_cgrid.npz"      # bundle path
OUT_PLOTS_DIR = "/content/drive/MyDrive/klein_ics/plots"              # where PNGs go

# --------------- Phys / grid ----------------
g     = 9.81
H     = 1000.0
fp    = 8.0e-5                # f(y)=fp*sin(phi), Klein plane
nx    = 256
ny    = 128
Lx    = 2.0e7
Ly    = 8.0e6
dx    = Lx/nx
dy    = Ly/ny

# ---------- Klein geometry ----------
x_c = np.linspace(0.5*dx, Lx-0.5*dx, nx)               # centers
y_c = np.linspace(0.5*dy, Ly-0.5*dy, ny)
Xc, Yc = np.meshgrid(x_c, y_c)                          # (ny,nx)
phi_c  = np.pi*((Yc/Ly) - 0.5)
f_c    = fp * np.sin(phi_c)                             # (ny,nx)

# face grids (for metadata/consistency; not strictly needed for plotting)
x_u = np.linspace(0.0, Lx, nx+1)  # u on (ny,nx+1)
y_v = np.linspace(0.0, Ly, ny+1)  # v on (ny+1,nx)
Xu, Yu = np.meshgrid(x_u, y_c)
Xv, Yv = np.meshgrid(x_c, y_v)
phi_u = np.pi*((Yu/Ly) - 0.5)
phi_v = np.pi*((Yv/Ly) - 0.5)
f_u   = fp * np.sin(phi_u)
f_v   = fp * np.sin(phi_v)

# ---------- Helpers ----------
def ensure_dir(p):
    os.makedirs(p, exist_ok=True)

HMIN = 0.5  # meters; choose small enough to avoid frequent hits

def enforce_floor_ke_preserving(u, v, h, hmin=HMIN):
    """
    If h < hmin, raise h to hmin and *reduce* velocities so that local KE is preserved:
        (1/2) h_old |u_c, v_c|^2  ==  (1/2) h_new |u'_c, v'_c|^2
    We approximate this on the C-grid by scaling faces with center-based factors.
    """
    mask = (h < hmin)
    if not np.any(mask):
        return u, v, h, 0

    # center-based scaling factor s_c = sqrt(h_old / h_new) in cells that were floored
    s_c = np.ones_like(h, dtype=np.float32)
    s_c[mask] = np.sqrt(np.maximum(h[mask], 1e-12) / hmin)

    # map center scaling to faces by averaging adjacent centers (periodic in x, edge in y)
    s_u = 0.5*(np.pad(s_c, ((0,0),(1,0)), mode='wrap') + np.pad(s_c, ((0,0),(0,1)), mode='wrap'))    # (ny, nx+1)
    s_v = 0.5*(np.pad(s_c, ((1,0),(0,0)), mode='edge') + np.pad(s_c, ((0,1),(0,0)), mode='edge'))    # (ny+1, nx)

    u = u * s_u
    v = v * s_v
    h = np.maximum(h, hmin)
    n_hit = int(mask.sum())
    return u, v, h, n_hit


# ---------- C-grid helpers ----------
def avg_x(a):  # (ny,nx) -> (ny,nx+1) (periodic)
    return 0.5*(np.pad(a,((0,0),(1,0)),mode='wrap') + np.pad(a,((0,0),(0,1)),mode='wrap'))
def avg_y(a):  # (ny,nx) -> (ny+1,nx) (mirror)
    return 0.5*(np.pad(a,((1,0),(0,0)),mode='edge') + np.pad(a,((0,1),(0,0)),mode='edge'))

def ddx_center(phi):   # (ny,nx)
    return (np.roll(phi, -1, axis=1) - np.roll(phi, 1, axis=1)) / (2.0*dx)
def ddy_center(phi):   # (ny,nx)
    pad = np.pad(phi, ((1,1),(0,0)), mode='edge')
    return (pad[2:, :] - pad[:-2, :]) / (2.0*dy)

# ---------- Klein twist boundary helpers ----------
def twist_reflect_x(arr):  # reverse in x (for top/bottom mirroring with twist)
    return arr[..., ::-1]

def enforce_klein_center(eta):
    # eta even at y-boundaries
    eta[0, :]  = 0.5*(eta[1, :] + twist_reflect_x(eta[1, :]))
    eta[-1, :] = 0.5*(eta[-2,:] + twist_reflect_x(eta[-2,:]))
    return eta

def enforce_klein_faces(u, v):
    # u odd, v even at y-boundaries
    u[0, :]  = 0.5*(u[1, :] - twist_reflect_x(u[1, :]))
    u[-1, :] = 0.5*(u[-2,:] - twist_reflect_x(u[-2,:]))
    v[0, :]  = 0.5*(v[1, :] + twist_reflect_x(v[1, :]))
    v[-1, :] = 0.5*(v[-2,:] + twist_reflect_x(v[-2,:]))
    return u, v

def faces_to_centers(u_face, v_face):
    """Map C-grid faces back to centers for visualization."""
    # u: (ny, nx+1) -> (ny,nx)
    u_c = 0.5*(u_face[:, :-1] + u_face[:, 1:])
    # v: (ny+1, nx) -> (ny,nx)
    v_c = 0.5*(v_face[:-1, :] + v_face[1:, :])
    return u_c, v_c

# ---------- Build eta ICs (your originals) ----------
def rh_wave(m=4, amp=50.0, zonal_phase=0.0):
    # “RH-like” cosine in lon/lat
    kx = 2*np.pi * m / Lx
    eta = amp * np.cos(kx*Xc + zonal_phase) * np.cos(phi_c)  # decays near poles
    return enforce_klein_center(eta.astype(np.float32))

def gaussian_modon(cx, cy, R=6e5, amp=80.0, sign=+1):
    r2 = (Xc-cx)**2 + (Yc-cy)**2
    eta = sign * amp * np.exp(-r2/(R*R))
    return enforce_klein_center(eta.astype(np.float32))

def build_eta_set():
    out = {}
    out["rh4"] = rh_wave(m=4, amp=60.0, zonal_phase=0.0)
    out["rh5"] = rh_wave(m=5, amp=45.0, zonal_phase=0.7)
    # single modon
    out["modon"] = gaussian_modon(0.55*Lx, 0.55*Ly, R=5.5e5, amp=70.0, sign=+1)
    # twin modons
    m1 = gaussian_modon(0.35*Lx, 0.45*Ly, R=5e5, amp=65.0, sign=+1)
    m2 = gaussian_modon(0.65*Lx, 0.55*Ly, R=5e5, amp=65.0, sign=-1)
    out["twin_modons"] = (m1 + m2).astype(np.float32)
    # mixed RH + modons
    out["mixed_RH2_modon"] = (rh_wave(m=2, amp=40.0) + gaussian_modon(0.5*Lx,0.6*Ly,R=6e5,amp=50)).astype(np.float32)
    out["mixed_RH3_modon"] = (rh_wave(m=3, amp=35.0) + gaussian_modon(0.6*Lx,0.4*Ly,R=6e5,amp=45,sign=-1)).astype(np.float32)
    out["mixed_RH4_2mod"]  = (rh_wave(m=4, amp=30.0) + m1 + m2).astype(np.float32)
    return out

# ---------- Geostrophic winds (centers) -> map to faces ----------
def geo_winds_from_eta(eta):
    dηdx = ddx_center(eta)
    dηdy = ddy_center(eta)
    f_floor = 1e-5
    f_reg = np.where(np.abs(f_c) < f_floor, np.sign(f_c)*f_floor, f_c)
    uc = -(g / f_reg) * dηdy
    vc =  (g / f_reg) * dηdx
    # enforce Klein center parity
    enforce_klein_center(uc)
    enforce_klein_center(vc)
    return uc.astype(np.float32), vc.astype(np.float32)

def centers_to_faces(uc, vc):
    # map center winds to faces with simple averaging + edge copy
    u = np.zeros((ny, nx+1), dtype=np.float32)
    v = np.zeros((ny+1, nx), dtype=np.float32)
    u[:,1:-1] = 0.5*(uc[:, :-1] + uc[:, 1:])
    u[:, 0]   = uc[:, 0]
    u[:, -1]  = uc[:, -1]
    v[1:-1,:] = 0.5*(vc[:-1,:] + vc[1:,:])
    v[0 , :]  = vc[0 , :]
    v[-1, :]  = vc[-1, :]
    return enforce_klein_faces(u, v)

# ---------- Kelvin mode (centers) ----------
def kelvin_centers(m=3, A_eta=30.0, phase=0.0, packet_sigma_x=None, x0=None):
    """
    Equatorially trapped Kelvin IC on Klein-β:
      v_K = 0, u_K = (c/H)*eta_K, eta_K = A * cos(kx x + phase) * exp(-(y')^2/(2 Ld^2))
    Using β_eff = df/dy|eq = fp * (π / Ly); c = sqrt(gH); Ld = sqrt(c/β_eff).
    """
    c  = np.sqrt(g*H)
    beta_eff = fp * (np.pi / Ly)
    Ld = np.sqrt(c / beta_eff)

    kx = 2*np.pi * m / Lx
    yprime = Yc - 0.5*Ly
    carrier = np.cos(kx*Xc + phase)
    merid = np.exp(-0.5*(yprime**2)/(Ld**2))

    if packet_sigma_x is not None:
        if x0 is None:
            x0 = 0.5*Lx
        carrier = carrier * np.exp(-0.5*((Xc - x0)/packet_sigma_x)**2)

    etaK = (A_eta * carrier * merid).astype(np.float32)
    uK   = ((c/H) * etaK).astype(np.float32)
    vK   = np.zeros_like(uK, dtype=np.float32)

    enforce_klein_center(etaK)
    enforce_klein_center(uK)
    enforce_klein_center(vK)
    return uK, vK, etaK

# ---------- Plotting ----------
def _sym_lims(a, pr=99.5):
    """Symmetric limits around zero using percentile clip."""
    p = np.percentile(np.abs(a), pr)
    if p == 0:
        p = np.max(np.abs(a)) if np.max(np.abs(a))>0 else 1.0
    return (-p, p)

def plot_ic(key, h_ctr, u_face, v_face, out_dir):
    """Save a 3-panel PNG: eta anomaly, u (centers), v (centers)."""
    eta = h_ctr - H
    u_c, v_c = faces_to_centers(u_face, v_face)

    vmin_eta, vmax_eta = _sym_lims(eta)
    vmin_u,   vmax_u   = _sym_lims(u_c)
    vmin_v,   vmax_v   = _sym_lims(v_c)

    fig, axs = plt.subplots(1, 3, figsize=(15, 4), constrained_layout=True)
    im0 = axs[0].imshow(eta, origin='lower', extent=[0,Lx,0,Ly], vmin=vmin_eta, vmax=vmax_eta)
    axs[0].set_title(f"{key} — η anomaly (m)")
    axs[0].set_xlabel("x (m)"); axs[0].set_ylabel("y (m)")
    fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)

    im1 = axs[1].imshow(u_c, origin='lower', extent=[0,Lx,0,Ly], vmin=vmin_u, vmax=vmax_u)
    axs[1].set_title(f"{key} — u (m/s) @ centers")
    axs[1].set_xlabel("x (m)"); axs[1].set_ylabel("y (m)")
    fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)

    im2 = axs[2].imshow(v_c, origin='lower', extent=[0,Lx,0,Ly], vmin=vmin_v, vmax=vmax_v)
    axs[2].set_title(f"{key} — v (m/s) @ centers")
    axs[2].set_xlabel("x (m)"); axs[2].set_ylabel("y (m)")
    fig.colorbar(im2, ax=axs[2], fraction=0.046, pad=0.04)

    ensure_dir(out_dir)
    out_path = os.path.join(out_dir, f"{key}_maps.png")
    fig.suptitle("Klein-β Initial Condition", fontsize=12, y=1.03)
    plt.savefig(out_path, dpi=160)
    plt.close(fig)
    return out_path

# ---------- Build bundle (10 ICs total) ----------
def main():
    ensure_dir(os.path.dirname(OUT_BUNDLE))
    ensure_dir(OUT_PLOTS_DIR)

    # 1) Your original seven η fields
    eta_set = build_eta_set()

    # 2) Kelvin components (centers)
    uK1, vK1, eK1 = kelvin_centers(m=3, A_eta=30.0, phase=0.0, packet_sigma_x=None)            # plane-wave m=3
    uK2, vK2, eK2 = kelvin_centers(m=2, A_eta=25.0, phase=0.7, packet_sigma_x=0.12*Lx, x0=0.3*Lx)  # localized packet

    # 3) Assemble 10 ICs: 7 originals + 3 Kelvin-related
    definitions = []

    # Originals (geostrophic winds)
    for key in ["rh4","rh5","modon","twin_modons","mixed_RH2_modon","mixed_RH3_modon","mixed_RH4_2mod"]:
        eta = eta_set[key]
        uc, vc = geo_winds_from_eta(eta)
        u, v   = centers_to_faces(uc, vc)
        h      = (H + eta).astype(np.float32)
        definitions.append((key, h, u, v))

    # 8) Pure Kelvin m=3
    key = "kelvin_m3"
    uc, vc, eta = uK1, vK1, eK1
    u, v = centers_to_faces(uc, vc)
    h    = (H + eta).astype(np.float32)
    definitions.append((key, h, u, v))

    # 9) RH4 + Kelvin m=3 (linear superposition; geostrophic + Kelvin)
    key = "rh4_plus_kelvin_m3"
    eta_bg = eta_set["rh4"]
    uc_bg, vc_bg = geo_winds_from_eta(eta_bg)
    uc_mix = (uc_bg + uK1).astype(np.float32)
    vc_mix = (vc_bg + vK1).astype(np.float32)  # vK1 is zero; kept for clarity
    u, v = centers_to_faces(uc_mix, vc_mix)
    h    = (H + (eta_bg + eK1)).astype(np.float32)
    definitions.append((key, h, u, v))

    # 10) Twin modons + Kelvin packet
    key = "twin_modons_plus_kelvin_pkt"
    eta_bg = eta_set["twin_modons"]
    uc_bg, vc_bg = geo_winds_from_eta(eta_bg)
    uc_mix = (uc_bg + uK2).astype(np.float32)
    vc_mix = (vc_bg + vK2).astype(np.float32)  # vK2 is zero
    u, v = centers_to_faces(uc_mix, vc_mix)
    h    = (H + (eta_bg + eK2)).astype(np.float32)
    definitions.append((key, h, u, v))

    # 4) Pack & save
    pack = {}
    for key, h, u, v in definitions:
        pack[f"{key}_h"] = h
        pack[f"{key}_u"] = u
        pack[f"{key}_v"] = v

    # metadata
    pack["grid_nx"] = np.int32(nx)
    pack["grid_ny"] = np.int32(ny)
    pack["grid_dx"] = np.float32(dx)
    pack["grid_dy"] = np.float32(dy)
    pack["H"]       = np.float32(H)
    pack["fp"]      = np.float32(fp)

    np.savez_compressed(OUT_BUNDLE, **pack)
    print(f"Saved C-grid IC bundle -> {OUT_BUNDLE}")

    # 5) Plots
    saved_pngs = []
    for key, h, u, v in definitions:
        p = plot_ic(key, h, u, v, OUT_PLOTS_DIR)
        saved_pngs.append(p)

    print("Saved plots:")
    for p in saved_pngs:
        print(" -", p)

    # Quick diagnostics on ranges
    def rng(a):
        return (float(np.min(a)), float(np.max(a)))
    for key, h, u, v in definitions:
        eta = h - H
        print(f"[{key}] eta[min,max]={rng(eta)}, u[min,max]={rng(u)}, v[min,max]={rng(v)}")

if __name__ == "__main__":
    main()


Saved C-grid IC bundle -> /content/drive/MyDrive/klein_ics/bundle_cgrid.npz
Saved plots:
 - /content/drive/MyDrive/klein_ics/plots/rh4_maps.png
 - /content/drive/MyDrive/klein_ics/plots/rh5_maps.png
 - /content/drive/MyDrive/klein_ics/plots/modon_maps.png
 - /content/drive/MyDrive/klein_ics/plots/twin_modons_maps.png
 - /content/drive/MyDrive/klein_ics/plots/mixed_RH2_modon_maps.png
 - /content/drive/MyDrive/klein_ics/plots/mixed_RH3_modon_maps.png
 - /content/drive/MyDrive/klein_ics/plots/mixed_RH4_2mod_maps.png
 - /content/drive/MyDrive/klein_ics/plots/kelvin_m3_maps.png
 - /content/drive/MyDrive/klein_ics/plots/rh4_plus_kelvin_m3_maps.png
 - /content/drive/MyDrive/klein_ics/plots/twin_modons_plus_kelvin_pkt_maps.png
[rh4] eta[min,max]=(-59.9232177734375, 59.9232177734375), u[min,max]=(-2.885535717010498, 2.885535717010498), v[min,max]=(-73.73017883300781, 73.73017883300781)
[rh5] eta[min,max]=(-44.99658203125, 44.99658203125), u[min,max]=(-2.162522792816162, 2.162522792816162), v[mi

# Two level FD model

In [None]:
# fd_klein_cgrid_2layer_AL_run.py
# Two-layer Klein-β SWE on Arakawa C-grid, vector-invariant (Arakawa–Lamb), RK4.
# INPUT: 1-layer IC bundle (h centers, u faces, v faces) per IC.
# SPLIT: barotropic default (eta_i=0, u1=u2, v1=v2), keeping H1+H2 = H (from bundle).
# OUTPUT: centered snapshots (eta, etai, uc1, vc1, uc2, vc2) at steps 0,200,...,1200.
# Also: plots of fields and time series of mass/energy.

import os
import numpy as np
import matplotlib.pyplot as plt

# --------- Paths ---------
IC_BUNDLE = "/content/drive/MyDrive/klein_ics/bundle_cgrid.npz"   # 1-layer bundle you already created
ROOT_OUT  = "/content/drive/MyDrive/klein_ckpt_2L_centers"        # per-IC subfolders

# ---------- Phys / grid (must match bundle) ----------
g   = 9.81
# Two-layer parameters (EDIT ME)
H1  = 600.0                  # rest depth layer 1 (m)
H2  = 400.0                  # rest depth layer 2 (m) -> H1+H2 matches 1-layer H by default
gprime = 0.02 * g            # reduced gravity g' (m/s^2); ~2 kg/m^3 density jump -> ~0.2 m/s^2
# Time & grid
nx  = 256
ny  = 128
Lx  = 2.0e7
Ly  = 8.0e6
dx  = Lx/nx
dy  = Ly/ny
fp  = 8.0e-5
dt  = 30.0
nt  = 1200
SAVE_STEPS = {0,200,400,600,800,1000,1200}
Htot = H1 + H2                # total rest depth (kept equal to 1-layer H)

# ---------- geometry ----------
x_c = np.linspace(0.5*dx, Lx-0.5*dx, nx)
y_c = np.linspace(0.5*dy, Ly-0.5*dy, ny)
Xc, Yc = np.meshgrid(x_c, y_c)
phi_c  = np.pi*((Yc/Ly) - 0.5)
f_c    = fp*np.sin(phi_c)

x_u = np.linspace(0.0, Lx, nx+1)
y_v = np.linspace(0.0, Ly, ny+1)
Xu, Yu = np.meshgrid(x_u, y_c)
Xv, Yv = np.meshgrid(x_c, y_v)
phi_u  = np.pi*((Yu/Ly) - 0.5)
phi_v  = np.pi*((Yv/Ly) - 0.5)
f_u    = fp*np.sin(phi_u)
f_v    = fp*np.sin(phi_v)

# ---------- Klein twist BCs ----------
def twist_reflect_x(arr): return arr[..., ::-1]

def apply_bc_2l(u1, v1, h1, u2, v2, h2):
    # centers (h1,h2): even
    for h in (h1,h2):
        h[0, :]  = 0.5*(h[1, :] + twist_reflect_x(h[1, :]))
        h[-1, :] = 0.5*(h[-2,:] + twist_reflect_x(h[-2, :]))
    # u-faces: odd ; v-faces: even
    for u in (u1,u2):
        u[0, :]  = 0.5*(u[1, :] - twist_reflect_x(u[1, :]))
        u[-1, :] = 0.5*(u[-2,:] - twist_reflect_x(u[-2, :]))
    for v in (v1,v2):
        v[0, :]  = 0.5*(v[1, :] + twist_reflect_x(v[1, :]))
        v[-1, :] = 0.5*(v[-2, :] + twist_reflect_x(v[-2, :]))
    return u1, v1, h1, u2, v2, h2

# ---------- C-grid helpers ----------
def center_from_u(u):  return 0.5*(u[:,:-1] + u[:,1:])
def center_from_v(v):  return 0.5*(v[:-1,:] + v[1:,:])
def avg_x(a):          return 0.5*(np.pad(a,((0,0),(1,0)),mode='wrap') + np.pad(a,((0,0),(0,1)),mode='wrap'))
def avg_y(a):          return 0.5*(np.pad(a,((1,0),(0,0)),mode='edge') + np.pad(a,((0,1),(0,0)),mode='edge'))

def ddx_c_to_u(phi):
    L = np.pad(phi,((0,0),(1,0)),mode='wrap'); R = np.pad(phi,((0,0),(0,1)),mode='wrap')
    return (R - L) / (2.0*dx)
def ddy_c_to_v(phi):
    T = np.pad(phi,((1,0),(0,0)),mode='edge'); B = np.pad(phi,((0,1),(0,0)),mode='edge')
    return (B - T) / (2.0*dy)
def ddx_u_to_c(phi_u): return (phi_u[:,1:] - phi_u[:,:-1]) / dx
def ddy_v_to_c(phi_v): return (phi_v[1:,:] - phi_v[:-1,:]) / dy

# Laplacians (aligned to native grids)
def lap_u(u):
    ue = np.pad(u, ((0,0),(1,1)), mode='wrap')
    u_xx = (ue[:, :-2] - 2*ue[:, 1:-1] + ue[:, 2:]) / dx**2
    ue2 = np.pad(u, ((1,1),(0,0)), mode='edge')
    u_yy = (ue2[:-2, :] - 2*ue2[1:-1, :] + ue2[2:, :]) / dy**2
    return u_xx + u_yy
def lap_v(v):
    ve = np.pad(v, ((0,0),(1,1)), mode='wrap')
    v_xx = (ve[:, :-2] - 2*ve[:, 1:-1] + ve[:, 2:]) / dx**2
    ve2 = np.pad(v, ((1,1),(0,0)), mode='edge')
    v_yy = (ve2[:-2, :] - 2*ve2[1:-1, :] + ve2[2:, :]) / dy**2
    return v_xx + v_yy
def lap_c(h):
    he = np.pad(h, ((0,0),(1,1)), mode='wrap')
    h_xx = (he[:, :-2] - 2*he[:, 1:-1] + he[:, 2:]) / dx**2
    he2 = np.pad(h, ((1,1),(0,0)), mode='edge')
    h_yy = (he2[:-2, :] - 2*he2[1:-1, :] + he2[2:, :]) / dy**2
    return h_xx + h_yy
def bih_u(u): return lap_u(lap_u(u))
def bih_v(v): return lap_v(lap_v(v))
def bih_c(h): return lap_c(lap_c(h))

# ---------- Vorticity ----------
def compute_corner_vort(u, v):
    v_w = np.pad(v,((0,0),(1,0)),mode='wrap'); v_e = np.pad(v,((0,0),(0,1)),mode='wrap')
    dv_dx = (v_e - v_w)/(2*dx)
    u_s = np.pad(u,((1,0),(0,0)),mode='edge'); u_n = np.pad(u,((0,1),(0,0)),mode='edge')
    du_dy = (u_n - u_s)/(2*dy)
    return dv_dx - du_dy
def to_u_from_corners(a):  return 0.5*(a[:-1,:] + a[1:,:])
def to_v_from_corners(a):  return 0.5*(a[:,:-1] + a[:,1:])

# ---------- Reconstruct free surface & interface from h1,h2 ----------
def reconstruct_eta_etai(h1, h2):
    etai = h2 - H2
    eta  = h1 + h2 - (H1 + H2)
    return eta, etai

# ---------- Enforce KE preserving ------------
HMIN = 0.5  # meters; choose small enough to avoid frequent hits

def enforce_floor_ke_preserving(u, v, h, hmin=HMIN):
    """
    If h < hmin, raise h to hmin and *reduce* velocities so that local KE is preserved:
        (1/2) h_old |u_c, v_c|^2  ==  (1/2) h_new |u'_c, v'_c|^2
    We approximate this on the C-grid by scaling faces with center-based factors.
    """
    mask = (h < hmin)
    if not np.any(mask):
        return u, v, h, 0

    # center-based scaling factor s_c = sqrt(h_old / h_new) in cells that were floored
    s_c = np.ones_like(h, dtype=np.float32)
    s_c[mask] = np.sqrt(np.maximum(h[mask], 1e-12) / hmin)

    # map center scaling to faces by averaging adjacent centers (periodic in x, edge in y)
    s_u = 0.5*(np.pad(s_c, ((0,0),(1,0)), mode='wrap') + np.pad(s_c, ((0,0),(0,1)), mode='wrap'))    # (ny, nx+1)
    s_v = 0.5*(np.pad(s_c, ((1,0),(0,0)), mode='edge') + np.pad(s_c, ((0,1),(0,0)), mode='edge'))    # (ny+1, nx)

    u = u * s_u
    v = v * s_v
    h = np.maximum(h, hmin)
    n_hit = int(mask.sum())
    return u, v, h, n_hit


# ---------- RHS (AL-style per layer) ----------
# diffusivities (tune as needed)
nu2_u, nu2_v, nu2_h = 1.0e4, 1.0e4, 5.0e3
nu4_u, nu4_v, nu4_h = 5.0e10, 5.0e10, 2.5e10

def rhs_2l(u1, v1, h1, u2, v2, h2):
    u1,v1,h1,u2,v2,h2 = apply_bc_2l(u1,v1,h1,u2,v2,h2)

    # --- centers & diagnostics per layer
    uc1, vc1 = center_from_u(u1), center_from_v(v1)
    uc2, vc2 = center_from_u(u2), center_from_v(v2)
    K1 = 0.5*(uc1**2 + vc1**2)
    K2 = 0.5*(uc2**2 + vc2**2)

    # --- pressure heads from eta, etai
    eta, etai = reconstruct_eta_etai(h1, h2)
    Phi1 = g*eta + gprime*etai
    Phi2 = g*eta - gprime*etai

    # --- pressure & KE gradients on faces (per layer)
    dPhidx_u1 = ddx_c_to_u(Phi1); dPhidy_v1 = ddy_c_to_v(Phi1)
    dPhidx_u2 = ddx_c_to_u(Phi2); dPhidy_v2 = ddy_c_to_v(Phi2)
    dKdx_u1   = ddx_c_to_u(K1);   dKdy_v1   = ddy_c_to_v(K1)
    dKdx_u2   = ddx_c_to_u(K2);   dKdy_v2   = ddy_c_to_v(K2)

    # --- absolute vorticity to faces (per layer)
    z_corners_1 = compute_corner_vort(u1, v1)
    z_corners_2 = compute_corner_vort(u2, v2)
    eta_u1 = to_u_from_corners(z_corners_1) + f_u
    eta_v1 = to_v_from_corners(z_corners_1) + f_v
    eta_u2 = to_u_from_corners(z_corners_2) + f_u
    eta_v2 = to_v_from_corners(z_corners_2) + f_v

    # --- transverse velocities (per layer)
    v_tu1 = avg_x(center_from_v(v1))
    u_tv1 = avg_y(center_from_u(u1))
    v_tu2 = avg_x(center_from_v(v2))
    u_tv2 = avg_y(center_from_u(u2))

    # --- momentum tendencies (AL form, inviscid + diffusion)
    du1 = -(dPhidx_u1 + dKdx_u1) + eta_u1 * v_tu1 + nu2_u*lap_u(u1) + nu4_u*bih_u(u1)
    dv1 = -(dPhidy_v1 + dKdy_v1) - eta_v1 * u_tv1 + nu2_v*lap_v(v1) + nu4_v*bih_v(v1)

    du2 = -(dPhidx_u2 + dKdx_u2) + eta_u2 * v_tu2 + nu2_u*lap_u(u2) + nu4_u*bih_u(u2)
    dv2 = -(dPhidy_v2 + dKdy_v2) - eta_v2 * u_tv2 + nu2_v*lap_v(v2) + nu4_v*bih_v(v2)

    # --- gentle interfacial drag to control baroclinic shear
    r_int = 3.0e-6  # s^-1 ~ 1/day; start tiny (1e-6 to 3e-6)

    du1 -= r_int * (u1 - u2)
    dv1 -= r_int * (v1 - v2)
    du2 -= r_int * (u2 - u1)
    dv2 -= r_int * (v2 - v1)


    # --- continuity (flux form) + diffusion (per layer)
    h1_u = avg_x(h1);  h1_v = avg_y(h1)
    h2_u = avg_x(h2);  h2_v = avg_y(h2)

    F_u1 = h1_u * u1;  F_v1 = h1_v * v1
    F_u2 = h2_u * u2;  F_v2 = h2_v * v2

    dh1dt = -(ddx_u_to_c(F_u1) + ddy_v_to_c(F_v1)) + nu2_h*lap_c(h1) + nu4_h*bih_c(h1)
    dh2dt = -(ddx_u_to_c(F_u2) + ddy_v_to_c(F_v2)) + nu2_h*lap_c(h2) + nu4_h*bih_c(h2)

    return apply_bc_2l(du1, dv1, dh1dt, du2, dv2, dh2dt)

# ---------- RK4 ----------
def rk4_2l(u1,v1,h1, u2,v2,h2, dt):
    k1u1,k1v1,k1h1, k1u2,k1v2,k1h2 = rhs_2l(u1,v1,h1, u2,v2,h2)
    u1b,v1b,h1b, u2b,v2b,h2b = apply_bc_2l(u1 + 0.5*dt*k1u1, v1 + 0.5*dt*k1v1, h1 + 0.5*dt*k1h1,
                                           u2 + 0.5*dt*k1u2, v2 + 0.5*dt*k1v2, h2 + 0.5*dt*k1h2)
    k2u1,k2v1,k2h1, k2u2,k2v2,k2h2 = rhs_2l(u1b,v1b,h1b, u2b,v2b,h2b)

    u1c,v1c,h1c, u2c,v2c,h2c = apply_bc_2l(u1 + 0.5*dt*k2u1, v1 + 0.5*dt*k2v1, h1 + 0.5*dt*k2h1,
                                           u2 + 0.5*dt*k2u2, v2 + 0.5*dt*k2v2, h2 + 0.5*dt*k2h2)
    k3u1,k3v1,k3h1, k3u2,k3v2,k3h2 = rhs_2l(u1c,v1c,h1c, u2c,v2c,h2c)

    u1d,v1d,h1d, u2d,v2d,h2d = apply_bc_2l(u1 + dt*k3u1, v1 + dt*k3v1, h1 + dt*k3h1,
                                           u2 + dt*k3u2, v2 + dt*k3v2, h2 + dt*k3h2)
    k4u1,k4v1,k4h1, k4u2,k4v2,k4h2 = rhs_2l(u1d,v1d,h1d, u2d,v2d,h2d)

    u1_new = u1 + (dt/6.0)*(k1u1 + 2*k2u1 + 2*k3u1 + k4u1)
    v1_new = v1 + (dt/6.0)*(k1v1 + 2*k2v1 + 2*k3v1 + k4v1)
    h1_new = h1 + (dt/6.0)*(k1h1 + 2*k2h1 + 2*k3h1 + k4h1)

    u2_new = u2 + (dt/6.0)*(k1u2 + 2*k2u2 + 2*k3u2 + k4u2)
    v2_new = v2 + (dt/6.0)*(k1v2 + 2*k2v2 + 2*k3v2 + k4v2)
    h2_new = h2 + (dt/6.0)*(k1h2 + 2*k2h2 + 2*k3h2 + k4h2)

    return apply_bc_2l(u1_new,v1_new,h1_new, u2_new,v2_new,h2_new)

# ---------- Diagnostics / save ----------
def total_mass_layer(h): return float(np.sum(h) * dx * dy)
def total_ke_layer(u,v,h):
    uc = center_from_u(u); vc = center_from_v(v)
    ke = 0.5*h*(uc**2 + vc**2)
    return float(np.sum(ke * dx * dy))
def total_energy_2l(u1,v1,h1, u2,v2,h2):
    uc1,vc1 = center_from_u(u1), center_from_v(v1)
    uc2,vc2 = center_from_u(u2), center_from_v(v2)
    eta, etai = reconstruct_eta_etai(h1,h2)
    ke = 0.5*h1*(uc1**2+vc1**2) + 0.5*h2*(uc2**2+vc2**2)
    pe = 0.5*g*(eta**2) + 0.5*gprime*(etai**2)
    return float(np.sum((ke + pe) * dx * dy))

def save_centered_2L(ic_key, step, u1,v1,h1, u2,v2,h2, t):
    ic_dir = os.path.join(ROOT_OUT, ic_key); os.makedirs(ic_dir, exist_ok=True)
    eta, etai = reconstruct_eta_etai(h1,h2)
    uc1, vc1 = center_from_u(u1), center_from_v(v1)
    uc2, vc2 = center_from_u(u2), center_from_v(v2)
    path = os.path.join(ic_dir, f"klein_step_{step:06d}.npz")
    np.savez_compressed(
        path,
        # centered fields
        eta=eta.astype(np.float32),        # free-surface
        etai=etai.astype(np.float32),      # interface
        uc1=uc1.astype(np.float32), vc1=vc1.astype(np.float32),
        uc2=uc2.astype(np.float32), vc2=vc2.astype(np.float32),
        # (optional) faces/centers to help downstream
        h1=h1.astype(np.float32), h2=h2.astype(np.float32),
        f=f_c.astype(np.float32), y_m=y_c.astype(np.float32),
        # metadata
        H1=np.float32(H1), H2=np.float32(H2), gprime=np.float32(gprime),
        dt=np.float32(dt), t=np.float32(t),
        nx=np.int32(nx), ny=np.int32(ny), dx=np.float32(dx), dy=np.float32(dy), fp=np.float32(fp),
    )
    return path

def quick_plot_2L(ic_key, step, u1,v1,h1, u2,v2,h2):
    eta, etai = reconstruct_eta_etai(h1,h2)
    uc1, vc1 = center_from_u(u1), center_from_v(v1)
    uc2, vc2 = center_from_u(u2), center_from_v(v2)
    pdir = os.path.join(ROOT_OUT, ic_key, "plots"); os.makedirs(pdir, exist_ok=True)
    xkm = Xc/1e3; ykm = (Yc - 0.5*Ly)/1e3
    plt.figure(figsize=(14,7)); plt.suptitle(f"{ic_key} step={step}  t={step*dt/3600:.2f} h")
    ax=plt.subplot(2,3,1); im=ax.pcolormesh(xkm,ykm,eta,  shading="auto"); plt.colorbar(im,ax=ax,label="η (m)"); ax.set_title("Free surface η")
    ax=plt.subplot(2,3,2); im=ax.pcolormesh(xkm,ykm,uc1,  shading="auto"); plt.colorbar(im,ax=ax,label="u1_c (m/s)")
    ax=plt.subplot(2,3,3); im=ax.pcolormesh(xkm,ykm,vc1,  shading="auto"); plt.colorbar(im,ax=ax,label="v1_c (m/s)")
    ax=plt.subplot(2,3,4); im=ax.pcolormesh(xkm,ykm,etai, shading="auto"); plt.colorbar(im,ax=ax,label="η_i (m)"); ax.set_title("Interface η_i")
    ax=plt.subplot(2,3,5); im=ax.pcolormesh(xkm,ykm,uc2,  shading="auto"); plt.colorbar(im,ax=ax,label="u2_c (m/s)")
    ax=plt.subplot(2,3,6); im=ax.pcolormesh(xkm,ykm,vc2,  shading="auto"); plt.colorbar(im,ax=ax,label="v2_c (m/s)")
    plt.tight_layout(); plt.savefig(os.path.join(pdir,f"fields2L_step_{step:06d}.png"),dpi=120); plt.close()

def plot_mass_energy(ic_key, steps, m1_series, m2_series, Etot_series):
    pdir = os.path.join(ROOT_OUT, ic_key, "plots"); os.makedirs(pdir, exist_ok=True)
    steps = np.asarray(steps)
    tdays = steps * dt / 86400.0
    M0 = m1_series[0] + m2_series[0]
    E0 = Etot_series[0]
    plt.figure(figsize=(9,4))
    ax = plt.gca()
    ax.plot(tdays, (np.array(m1_series)+np.array(m2_series) - M0)/M0, label="Δ(M1+M2)/M0")
    ax.plot(tdays, (np.array(Etot_series) - E0)/E0, label="ΔE/E0")
    ax.set_xlabel("time (days)")
    ax.set_ylabel("relative change")
    ax.grid(True, alpha=0.3); ax.legend()
    out = os.path.join(pdir, "mass_energy_timeseries.png")
    plt.savefig(out, dpi=120); plt.close()
    # CSV
    csvp = os.path.join(pdir, "mass_energy_timeseries.csv")
    arr = np.stack([tdays, np.array(m1_series), np.array(m2_series), np.array(Etot_series)], axis=1)
    np.savetxt(csvp, arr, delimiter=",", header="time_days,mass1,mass2,total_energy", comments="")
    print(f"[diag] saved {out} and {csvp}")

# ---------- Load 1-layer IC from bundle and split to 2-layer ----------
def load_1L_and_split_to_2L(ic_key, bundle_path, add_internal=False, ai_amp=0.0):
    """
    add_internal=True adds a small internal Kelvin-like bulge to test baroclinic response.
    ai_amp is amplitude (meters) for η_i.
    """
    d = np.load(bundle_path)
    h_tot = d[f"{ic_key}_h"].astype(np.float32)     # (ny,nx)
    u_bar = d[f"{ic_key}_u"].astype(np.float32)     # (ny,nx+1)
    v_bar = d[f"{ic_key}_v"].astype(np.float32)     # (ny+1,nx)
    # Reconstruct 1-layer eta using Htot (kept equal to 1-layer H)
    eta = h_tot - Htot
    # Default split: barotropic (η_i=0), u1=u2=u_bar, v1=v2=v_bar
    etai = np.zeros_like(eta, dtype=np.float32)
    if add_internal and ai_amp != 0.0:
        # Gaussian internal bump at equator (toy)
        y0 = 0.5*Ly
        Ld_int = np.sqrt((np.sqrt(gprime*H1*H2/(H1+H2))) / (fp*np.pi/Ly))  # crude Ld_int
        etai = ai_amp * np.exp(-0.5*((Yc - y0)/Ld_int)**2).astype(np.float32)
        # optional: small opposing baroclinic shear at faces
        uc1 =  center_from_u(u_bar); vc1 = center_from_v(v_bar)
        u1 = u_bar.copy(); v1 = v_bar.copy()
        u2 = u_bar.copy(); v2 = v_bar.copy()
    else:
        u1, v1 = u_bar.copy(), v_bar.copy()
        u2, v2 = u_bar.copy(), v_bar.copy()
    # Build layer thicknesses from (eta, etai):
    h1 = (H1 + eta - etai).astype(np.float32)
    h2 = (H2 + etai).astype(np.float32)
    return h1, u1, v1, h2, u2, v2

def list_ic_keys(bundle_path):
    d = np.load(bundle_path)
    keys = sorted(set(k[:-2] for k in d.files if k.endswith("_h")))
    return keys

# ---------- Run one IC ----------
def run_ic(ic_key):
    print(f"\n=== 2L IC: {ic_key} | nx={nx}, ny={ny}, dt={dt:.1f}s, nt={nt} ===")
    h1,u1,v1,h2,u2,v2 = load_1L_and_split_to_2L(ic_key, IC_BUNDLE, add_internal=False, ai_amp=0.0)
    u1,v1,h1,u2,v2,h2 = apply_bc_2l(u1,v1,h1,u2,v2,h2)
    h1[:] = np.maximum(h1, 5.0); h2[:] = np.maximum(h2, 5.0)

    # diagnostics storage
    diag_steps = []
    m1_series, m2_series, E_series = [], [], []

    M1_0 = total_mass_layer(h1)
    M2_0 = total_mass_layer(h2)
    E0   = total_energy_2l(u1,v1,h1, u2,v2,h2)
    t = 0.0

    # record & save step 0
    diag_steps.append(0); m1_series.append(M1_0); m2_series.append(M2_0); E_series.append(E0)
    if 0 in SAVE_STEPS:
        p = save_centered_2L(ic_key, 0, u1,v1,h1, u2,v2,h2, t)
        quick_plot_2L(ic_key, 0, u1,v1,h1, u2,v2,h2)
        print(f"[save] {ic_key} step=0 -> {p}")

    for n in range(1, nt+1):
        u1,v1,h1, u2,v2,h2 = rk4_2l(u1,v1,h1, u2,v2,h2, dt)
        t += dt
        # positivity floors
        #if np.any(h1 < 1.0): h1[:] = np.maximum(h1, 1.0)
        #if np.any(h2 < 1.0): h2[:] = np.maximum(h2, 1.0)

        hit1 = hit2 = 0
        u1, v1, h1, hit1 = enforce_floor_ke_preserving(u1, v1, h1, HMIN)
        u2, v2, h2, hit2 = enforce_floor_ke_preserving(u2, v2, h2, HMIN)


        # record diagnostics
        diag_steps.append(n)
        m1_series.append(total_mass_layer(h1))
        m2_series.append(total_mass_layer(h2))
        E_series.append(total_energy_2l(u1,v1,h1, u2,v2,h2))

        if n in SAVE_STEPS:
            p = save_centered_2L(ic_key, n, u1,v1,h1, u2,v2,h2, t)
            quick_plot_2L(ic_key, n, u1,v1,h1, u2,v2,h2)
            print(f"[save] {ic_key} step={n} -> {p}")

        if (n % 100) == 0 or n == 1:
            uc1,vc1 = center_from_u(u1), center_from_v(v1)
            uc2,vc2 = center_from_u(u2), center_from_v(v2)
            umax1 = float(np.max(np.sqrt(uc1*uc1+vc1*vc1)))
            umax2 = float(np.max(np.sqrt(uc2*uc2+vc2*vc2)))
            dM = (m1_series[-1]+m2_series[-1] - (M1_0+M2_0))/(M1_0+M2_0)
            dE = (E_series[-1]-E0)/E0
            print(f"[{n:5d}] d(M1+M2)/M0={dM:+.3e}  dE/E0={dE:+.3e}  umax1={umax1:6.2f}  umax2={umax2:6.2f}")

    # plot & save mass/energy history
    plot_mass_energy(ic_key, diag_steps, m1_series, m2_series, E_series)
    print(f"Done (2L): {ic_key}")

# ---------- Main ----------
if __name__ == "__main__":
    os.makedirs(ROOT_OUT, exist_ok=True)
    keys = list_ic_keys(IC_BUNDLE)
    print("ICs in bundle:", keys)
    for k in keys:
        run_ic(k)


ICs in bundle: ['kelvin_m3', 'mixed_RH2_modon', 'mixed_RH3_modon', 'mixed_RH4_2mod', 'modon', 'rh4', 'rh4_plus_kelvin_m3', 'rh5', 'twin_modons', 'twin_modons_plus_kelvin_pkt']

=== 2L IC: kelvin_m3 | nx=256, ny=128, dt=30.0s, nt=1200 ===
[save] kelvin_m3 step=0 -> /content/drive/MyDrive/klein_ckpt_2L_centers/kelvin_m3/klein_step_000000.npz
[    1] d(M1+M2)/M0=-2.913e-10  dE/E0=+7.439e-07  umax1=  2.97  umax2=  2.97
[  100] d(M1+M2)/M0=+2.063e-08  dE/E0=+1.109e-02  umax1=  2.94  umax2=  2.94
[save] kelvin_m3 step=200 -> /content/drive/MyDrive/klein_ckpt_2L_centers/kelvin_m3/klein_step_000200.npz
[  200] d(M1+M2)/M0=+8.514e-08  dE/E0=+4.201e-02  umax1=  2.86  umax2=  2.87
[  300] d(M1+M2)/M0=+1.795e-07  dE/E0=+8.609e-02  umax1=  2.75  umax2=  2.76
[save] kelvin_m3 step=400 -> /content/drive/MyDrive/klein_ckpt_2L_centers/kelvin_m3/klein_step_000400.npz
[  400] d(M1+M2)/M0=+2.638e-07  dE/E0=+1.339e-01  umax1=  2.62  umax2=  2.63
[  500] d(M1+M2)/M0=+2.838e-07  dE/E0=+1.752e-01  umax1=  2.4

# Two level model with more diagnostics

In [None]:
# fd_klein_cgrid_2layer_AL_run.py
# Two-layer Klein-β SWE on Arakawa C-grid, vector-invariant (Arakawa–Lamb), RK4.
# INPUT: 1-layer IC bundle (h centers, u faces, v faces) per IC.
# SPLIT: barotropic default (eta_i=0, u1=u2, v1=v2), keeping H1+H2 = H (from bundle).
# OUTPUT: centered snapshots (eta, etai, uc1, vc1, uc2, vc2) at steps 0,200,...,1200.
# Also: plots of fields and time series of mass/energy.

import os
import numpy as np
import matplotlib.pyplot as plt

# --------- Paths ---------
IC_BUNDLE = "/content/drive/MyDrive/klein_ics/bundle_cgrid.npz"   # 1-layer bundle you already created
ROOT_OUT  = "/content/drive/MyDrive/klein_ckpt_2L_centers"        # per-IC subfolders

# ---- Diagnostics toggles ----
SAVE_BTBC_FIELDS = True   # also store uc_bt, vc_bt, uc_bc, vc_bc in each npz
PLOT_BTBC_MAPS   = True   # save BT/BC maps at SAVE_STEPS

# ---------- Phys / grid (must match bundle) ----------
g   = 9.81
# Two-layer parameters (EDIT ME)
H1  = 600.0                  # rest depth layer 1 (m)
H2  = 400.0                  # rest depth layer 2 (m) -> H1+H2 matches 1-layer H by default
gprime = 0.02 * g            # reduced gravity g' (m/s^2); ~2 kg/m^3 density jump -> ~0.2 m/s^2
# Time & grid
nx  = 256
ny  = 128
Lx  = 2.0e7
Ly  = 8.0e6
dx  = Lx/nx
dy  = Ly/ny
fp  = 8.0e-5
dt  = 30.0
nt  = 1200
SAVE_STEPS = {0,200,400,600,800,1000,1200}
Htot = H1 + H2                # total rest depth (kept equal to 1-layer H)

# ---------- geometry ----------
x_c = np.linspace(0.5*dx, Lx-0.5*dx, nx)
y_c = np.linspace(0.5*dy, Ly-0.5*dy, ny)
Xc, Yc = np.meshgrid(x_c, y_c)
phi_c  = np.pi*((Yc/Ly) - 0.5)
f_c    = fp*np.sin(phi_c)

x_u = np.linspace(0.0, Lx, nx+1)
y_v = np.linspace(0.0, Ly, ny+1)
Xu, Yu = np.meshgrid(x_u, y_c)
Xv, Yv = np.meshgrid(x_c, y_v)
phi_u  = np.pi*((Yu/Ly) - 0.5)
phi_v  = np.pi*((Yv/Ly) - 0.5)
f_u    = fp*np.sin(phi_u)
f_v    = fp*np.sin(phi_v)

# ---------- Klein twist BCs ----------
def twist_reflect_x(arr): return arr[..., ::-1]

def apply_bc_2l(u1, v1, h1, u2, v2, h2):
    # centers (h1,h2): even
    for h in (h1,h2):
        h[0, :]  = 0.5*(h[1, :] + twist_reflect_x(h[1, :]))
        h[-1, :] = 0.5*(h[-2,:] + twist_reflect_x(h[-2, :]))
    # u-faces: odd ; v-faces: even
    for u in (u1,u2):
        u[0, :]  = 0.5*(u[1, :] - twist_reflect_x(u[1, :]))
        u[-1, :] = 0.5*(u[-2,:] - twist_reflect_x(u[-2, :]))
    for v in (v1,v2):
        v[0, :]  = 0.5*(v[1, :] + twist_reflect_x(v[1, :]))
        v[-1, :] = 0.5*(v[-2, :] + twist_reflect_x(v[-2, :]))
    return u1, v1, h1, u2, v2, h2

# ---------- C-grid helpers ----------
def center_from_u(u):  return 0.5*(u[:,:-1] + u[:,1:])
def center_from_v(v):  return 0.5*(v[:-1,:] + v[1:,:])
def avg_x(a):          return 0.5*(np.pad(a,((0,0),(1,0)),mode='wrap') + np.pad(a,((0,0),(0,1)),mode='wrap'))
def avg_y(a):          return 0.5*(np.pad(a,((1,0),(0,0)),mode='edge') + np.pad(a,((0,1),(0,0)),mode='edge'))

def ddx_c_to_u(phi):
    L = np.pad(phi,((0,0),(1,0)),mode='wrap'); R = np.pad(phi,((0,0),(0,1)),mode='wrap')
    return (R - L) / (2.0*dx)
def ddy_c_to_v(phi):
    T = np.pad(phi,((1,0),(0,0)),mode='edge'); B = np.pad(phi,((0,1),(0,0)),mode='edge')
    return (B - T) / (2.0*dy)
def ddx_u_to_c(phi_u): return (phi_u[:,1:] - phi_u[:,:-1]) / dx
def ddy_v_to_c(phi_v): return (phi_v[1:,:] - phi_v[:-1,:]) / dy

# Laplacians (aligned to native grids)
def lap_u(u):
    ue = np.pad(u, ((0,0),(1,1)), mode='wrap')
    u_xx = (ue[:, :-2] - 2*ue[:, 1:-1] + ue[:, 2:]) / dx**2
    ue2 = np.pad(u, ((1,1),(0,0)), mode='edge')
    u_yy = (ue2[:-2, :] - 2*ue2[1:-1, :] + ue2[2:, :]) / dy**2
    return u_xx + u_yy
def lap_v(v):
    ve = np.pad(v, ((0,0),(1,1)), mode='wrap')
    v_xx = (ve[:, :-2] - 2*ve[:, 1:-1] + ve[:, 2:]) / dx**2
    ve2 = np.pad(v, ((1,1),(0,0)), mode='edge')
    v_yy = (ve2[:-2, :] - 2*ve2[1:-1, :] + ve2[2:, :]) / dy**2
    return v_xx + v_yy
def lap_c(h):
    he = np.pad(h, ((0,0),(1,1)), mode='wrap')
    h_xx = (he[:, :-2] - 2*he[:, 1:-1] + he[:, 2:]) / dx**2
    he2 = np.pad(h, ((1,1),(0,0)), mode='edge')
    h_yy = (he2[:-2, :] - 2*he2[1:-1, :] + he2[2:, :]) / dy**2
    return h_xx + h_yy
def bih_u(u): return lap_u(lap_u(u))
def bih_v(v): return lap_v(lap_v(v))
def bih_c(h): return lap_c(lap_c(h))

# ---------- Vorticity ----------
def compute_corner_vort(u, v):
    v_w = np.pad(v,((0,0),(1,0)),mode='wrap'); v_e = np.pad(v,((0,0),(0,1)),mode='wrap')
    dv_dx = (v_e - v_w)/(2*dx)
    u_s = np.pad(u,((1,0),(0,0)),mode='edge'); u_n = np.pad(u,((0,1),(0,0)),mode='edge')
    du_dy = (u_n - u_s)/(2*dy)
    return dv_dx - du_dy
def to_u_from_corners(a):  return 0.5*(a[:-1,:] + a[1:,:])
def to_v_from_corners(a):  return 0.5*(a[:,:-1] + a[:,1:])

# ---------- Reconstruct free surface & interface from h1,h2 ----------
def reconstruct_eta_etai(h1, h2):
    etai = h2 - H2
    eta  = h1 + h2 - (H1 + H2)
    return eta, etai

# ---------- Enforce KE preserving ------------
HMIN = 0.5  # meters; choose small enough to avoid frequent hits

def enforce_floor_ke_preserving(u, v, h, hmin=HMIN):
    """
    If h < hmin, raise h to hmin and *reduce* velocities so that local KE is preserved:
        (1/2) h_old |u_c, v_c|^2  ==  (1/2) h_new |u'_c, v'_c|^2
    We approximate this on the C-grid by scaling faces with center-based factors.
    """
    mask = (h < hmin)
    if not np.any(mask):
        return u, v, h, 0

    # center-based scaling factor s_c = sqrt(h_old / h_new) in cells that were floored
    s_c = np.ones_like(h, dtype=np.float32)
    s_c[mask] = np.sqrt(np.maximum(h[mask], 1e-12) / hmin)

    # map center scaling to faces by averaging adjacent centers (periodic in x, edge in y)
    s_u = 0.5*(np.pad(s_c, ((0,0),(1,0)), mode='wrap') + np.pad(s_c, ((0,0),(0,1)), mode='wrap'))    # (ny, nx+1)
    s_v = 0.5*(np.pad(s_c, ((1,0),(0,0)), mode='edge') + np.pad(s_c, ((0,1),(0,0)), mode='edge'))    # (ny+1, nx)

    u = u * s_u
    v = v * s_v
    h = np.maximum(h, hmin)
    n_hit = int(mask.sum())
    return u, v, h, n_hit


# ---------- RHS (AL-style per layer) ----------
# diffusivities (tune as needed)
nu2_u, nu2_v, nu2_h = 1.0e4, 1.0e4, 5.0e3
nu4_u, nu4_v, nu4_h = 5.0e10, 5.0e10, 2.5e10

def rhs_2l(u1, v1, h1, u2, v2, h2):
    u1,v1,h1,u2,v2,h2 = apply_bc_2l(u1,v1,h1,u2,v2,h2)

    # --- centers & diagnostics per layer
    uc1, vc1 = center_from_u(u1), center_from_v(v1)
    uc2, vc2 = center_from_u(u2), center_from_v(v2)
    K1 = 0.5*(uc1**2 + vc1**2)
    K2 = 0.5*(uc2**2 + vc2**2)

    # --- pressure heads from eta, etai
    eta, etai = reconstruct_eta_etai(h1, h2)
    Phi1 = g*eta + gprime*etai
    Phi2 = g*eta - gprime*etai

    # --- pressure & KE gradients on faces (per layer)
    dPhidx_u1 = ddx_c_to_u(Phi1); dPhidy_v1 = ddy_c_to_v(Phi1)
    dPhidx_u2 = ddx_c_to_u(Phi2); dPhidy_v2 = ddy_c_to_v(Phi2)
    dKdx_u1   = ddx_c_to_u(K1);   dKdy_v1   = ddy_c_to_v(K1)
    dKdx_u2   = ddx_c_to_u(K2);   dKdy_v2   = ddy_c_to_v(K2)

    # --- absolute vorticity to faces (per layer)
    z_corners_1 = compute_corner_vort(u1, v1)
    z_corners_2 = compute_corner_vort(u2, v2)
    eta_u1 = to_u_from_corners(z_corners_1) + f_u
    eta_v1 = to_v_from_corners(z_corners_1) + f_v
    eta_u2 = to_u_from_corners(z_corners_2) + f_u
    eta_v2 = to_v_from_corners(z_corners_2) + f_v

    # --- transverse velocities (per layer)
    v_tu1 = avg_x(center_from_v(v1))
    u_tv1 = avg_y(center_from_u(u1))
    v_tu2 = avg_x(center_from_v(v2))
    u_tv2 = avg_y(center_from_u(u2))

    # --- momentum tendencies (AL form, inviscid + diffusion)
    du1 = -(dPhidx_u1 + dKdx_u1) + eta_u1 * v_tu1 + nu2_u*lap_u(u1) + nu4_u*bih_u(u1)
    dv1 = -(dPhidy_v1 + dKdy_v1) - eta_v1 * u_tv1 + nu2_v*lap_v(v1) + nu4_v*bih_v(v1)

    du2 = -(dPhidx_u2 + dKdx_u2) + eta_u2 * v_tu2 + nu2_u*lap_u(u2) + nu4_u*bih_u(u2)
    dv2 = -(dPhidy_v2 + dKdy_v2) - eta_v2 * u_tv2 + nu2_v*lap_v(v2) + nu4_v*bih_v(v2)

    # --- gentle interfacial drag to control baroclinic shear
    r_int = 3.0e-6  # s^-1 ~ 1/day; start tiny (1e-6 to 3e-6)

    du1 -= r_int * (u1 - u2)
    dv1 -= r_int * (v1 - v2)
    du2 -= r_int * (u2 - u1)
    dv2 -= r_int * (v2 - v1)


    # --- continuity (flux form) + diffusion (per layer)
    h1_u = avg_x(h1);  h1_v = avg_y(h1)
    h2_u = avg_x(h2);  h2_v = avg_y(h2)

    F_u1 = h1_u * u1;  F_v1 = h1_v * v1
    F_u2 = h2_u * u2;  F_v2 = h2_v * v2

    dh1dt = -(ddx_u_to_c(F_u1) + ddy_v_to_c(F_v1)) + nu2_h*lap_c(h1) + nu4_h*bih_c(h1)
    dh2dt = -(ddx_u_to_c(F_u2) + ddy_v_to_c(F_v2)) + nu2_h*lap_c(h2) + nu4_h*bih_c(h2)

    return apply_bc_2l(du1, dv1, dh1dt, du2, dv2, dh2dt)

# ---------- RK4 ----------
def rk4_2l(u1,v1,h1, u2,v2,h2, dt):
    k1u1,k1v1,k1h1, k1u2,k1v2,k1h2 = rhs_2l(u1,v1,h1, u2,v2,h2)
    u1b,v1b,h1b, u2b,v2b,h2b = apply_bc_2l(u1 + 0.5*dt*k1u1, v1 + 0.5*dt*k1v1, h1 + 0.5*dt*k1h1,
                                           u2 + 0.5*dt*k1u2, v2 + 0.5*dt*k1v2, h2 + 0.5*dt*k1h2)
    k2u1,k2v1,k2h1, k2u2,k2v2,k2h2 = rhs_2l(u1b,v1b,h1b, u2b,v2b,h2b)

    u1c,v1c,h1c, u2c,v2c,h2c = apply_bc_2l(u1 + 0.5*dt*k2u1, v1 + 0.5*dt*k2v1, h1 + 0.5*dt*k2h1,
                                           u2 + 0.5*dt*k2u2, v2 + 0.5*dt*k2v2, h2 + 0.5*dt*k2h2)
    k3u1,k3v1,k3h1, k3u2,k3v2,k3h2 = rhs_2l(u1c,v1c,h1c, u2c,v2c,h2c)

    u1d,v1d,h1d, u2d,v2d,h2d = apply_bc_2l(u1 + dt*k3u1, v1 + dt*k3v1, h1 + dt*k3h1,
                                           u2 + dt*k3u2, v2 + dt*k3v2, h2 + dt*k3h2)
    k4u1,k4v1,k4h1, k4u2,k4v2,k4h2 = rhs_2l(u1d,v1d,h1d, u2d,v2d,h2d)

    u1_new = u1 + (dt/6.0)*(k1u1 + 2*k2u1 + 2*k3u1 + k4u1)
    v1_new = v1 + (dt/6.0)*(k1v1 + 2*k2v1 + 2*k3v1 + k4v1)
    h1_new = h1 + (dt/6.0)*(k1h1 + 2*k2h1 + 2*k3h1 + k4h1)

    u2_new = u2 + (dt/6.0)*(k1u2 + 2*k2u2 + 2*k3u2 + k4u2)
    v2_new = v2 + (dt/6.0)*(k1v2 + 2*k2v2 + 2*k3v2 + k4v2)
    h2_new = h2 + (dt/6.0)*(k1h2 + 2*k2h2 + 2*k3h2 + k4h2)

    return apply_bc_2l(u1_new,v1_new,h1_new, u2_new,v2_new,h2_new)

# ---------- Diagnostics / save ----------
def total_mass_layer(h): return float(np.sum(h) * dx * dy)
def total_ke_layer(u,v,h):
    uc = center_from_u(u); vc = center_from_v(v)
    ke = 0.5*h*(uc**2 + vc**2)
    return float(np.sum(ke * dx * dy))
def total_energy_2l(u1,v1,h1, u2,v2,h2):
    uc1,vc1 = center_from_u(u1), center_from_v(v1)
    uc2,vc2 = center_from_u(u2), center_from_v(v2)
    eta, etai = reconstruct_eta_etai(h1,h2)
    ke = 0.5*h1*(uc1**2+vc1**2) + 0.5*h2*(uc2**2+vc2**2)
    pe = 0.5*g*(eta**2) + 0.5*gprime*(etai**2)
    return float(np.sum((ke + pe) * dx * dy))

def save_centered_2L_old(ic_key, step, u1,v1,h1, u2,v2,h2, t):
    ic_dir = os.path.join(ROOT_OUT, ic_key); os.makedirs(ic_dir, exist_ok=True)
    eta, etai = reconstruct_eta_etai(h1,h2)
    uc1, vc1 = center_from_u(u1), center_from_v(v1)
    uc2, vc2 = center_from_u(u2), center_from_v(v2)
    path = os.path.join(ic_dir, f"klein_step_{step:06d}.npz")
    np.savez_compressed(
        path,
        # centered fields
        eta=eta.astype(np.float32),        # free-surface
        etai=etai.astype(np.float32),      # interface
        uc1=uc1.astype(np.float32), vc1=vc1.astype(np.float32),
        uc2=uc2.astype(np.float32), vc2=vc2.astype(np.float32),
        # (optional) faces/centers to help downstream
        h1=h1.astype(np.float32), h2=h2.astype(np.float32),
        f=f_c.astype(np.float32), y_m=y_c.astype(np.float32),
        # metadata
        H1=np.float32(H1), H2=np.float32(H2), gprime=np.float32(gprime),
        dt=np.float32(dt), t=np.float32(t),
        nx=np.int32(nx), ny=np.int32(ny), dx=np.float32(dx), dy=np.float32(dy), fp=np.float32(fp),
    )
    return path
def save_centered_2L(ic_key, step, u1,v1,h1, u2,v2,h2, t):
    ic_dir = os.path.join(ROOT_OUT, ic_key); os.makedirs(ic_dir, exist_ok=True)
    eta, etai = reconstruct_eta_etai(h1,h2)
    uc1, vc1 = center_from_u(u1), center_from_v(v1)
    uc2, vc2 = center_from_u(u2), center_from_v(v2)

    pack = dict(
        # centered fields
        eta=eta.astype(np.float32),        # free-surface
        etai=etai.astype(np.float32),      # interface
        uc1=uc1.astype(np.float32), vc1=vc1.astype(np.float32),
        uc2=uc2.astype(np.float32), vc2=vc2.astype(np.float32),
        # (optional) faces/centers to help downstream
        h1=h1.astype(np.float32), h2=h2.astype(np.float32),
        f=f_c.astype(np.float32), y_m=y_c.astype(np.float32),
        # metadata
        H1=np.float32(H1), H2=np.float32(H2), gprime=np.float32(gprime),
        dt=np.float32(dt), t=np.float32(t),
        nx=np.int32(nx), ny=np.int32(ny), dx=np.float32(dx), dy=np.float32(dy), fp=np.float32(fp),
    )

    if SAVE_BTBC_FIELDS:
        uc_bt, vc_bt, uc_sh, vc_sh, _, _ = bt_bc_fields(u1,v1,h1, u2,v2,h2)
        pack.update(
            uc_bt=uc_bt.astype(np.float32), vc_bt=vc_bt.astype(np.float32),
            uc_bc=uc_sh.astype(np.float32), vc_bc=vc_sh.astype(np.float32)
        )

    path = os.path.join(ic_dir, f"klein_step_{step:06d}.npz")
    np.savez_compressed(path, **pack)
    return path


def quick_plot_2L(ic_key, step, u1,v1,h1, u2,v2,h2):
    eta, etai = reconstruct_eta_etai(h1,h2)
    uc1, vc1 = center_from_u(u1), center_from_v(v1)
    uc2, vc2 = center_from_u(u2), center_from_v(v2)
    pdir = os.path.join(ROOT_OUT, ic_key, "plots"); os.makedirs(pdir, exist_ok=True)
    xkm = Xc/1e3; ykm = (Yc - 0.5*Ly)/1e3
    plt.figure(figsize=(14,7)); plt.suptitle(f"{ic_key} step={step}  t={step*dt/3600:.2f} h")
    ax=plt.subplot(2,3,1); im=ax.pcolormesh(xkm,ykm,eta,  shading="auto"); plt.colorbar(im,ax=ax,label="η (m)"); ax.set_title("Free surface η")
    ax=plt.subplot(2,3,2); im=ax.pcolormesh(xkm,ykm,uc1,  shading="auto"); plt.colorbar(im,ax=ax,label="u1_c (m/s)")
    ax=plt.subplot(2,3,3); im=ax.pcolormesh(xkm,ykm,vc1,  shading="auto"); plt.colorbar(im,ax=ax,label="v1_c (m/s)")
    ax=plt.subplot(2,3,4); im=ax.pcolormesh(xkm,ykm,etai, shading="auto"); plt.colorbar(im,ax=ax,label="η_i (m)"); ax.set_title("Interface η_i")
    ax=plt.subplot(2,3,5); im=ax.pcolormesh(xkm,ykm,uc2,  shading="auto"); plt.colorbar(im,ax=ax,label="u2_c (m/s)")
    ax=plt.subplot(2,3,6); im=ax.pcolormesh(xkm,ykm,vc2,  shading="auto"); plt.colorbar(im,ax=ax,label="v2_c (m/s)")
    plt.tight_layout(); plt.savefig(os.path.join(pdir,f"fields2L_step_{step:06d}.png"),dpi=120); plt.close()

def quick_plot_BTBC(ic_key, step, u1,v1,h1, u2,v2,h2):
    uc_bt, vc_bt, uc_sh, vc_sh, _, _ = bt_bc_fields(u1,v1,h1, u2,v2,h2)
    pdir = os.path.join(ROOT_OUT, ic_key, "plots"); os.makedirs(pdir, exist_ok=True)
    xkm = Xc/1e3; ykm = (Yc - 0.5*Ly)/1e3
    plt.figure(figsize=(12,6)); plt.suptitle(f"{ic_key} BT/BC step={step}")
    ax=plt.subplot(2,2,1); im=ax.pcolormesh(xkm,ykm,uc_bt, shading="auto"); plt.colorbar(im,ax=ax,label="u_BT (m/s)"); ax.set_title("Barotropic u_BT")
    ax=plt.subplot(2,2,2); im=ax.pcolormesh(xkm,ykm,vc_bt, shading="auto"); plt.colorbar(im,ax=ax,label="v_BT (m/s)"); ax.set_title("Barotropic v_BT")
    ax=plt.subplot(2,2,3); im=ax.pcolormesh(xkm,ykm,uc_sh, shading="auto"); plt.colorbar(im,ax=ax,label="u1-u2 (m/s)"); ax.set_title("Baroclinic shear u1-u2")
    ax=plt.subplot(2,2,4); im=ax.pcolormesh(xkm,ykm,vc_sh, shading="auto"); plt.colorbar(im,ax=ax,label="v1-v2 (m/s)"); ax.set_title("Baroclinic shear v1-v2")
    plt.tight_layout(); plt.savefig(os.path.join(pdir,f"btbc_step_{step:06d}.png"),dpi=120); plt.close()


def plot_mass_energy_old(ic_key, steps, m1_series, m2_series, Etot_series):
    pdir = os.path.join(ROOT_OUT, ic_key, "plots"); os.makedirs(pdir, exist_ok=True)
    steps = np.asarray(steps)
    tdays = steps * dt / 86400.0
    M0 = m1_series[0] + m2_series[0]
    E0 = Etot_series[0]
    plt.figure(figsize=(9,4))
    ax = plt.gca()
    ax.plot(tdays, (np.array(m1_series)+np.array(m2_series) - M0)/M0, label="Δ(M1+M2)/M0")
    ax.plot(tdays, (np.array(Etot_series) - E0)/E0, label="ΔE/E0")
    ax.set_xlabel("time (days)")
    ax.set_ylabel("relative change")
    ax.grid(True, alpha=0.3); ax.legend()
    out = os.path.join(pdir, "mass_energy_timeseries.png")
    plt.savefig(out, dpi=120); plt.close()
    # CSV
    csvp = os.path.join(pdir, "mass_energy_timeseries.csv")
    arr = np.stack([tdays, np.array(m1_series), np.array(m2_series), np.array(Etot_series)], axis=1)
    np.savetxt(csvp, arr, delimiter=",", header="time_days,mass1,mass2,total_energy", comments="")
    print(f"[diag] saved {out} and {csvp}")

def plot_mass_energy(ic_key, steps, m1_series, m2_series, Etot_series,
                     KE_series, KEbt_series, KEbc_series):
    pdir = os.path.join(ROOT_OUT, ic_key, "plots"); os.makedirs(pdir, exist_ok=True)
    steps = np.asarray(steps)
    tdays = steps * dt / 86400.0

    M0 = m1_series[0] + m2_series[0]
    E0 = Etot_series[0]
    KE0 = KE_series[0]

    fig, ax = plt.subplots(figsize=(9,4))
    # relative changes on left
    ax.plot(tdays, (np.array(m1_series)+np.array(m2_series) - M0)/M0, label="Δ(M1+M2)/M0")
    ax.plot(tdays, (np.array(Etot_series)-E0)/E0,                 label="ΔE/E0")
    ax.set_xlabel("time (days)"); ax.set_ylabel("relative change")
    ax.grid(True, alpha=0.3)

    # absolute KE on right
    ax2 = ax.twinx()
    ax2.plot(tdays, KE_series,      ls="--", label="KE total (J)")
    ax2.plot(tdays, KEbt_series,    ls=":",  label="KE_BT (J)")
    ax2.plot(tdays, KEbc_series,    ls="-.", label="KE_BC (J)")
    ax2.set_ylabel("energy (J)")

    # unify legend
    lines = ax.get_lines() + ax2.get_lines()
    ax.legend(lines, [l.get_label() for l in lines], loc="best")

    out = os.path.join(pdir, "mass_energy_timeseries.png")
    plt.tight_layout(); plt.savefig(out, dpi=120); plt.close()

    # CSV (self-contained)
    csvp = os.path.join(pdir, "mass_energy_timeseries.csv")
    arr = np.stack([tdays,
                    np.array(m1_series), np.array(m2_series),
                    np.array(Etot_series),
                    np.array(KE_series), np.array(KEbt_series), np.array(KEbc_series)], axis=1)
    np.savetxt(csvp, arr, delimiter=",",
               header="time_days,mass1,mass2,total_energy,KE_total,KE_BT,KE_BC",
               comments="")
    print(f"[diag] saved {out} and {csvp}")


def total_ke_2l(u1,v1,h1, u2,v2,h2):
    uc1, vc1 = center_from_u(u1), center_from_v(v1)
    uc2, vc2 = center_from_u(u2), center_from_v(v2)
    ke = 0.5*h1*(uc1**2+vc1**2) + 0.5*h2*(uc2**2+vc2**2)
    return float(np.sum(ke * dx * dy))

def bt_bc_fields(u1,v1,h1, u2,v2,h2):
    """
    Return mass-weighted barotropic velocity and baroclinic shear (centers),
    plus domain-integrated KE_BT and KE_BC (exact cellwise identity):
        KE = KE_BT + KE_BC
    with:
        uc_bt = (h1*uc1 + h2*uc2)/(h1+h2),  uc_sh = uc1 - uc2
        KE_BT = 0.5*(h1+h2)*(uc_bt^2 + vc_bt^2)
        KE_BC = 0.5*(h1*h2/(h1+h2))*(uc_sh^2 + vc_sh^2)
    """
    uc1, vc1 = center_from_u(u1), center_from_v(v1)
    uc2, vc2 = center_from_u(u2), center_from_v(v2)

    hT = h1 + h2
    # Avoid divide-by-zero: if any hT is tiny, clamp locally (won't happen in practice with floors)
    hT_safe = np.maximum(hT, 1e-12)

    uc_bt = (h1*uc1 + h2*uc2) / hT_safe
    vc_bt = (h1*vc1 + h2*vc2) / hT_safe

    uc_sh = uc1 - uc2
    vc_sh = vc1 - vc2

    KE_BT = 0.5*hT*(uc_bt**2 + vc_bt**2)
    KE_BC = 0.5*(h1*h2/hT_safe)*(uc_sh**2 + vc_sh**2)

    KE_BT_tot = float(np.sum(KE_BT * dx * dy))
    KE_BC_tot = float(np.sum(KE_BC * dx * dy))

    return uc_bt, vc_bt, uc_sh, vc_sh, KE_BT_tot, KE_BC_tot


# ---------- Load 1-layer IC from bundle and split to 2-layer ----------
def load_1L_and_split_to_2L(ic_key, bundle_path, add_internal=False, ai_amp=0.0):
    """
    add_internal=True adds a small internal Kelvin-like bulge to test baroclinic response.
    ai_amp is amplitude (meters) for η_i.
    """
    d = np.load(bundle_path)
    h_tot = d[f"{ic_key}_h"].astype(np.float32)     # (ny,nx)
    u_bar = d[f"{ic_key}_u"].astype(np.float32)     # (ny,nx+1)
    v_bar = d[f"{ic_key}_v"].astype(np.float32)     # (ny+1,nx)
    # Reconstruct 1-layer eta using Htot (kept equal to 1-layer H)
    eta = h_tot - Htot
    # Default split: barotropic (η_i=0), u1=u2=u_bar, v1=v2=v_bar
    etai = np.zeros_like(eta, dtype=np.float32)
    if add_internal and ai_amp != 0.0:
        # Gaussian internal bump at equator (toy)
        y0 = 0.5*Ly
        Ld_int = np.sqrt((np.sqrt(gprime*H1*H2/(H1+H2))) / (fp*np.pi/Ly))  # crude Ld_int
        etai = ai_amp * np.exp(-0.5*((Yc - y0)/Ld_int)**2).astype(np.float32)
        # optional: small opposing baroclinic shear at faces
        uc1 =  center_from_u(u_bar); vc1 = center_from_v(v_bar)
        u1 = u_bar.copy(); v1 = v_bar.copy()
        u2 = u_bar.copy(); v2 = v_bar.copy()
    else:
        u1, v1 = u_bar.copy(), v_bar.copy()
        u2, v2 = u_bar.copy(), v_bar.copy()
    # Build layer thicknesses from (eta, etai):
    h1 = (H1 + eta - etai).astype(np.float32)
    h2 = (H2 + etai).astype(np.float32)
    return h1, u1, v1, h2, u2, v2

def list_ic_keys(bundle_path):
    d = np.load(bundle_path)
    keys = sorted(set(k[:-2] for k in d.files if k.endswith("_h")))
    return keys

# ---------- Run one IC ----------
def run_ic(ic_key):
    print(f"\n=== 2L IC: {ic_key} | nx={nx}, ny={ny}, dt={dt:.1f}s, nt={nt} ===")
    h1,u1,v1,h2,u2,v2 = load_1L_and_split_to_2L(ic_key, IC_BUNDLE, add_internal=False, ai_amp=0.0)
    u1,v1,h1,u2,v2,h2 = apply_bc_2l(u1,v1,h1,u2,v2,h2)
    h1[:] = np.maximum(h1, 5.0); h2[:] = np.maximum(h2, 5.0)

    # diagnostics storage
    diag_steps = []
    m1_series, m2_series, E_series = [], [], []
    KE_series, KEbt_series, KEbc_series = [], [], []


    M1_0 = total_mass_layer(h1)
    M2_0 = total_mass_layer(h2)
    E0   = total_energy_2l(u1,v1,h1, u2,v2,h2)
    KE0  = total_ke_2l(u1,v1,h1, u2,v2,h2)
    _, _, _, _, KEbt0, KEbc0 = bt_bc_fields(u1,v1,h1, u2,v2,h2)
    KE_series.append(KE0); KEbt_series.append(KEbt0); KEbc_series.append(KEbc0)

    t = 0.0

    # record & save step 0
    diag_steps.append(0); m1_series.append(M1_0); m2_series.append(M2_0); E_series.append(E0)
    if 0 in SAVE_STEPS:
        p = save_centered_2L(ic_key, 0, u1,v1,h1, u2,v2,h2, t)
        quick_plot_2L(ic_key, 0, u1,v1,h1, u2,v2,h2)
        print(f"[save] {ic_key} step=0 -> {p}")

    for n in range(1, nt+1):
        u1,v1,h1, u2,v2,h2 = rk4_2l(u1,v1,h1, u2,v2,h2, dt)
        t += dt
        # positivity floors
        #if np.any(h1 < 1.0): h1[:] = np.maximum(h1, 1.0)
        #if np.any(h2 < 1.0): h2[:] = np.maximum(h2, 1.0)

        hit1 = hit2 = 0
        u1, v1, h1, hit1 = enforce_floor_ke_preserving(u1, v1, h1, HMIN)
        u2, v2, h2, hit2 = enforce_floor_ke_preserving(u2, v2, h2, HMIN)


        # record diagnostics
        diag_steps.append(n)
        m1_series.append(total_mass_layer(h1))
        m2_series.append(total_mass_layer(h2))
        E_series.append(total_energy_2l(u1,v1,h1, u2,v2,h2))
        KEt = total_ke_2l(u1,v1,h1, u2,v2,h2)
        _, _, _, _, KEbt, KEbc = bt_bc_fields(u1,v1,h1, u2,v2,h2)
        KE_series.append(KEt); KEbt_series.append(KEbt); KEbc_series.append(KEbc)


        if n in SAVE_STEPS:
            p = save_centered_2L(ic_key, n, u1,v1,h1, u2,v2,h2, t)
            quick_plot_2L(ic_key, n, u1,v1,h1, u2,v2,h2)
            print(f"[save] {ic_key} step={n} -> {p}")
            if PLOT_BTBC_MAPS:
                quick_plot_BTBC(ic_key, n, u1,v1,h1, u2,v2,h2)

        if (n % 100) == 0 or n == 1:
            uc1,vc1 = center_from_u(u1), center_from_v(v1)
            uc2,vc2 = center_from_u(u2), center_from_v(v2)
            umax1 = float(np.max(np.sqrt(uc1*uc1+vc1*vc1)))
            umax2 = float(np.max(np.sqrt(uc2*uc2+vc2*vc2)))
            dM = (m1_series[-1]+m2_series[-1] - (M1_0+M2_0))/(M1_0+M2_0)
            dE = (E_series[-1]-E0)/E0
            frac_bt = KEbt / max(KEt, 1e-30)
            print(f"[{n:5d}] d(M1+M2)/M0={dM:+.3e}  dE/E0={dE:+.3e}  umax1={umax1:6.2f}  umax2={umax2:6.2f}  KE_BT%={100*frac_bt:5.1f}")


    # plot & save mass/energy history
    plot_mass_energy(ic_key, diag_steps, m1_series, m2_series, E_series,
                 KE_series, KEbt_series, KEbc_series)

    print(f"Done (2L): {ic_key}")

# ---------- Main ----------
if __name__ == "__main__":
    os.makedirs(ROOT_OUT, exist_ok=True)
    keys = list_ic_keys(IC_BUNDLE)
    print("ICs in bundle:", keys)
    for k in keys:
        run_ic(k)


ICs in bundle: ['kelvin_m3', 'mixed_RH2_modon', 'mixed_RH3_modon', 'mixed_RH4_2mod', 'modon', 'rh4', 'rh4_plus_kelvin_m3', 'rh5', 'twin_modons', 'twin_modons_plus_kelvin_pkt']

=== 2L IC: kelvin_m3 | nx=256, ny=128, dt=30.0s, nt=1200 ===
[save] kelvin_m3 step=0 -> /content/drive/MyDrive/klein_ckpt_2L_centers/kelvin_m3/klein_step_000000.npz
[    1] d(M1+M2)/M0=-2.913e-10  dE/E0=+7.439e-07  umax1=  2.97  umax2=  2.97  KE_BT%=100.0
[  100] d(M1+M2)/M0=+2.063e-08  dE/E0=+1.109e-02  umax1=  2.94  umax2=  2.94  KE_BT%=100.0
[save] kelvin_m3 step=200 -> /content/drive/MyDrive/klein_ckpt_2L_centers/kelvin_m3/klein_step_000200.npz
[  200] d(M1+M2)/M0=+8.514e-08  dE/E0=+4.201e-02  umax1=  2.86  umax2=  2.87  KE_BT%=100.0
[  300] d(M1+M2)/M0=+1.795e-07  dE/E0=+8.609e-02  umax1=  2.75  umax2=  2.76  KE_BT%=100.0
[save] kelvin_m3 step=400 -> /content/drive/MyDrive/klein_ckpt_2L_centers/kelvin_m3/klein_step_000400.npz
[  400] d(M1+M2)/M0=+2.638e-07  dE/E0=+1.339e-01  umax1=  2.62  umax2=  2.63  KE_B

In [None]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0


# GNN for two layer shallow water model emulation

In [None]:
# ==============================================
# train_gnn_klein2L_8nbr.py
# GraphSAGE trainer for 2-layer Klein-β SWE (centered snapshots).
# Features @ t:  [eta, etai, uc1, vc1, uc2, vc2, f, y_norm]  -> in_ch=8
# Targets @ t+H: [eta, etai, uc1, vc1, uc2, vc2]             -> out_ch=6
#
# Expects files saved by your 2-layer FD driver, e.g.:
# npz keys: eta, etai, uc1, vc1, uc2, vc2, h1, h2, f, y_m, (H1,H2,gprime,dt,t,...)
#
# Saves:
#   - weights: ckpt_dir/gnn2L_rollout.pt
#   - norms  : ckpt_dir/norm_stats_2L.pth  (torch.save with {'mean','std','ny','nx'})
#   - training curves: loss_history.csv/.png
# ==============================================

import os, glob, math, csv
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
try:
    from torch_geometric.data import Data
    from torch_geometric.loader import DataLoader as GeoDataLoader
    from torch_geometric.nn import SAGEConv
except Exception:
    raise RuntimeError(
        "This script needs torch_geometric. In Colab run:\n"
        "!pip -q install torch-geometric torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.2.0+cpu.html"
    )

# -------------------
# graph construction
# -------------------
def build_edge_index_8(ny:int, nx:int):
    src, dst = [], []
    def add(u,v): src.append(u); dst.append(v)
    for j in range(ny):
        for i in range(nx):
            me = j*nx + i
            for dj in (-1,0,1):
                jj = j + dj
                if jj < 0 or jj >= ny:  # clamped in y
                    continue
                for di in (-1,0,1):
                    if di==0 and dj==0: continue
                    ii = (i + di) % nx    # periodic in x
                    add(me, jj*nx + ii)
    return torch.tensor([src, dst], dtype=torch.long)

# -------------------
# dataset & pairs
# -------------------
def scan_pairs(root:str, horizon:int):
    """
    Return list of (src_path, tgt_path) pairs per IC folder.
    Uses steps like s -> s+horizon if both exist in the same IC dir.
    """
    pairs = []
    ic_dirs = sorted([d for d in glob.glob(os.path.join(root, "*")) if os.path.isdir(d)])
    print("[scan] IC folders:", [os.path.basename(d) for d in ic_dirs])
    total_files = 0
    for icd in ic_dirs:
        files = sorted(glob.glob(os.path.join(icd, "klein_step_*.npz")))
        steps = [int(os.path.basename(p).split("_")[-1].split(".")[0]) for p in files]
        idx = {s:p for s,p in zip(steps, files)}
        total_files += len(files)
        for s in steps:
            t = s + horizon
            if t in idx:
                pairs.append((idx[s], idx[t]))
    print(f"[scan] found {len(pairs)} pairs across {len(ic_dirs)} ICs ({total_files} files).")
    return pairs

def load_npz_2L_centered(path:str):
    """
    Expect 2-layer centered fields written by your 2L FD driver.
    Returns (eta, etai, uc1, vc1, uc2, vc2, f, y2) where y2 is (ny,nx).
    """
    d = np.load(path)
    eta  = d["eta" ].astype(np.float32)   # (ny,nx)
    etai = d["etai"].astype(np.float32)
    uc1  = d["uc1" ].astype(np.float32)
    vc1  = d["vc1" ].astype(np.float32)
    uc2  = d["uc2" ].astype(np.float32)
    vc2  = d["vc2" ].astype(np.float32)
    f    = d["f"   ].astype(np.float32)
    y_m  = d["y_m" ].astype(np.float32)   # (ny,) or (ny,nx)
    if y_m.ndim == 1:
        y2 = np.repeat(y_m[:,None], eta.shape[1], axis=1)
    else:
        y2 = y_m
    return eta, etai, uc1, vc1, uc2, vc2, f, y2

class Klein2LPairs(Dataset):
    """
    Node-wise dataset:
      x(t) = [eta,etai,uc1,vc1,uc2,vc2,f,y_norm]  -> standardized per-channel
      y(t+H) = [eta,etai,uc1,vc1,uc2,vc2]
    """
    def __init__(self, pairs, standardize:bool=True, sample_stats:int=64):
        super().__init__()
        self.pairs = pairs
        # peek shapes
        eta, etai, uc1, vc1, uc2, vc2, f, y2 = load_npz_2L_centered(pairs[0][0])
        self.ny, self.nx = eta.shape
        self.edge_index = build_edge_index_8(self.ny, self.nx)

        self.mean = None; self.std = None
        if standardize:
            m_acc, s_acc, ss = [], [], 0
            for (src, _) in pairs:
                e,ei,u1,v1,u2,v2,ff,yy = load_npz_2L_centered(src)
                y_norm = (yy - yy.min())/(yy.max()-yy.min()+1e-9)
                x = np.stack([e,ei,u1,v1,u2,v2,ff,y_norm], axis=0)  # (8,ny,nx)
                m_acc.append(x.mean(axis=(1,2)))
                s_acc.append(x.std (axis=(1,2)))
                ss += 1
                if ss >= sample_stats:
                    break
            mean = np.mean(np.stack(m_acc,0),0).astype(np.float32)  # (8,)
            std  = np.mean(np.stack(s_acc,0),0).astype(np.float32)  # (8,)
            std  = np.where(std<1e-8, 1.0, std)
            self.mean = torch.tensor(mean, dtype=torch.float32)     # (8,)
            self.std  = torch.tensor(std,  dtype=torch.float32)
        else:
            self.mean = torch.zeros(8, dtype=torch.float32)
            self.std  = torch.ones (8, dtype=torch.float32)

    def __len__(self): return len(self.pairs)

    def __getitem__(self, idx):
        sp, tp = self.pairs[idx]
        e,ei,u1,v1,u2,v2,ff,yy = load_npz_2L_centered(sp)
        y_norm = (yy - yy.min())/(yy.max()-yy.min()+1e-9)
        x = np.stack([e,ei,u1,v1,u2,v2,ff,y_norm], axis=0)  # (8,ny,nx)

        e2,ei2,u12,v12,u22,v22,_,_ = load_npz_2L_centered(tp)
        y = np.stack([e2,ei2,u12,v12,u22,v22], axis=0)      # (6,ny,nx)

        # to tensors [N,C]
        x = torch.from_numpy(x).view(8, -1).T   # (N,8)
        y = torch.from_numpy(y).view(6, -1).T   # (N,6)

        # standardize inputs (per-channel)
        x = (x - self.mean) / self.std

        return Data(x=x, y=y, edge_index=self.edge_index)

# -------------
# the model
# -------------
class SAGEBlock(nn.Module):
    def __init__(self, in_dim, out_dim, dropout):
        super().__init__()
        self.conv = SAGEConv(in_dim, out_dim, normalize=True)
        self.bn   = nn.BatchNorm1d(out_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x, edge_index):
        y = self.conv(x, edge_index)
        y = self.bn(y)
        y = torch.relu(y)
        y = self.drop(y)
        return x + y  # residual

class ResidGraphSAGE(nn.Module):
    def __init__(self, in_ch=8, out_ch=6, hidden=256, layers=6, dropout=0.05):
        super().__init__()
        self.inp = nn.Linear(in_ch, hidden)
        self.blocks = nn.ModuleList([SAGEBlock(hidden, hidden, dropout) for _ in range(layers)])
        self.out = nn.Linear(hidden, out_ch)
    def forward(self, x, edge_index):
        h = torch.relu(self.inp(x))
        for blk in self.blocks:
            h = blk(h, edge_index)
        return self.out(h)

# -----------------
# training routine
# -----------------
def split_by_ic(pairs):
    by_ic = {}
    for s,t in pairs:
        ic = os.path.basename(os.path.dirname(s))
        by_ic.setdefault(ic, []).append((s,t))
    ics = sorted(by_ic.keys())
    ntr = max(1, int(0.8*len(ics)))
    tr_ics = set(ics[:ntr])
    tr, va = [], []
    for ic,pls in by_ic.items():
        (tr if ic in tr_ics else va).extend(pls)
    return tr, va

def train_one_epoch(model, loader, opt, device):
    model.train()
    total = 0.0; denom = 0
    for batch in loader:
        batch = batch.to(device)
        opt.zero_grad(set_to_none=True)
        pred = model(batch.x, batch.edge_index)
        loss = nn.functional.mse_loss(pred, batch.y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        total += float(loss.detach().cpu()) * batch.num_nodes
        denom += batch.num_nodes
    return total / max(denom,1)

@torch.no_grad()
def eval_epoch(model, loader, device):
    model.eval()
    total = 0.0; denom = 0
    for batch in loader:
        batch = batch.to(device)
        pred = model(batch.x, batch.edge_index)
        loss = nn.functional.mse_loss(pred, batch.y)
        total += float(loss.cpu()) * batch.num_nodes
        denom += batch.num_nodes
    return total / max(denom,1)

def save_loss_plot_csv(history, ckpt_dir):
    os.makedirs(ckpt_dir, exist_ok=True)
    csv_path = os.path.join(ckpt_dir, "loss_history.csv")
    with open(csv_path, "w", newline="") as f:
        w = csv.writer(f); w.writerow(["epoch","train","val","lr"])
        for e,(tr,va,lr) in enumerate(history):
            w.writerow([e, tr, va, lr])
    try:
        import matplotlib.pyplot as plt
        tr = [h[0] for h in history]; va = [h[1] for h in history]
        plt.figure(figsize=(6,4))
        plt.plot(tr, label="train"); plt.plot(va, label="val")
        plt.yscale("log"); plt.xlabel("epoch"); plt.ylabel("MSE (log)")
        plt.grid(True, alpha=0.3); plt.legend()
        png = os.path.join(ckpt_dir, "loss_history.png")
        plt.savefig(png, dpi=140, bbox_inches="tight"); plt.close()
        print(f"[plot] saved {png} and {csv_path}")
    except Exception as e:
        print("[plot] skipping plot:", e)

# -------------
# main runner
# -------------
def run(
    root="/content/drive/MyDrive/klein_ckpt_2L_centers",          # <-- 2-LAYER snapshots root
    ckpt_dir="/content/drive/MyDrive/klein_gnn2L_ckpt_8nbr_AL",
    horizon=200,
    hidden=256,
    layers=6,
    dropout=0.05,
    epochs=200,
    batch_size=4,
    lr=2e-3,
    seed=42
):
    torch.manual_seed(seed); np.random.seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("[device]", device)

    pairs = scan_pairs(root, horizon=horizon)
    if not pairs:
        raise RuntimeError("No (src,tgt) pairs found. Are your 2-L snapshots present?")

    train_pairs, val_pairs = split_by_ic(pairs)
    print(f"IC-split pairs: train={len(train_pairs)}  val={len(val_pairs)}")
    if not val_pairs:
        rng = np.random.RandomState(seed)
        idx = rng.permutation(len(pairs))
        k   = max(1, int(0.9*len(pairs)))
        train_pairs = [pairs[i] for i in idx[:k]]
        val_pairs   = [pairs[i] for i in idx[k:] ]
        print(f"[fallback split] train={len(train_pairs)}  val={len(val_pairs)}")

    dtr = Klein2LPairs(train_pairs, standardize=True)
    dvl = Klein2LPairs(val_pairs,   standardize=True)  # recompute stats for val is fine

    loader_tr = GeoDataLoader(dtr, batch_size=batch_size, shuffle=True)
    loader_va = GeoDataLoader(dvl, batch_size=batch_size, shuffle=False)

    model = ResidGraphSAGE(in_ch=8, out_ch=6, hidden=hidden, layers=layers, dropout=dropout).to(device)
    opt    = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    sched  = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", factor=0.7, patience=6, min_lr=5e-5)

    os.makedirs(ckpt_dir, exist_ok=True)
    best = float("inf"); bad = 0; history = []
    ckpt_path  = os.path.join(ckpt_dir, "gnn2L_rollout.pt")
    stats_path = os.path.join(ckpt_dir, "norm_stats_2L.pth")
    torch.save({"mean":dtr.mean.numpy(), "std":dtr.std.numpy(), "ny":dtr.ny, "nx":dtr.nx}, stats_path)
    print(f"[norm] saved {stats_path}")

    for ep in range(1, epochs+1):
        tr = train_one_epoch(model, loader_tr, opt, device)
        va = eval_epoch(model, loader_va, device)
        sched.step(va)
        lr_now = opt.param_groups[0]["lr"]
        history.append((tr, va, lr_now))
        if ep % 5 == 0 or ep == 1:
            print(f"[{ep:04d}] train={tr:.6e}  val={va:.6e}  lr={lr_now:.2e}")
        if va < best - 1e-6:
            best = va; bad = 0
            torch.save(model.state_dict(), ckpt_path)
        else:
            bad += 1
        if bad >= 30:
            print("Early stopping."); break

    print("Done. Best val =", best)
    save_loss_plot_csv(history, ckpt_dir)

# ---------------
# Colab-friendly
# ---------------
if __name__=="__main__":
    run(
        root="/content/drive/MyDrive/klein_ckpt_2L_centers",
        ckpt_dir="/content/drive/MyDrive/klein_gnn2L_ckpt_8nbr_AL",
        horizon=200,
        hidden=256,
        layers=6,
        dropout=0.05,
        epochs=200,
        batch_size=4,
        lr=2e-3,
        seed=42
    )


[device] cpu
[scan] IC folders: ['kelvin_m3', 'mixed_RH2_modon', 'mixed_RH3_modon', 'mixed_RH4_2mod', 'modon', 'rh4', 'rh4_plus_kelvin_m3', 'rh5', 'twin_modons', 'twin_modons_plus_kelvin_pkt']
[scan] found 60 pairs across 10 ICs (70 files).
IC-split pairs: train=48  val=12
[norm] saved /content/drive/MyDrive/klein_gnn2L_ckpt_8nbr_AL/norm_stats_2L.pth
[0001] train=2.851215e+03  val=9.207295e+02  lr=2.00e-03
[0005] train=2.288806e+03  val=8.586951e+02  lr=2.00e-03
[0010] train=2.092315e+03  val=9.743330e+02  lr=2.00e-03
[0015] train=1.938643e+03  val=9.621431e+02  lr=1.40e-03
[0020] train=1.803970e+03  val=9.581298e+02  lr=9.80e-04


# Evaluation

In [None]:
# eval_gnn2L.py
# Inference & evaluation for 2-layer Klein-β GNN (GraphSAGE, 8->6 channels).
# - Single pair eval (SRC -> predict -> compare to TGT).
# - Rollout within an IC directory (iterative t -> t+H -> t+2H ...), with optional scoring vs FD.
# Saves predictions/plots/CSV metrics.

import os, glob, csv, math, json
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt

# =========================
# CONFIG (EDIT ME)
# =========================
CKPT_DIR = "/content/drive/MyDrive/klein_gnn2L_ckpt_8nbr_AL"
ROOT_2L  = "/content/drive/MyDrive/klein_ckpt_2L_centers"  # FD outputs (per-IC subfolders)
HORIZON  = 200   # must match training horizon

# --- MODE ---
MODE = "single"   # "single" or "rollout_ic" or "rollout_all"

# SINGLE: pick one IC and steps
SINGLE_SRC = f"{ROOT_2L}/rh4/klein_step_000000.npz"
SINGLE_TGT = f"{ROOT_2L}/rh4/klein_step_{HORIZON:06d}.npz"  # optional; set to None if not available

# ROLLOUT: choose one IC (or MODE="rollout_all" to do all ICs)
ROLLOUT_IC         = "rh4"   # folder name under ROOT_2L
ROLLOUT_START_STEP = 0
ROLLOUT_STEPS      = 6       # number of model steps (each of size HORIZON)
SAVE_PRED_NPZ      = True
PLOT_COMPARE       = True    # quick True/Pred/Error maps for η, η_i at each saved step
COMPARE_VELS       = False   # include u_c,v_c panels

# =========================
# Graph/model definitions (match training)
# =========================
try:
    from torch_geometric.nn import SAGEConv
except Exception:
    raise RuntimeError("This script needs torch_geometric. In Colab run:\n"
                       "!pip -q install torch-geometric torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.2.0+cpu.html")

def build_edge_index_8(ny:int, nx:int):
    src, dst = [], []
    for j in range(ny):
        for i in range(nx):
            me = j*nx + i
            for dj in (-1,0,1):
                jj = j + dj
                if jj < 0 or jj >= ny:  # clamped y
                    continue
                for di in (-1,0,1):
                    if di==0 and dj==0: continue
                    ii = (i + di) % nx   # periodic x
                    src.append(me); dst.append(jj*nx + ii)
    return torch.tensor([src, dst], dtype=torch.long)

class SAGEBlock(nn.Module):
    def __init__(self, in_dim, out_dim, dropout):
        super().__init__()
        self.conv = SAGEConv(in_dim, out_dim, normalize=True)
        self.bn   = nn.BatchNorm1d(out_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x, edge_index):
        y = self.conv(x, edge_index)
        y = self.bn(y)
        y = torch.relu(y)
        y = self.drop(y)
        return x + y

class ResidGraphSAGE(nn.Module):
    def __init__(self, in_ch=8, out_ch=6, hidden=256, layers=6, dropout=0.05):
        super().__init__()
        self.inp = nn.Linear(in_ch, hidden)
        self.blocks = nn.ModuleList([SAGEBlock(hidden, hidden, dropout) for _ in range(layers)])
        self.out = nn.Linear(hidden, out_ch)
    def forward(self, x, edge_index):
        h = torch.relu(self.inp(x))
        for blk in self.blocks:
            h = blk(h, edge_index)
        return self.out(h)

# =========================
# IO utilities
# =========================
def load_npz_2L_centered(path:str):
    d = np.load(path)
    eta  = d["eta" ].astype(np.float32)
    etai = d["etai"].astype(np.float32)
    uc1  = d["uc1" ].astype(np.float32)
    vc1  = d["vc1" ].astype(np.float32)
    uc2  = d["uc2" ].astype(np.float32)
    vc2  = d["vc2" ].astype(np.float32)
    f    = d["f"   ].astype(np.float32)
    y_m  = d["y_m" ].astype(np.float32)  # (ny,) or (ny,nx)
    if y_m.ndim == 1:
        y2 = np.repeat(y_m[:,None], eta.shape[1], axis=1).astype(np.float32)
    else:
        y2 = y_m
    meta = {k: d[k].item() if d[k].ndim==0 else d[k] for k in d.files if k not in
            ["eta","etai","uc1","vc1","uc2","vc2","f","y_m","uc_bt","vc_bt","uc_bc","vc_bc","h1","h2"]}
    return (eta,etai,uc1,vc1,uc2,vc2,f,y2), meta

def y_to_ynorm(y2):
    return (y2 - y2.min())/(y2.max()-y2.min()+1e-9)

def stack_inputs(eta,etai,uc1,vc1,uc2,vc2,f,yn):
    return np.stack([eta,etai,uc1,vc1,uc2,vc2,f,yn], axis=0)  # (8,ny,nx)

def stack_targets(eta,etai,uc1,vc1,uc2,vc2):
    return np.stack([eta,etai,uc1,vc1,uc2,vc2], axis=0)       # (6,ny,nx)

# =========================
# Adapter for inference (handles norming & graph mapping)
# =========================
class InferenceAdapter2L:
    def __init__(self, ckpt_dir, hidden=256, layers=6, dropout=0.05, device=None):
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
        stats_path = os.path.join(ckpt_dir, "norm_stats_2L.pth")
        w_path     = os.path.join(ckpt_dir, "gnn2L_rollout.pt")

        stats = torch.load(stats_path, map_location="cpu")
        self.ny, self.nx = int(stats["ny"]), int(stats["nx"])
        mean = stats["mean"].astype(np.float32); std = stats["std"].astype(np.float32)
        self.mean = torch.tensor(mean, dtype=torch.float32, device=self.device).view(1,8)
        self.std  = torch.tensor(std,  dtype=torch.float32, device=self.device).view(1,8)

        self.edge_index = build_edge_index_8(self.ny, self.nx).to(self.device)
        self.model = ResidGraphSAGE(in_ch=8, out_ch=6, hidden=hidden, layers=layers, dropout=dropout).to(self.device)
        sd = torch.load(w_path, map_location=self.device)
        self.model.load_state_dict(sd); self.model.eval()

    @torch.no_grad()
    def predict_next(self, eta,etai,uc1,vc1,uc2,vc2, f, y2):
        # Build features
        yn = y_to_ynorm(y2)
        x = stack_inputs(eta,etai,uc1,vc1,uc2,vc2,f,yn)           # (8,ny,nx)
        x = torch.from_numpy(x).to(self.device).view(8,-1).T      # (N,8)
        x = (x - self.mean) / self.std
        y = self.model(x, self.edge_index).view(self.ny,self.nx,6).permute(2,0,1)  # (6,ny,nx)
        y = y.detach().cpu().numpy().astype(np.float32)
        e,p_i,u1,v1,u2,v2 = y
        return e,p_i,u1,v1,u2,v2

# =========================
# Metrics & plotting
# =========================
def rmse(a,b): return float(np.sqrt(np.mean((a-b)**2)))
def corr(a,b):
    aa=a-a.mean(); bb=b-b.mean()
    den = np.sqrt((aa**2).mean())*np.sqrt((bb**2).mean())+1e-12
    return float((aa*bb).mean()/den)

def score_fields(tru, pred):
    keys = ["eta","etai","uc1","vc1","uc2","vc2"]
    scores = {}
    for k, (A,B) in zip(keys, zip(tru, pred)):
        scores[f"rmse_{k}"] = rmse(A,B)
        scores[f"corr_{k}"] = corr(A,B)
    return scores

def quick_plot_compare(ic_dir, step_pred, fields_true, fields_pred, y2, include_vel=False):
    eta_t, etai_t, uc1_t, vc1_t, uc2_t, vc2_t = fields_true
    eta_p, etai_p, uc1_p, vc1_p, uc2_p, vc2_p = fields_pred
    outdir = os.path.join(ic_dir, "pred_gnn2L"); os.makedirs(outdir, exist_ok=True)

    xkm = np.linspace(0, eta_t.shape[1]-1, eta_t.shape[1])
    ykm = (y2 - y2.mean(axis=0)).astype(np.float32)  # just to get centered y for axis ticks
    Y,X = np.meshgrid(np.arange(eta_t.shape[0]), np.arange(eta_t.shape[1]), indexing="ij")

    def panel(tru, pred, title, fname):
        err = pred - tru
        fig, axs = plt.subplots(1,3, figsize=(12,3.6), constrained_layout=True)
        im0=axs[0].pcolormesh(tru, shading="auto"); axs[0].set_title(f"True {title}"); plt.colorbar(im0,ax=axs[0])
        im1=axs[1].pcolormesh(pred, shading="auto"); axs[1].set_title(f"Pred {title}"); plt.colorbar(im1,ax=axs[1])
        im2=axs[2].pcolormesh(err, shading="auto"); axs[2].set_title(f"Error {title}"); plt.colorbar(im2,ax=axs[2])
        fp = os.path.join(outdir, f"{fname}_step_{step_pred:06d}.png")
        plt.savefig(fp, dpi=140); plt.close()

    panel(eta_t,  eta_p,  "η",   "eta")
    panel(etai_t, etai_p, "η_i", "etai")
    if include_vel:
        panel(uc1_t, uc1_p, "u1_c", "uc1")
        panel(vc1_t, vc1_p, "v1_c", "vc1")
        panel(uc2_t, uc2_p, "u2_c", "uc2")
        panel(vc2_t, vc2_p, "v2_c", "vc2")

# =========================
# Single pair evaluation
# =========================
def evaluate_single(src_path, tgt_path=None):
    (eta,etai,uc1,vc1,uc2,vc2,f,y2), meta = load_npz_2L_centered(src_path)
    ny, nx = eta.shape
    adapter = InferenceAdapter2L(CKPT_DIR, device="cpu")  # force CPU for portability
    e_p, ei_p, u1_p, v1_p, u2_p, v2_p = adapter.predict_next(eta,etai,uc1,vc1,uc2,vc2,f,y2)

    ic_dir = os.path.dirname(src_path)
    out_dir = os.path.join(ic_dir, "pred_gnn2L"); os.makedirs(out_dir, exist_ok=True)
    base_step = int(os.path.basename(src_path).split("_")[-1].split(".")[0])
    step_pred = base_step + HORIZON

    # Save prediction npz (centered) for downstream use
    if SAVE_PRED_NPZ:
        np.savez_compressed(
            os.path.join(out_dir, f"pred_step_{step_pred:06d}.npz"),
            eta=e_p, etai=ei_p, uc1=u1_p, vc1=v1_p, uc2=u2_p, vc2=v2_p, f=f, y_m=y2[:,0].astype(np.float32)
        )

    if tgt_path and os.path.exists(tgt_path):
        (eta_t,etai_t,uc1_t,vc1_t,uc2_t,vc2_t,_,_), _ = load_npz_2L_centered(tgt_path)
        scores = score_fields(
            (eta_t,etai_t,uc1_t,vc1_t,uc2_t,vc2_t),
            (e_p,  ei_p,  u1_p,  v1_p,  u2_p,  v2_p)
        )
        # CSV one-liner
        csvp = os.path.join(out_dir, "single_eval.csv")
        with open(csvp, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["metric","value"])
            for k,v in scores.items(): w.writerow([k,v])
        # Plots
        if PLOT_COMPARE:
            quick_plot_compare(ic_dir, step_pred,
                               (eta_t,etai_t,uc1_t,vc1_t,uc2_t,vc2_t),
                               (e_p,  ei_p,  u1_p,  v1_p,  u2_p,  v2_p),
                               y2, include_vel=COMPARE_VELS)
        print("[single] Scores:", scores)
    else:
        print("[single] No target provided; saved prediction only.")

# =========================
# Rollout within an IC
# =========================
def collect_ic_files(ic_dir):
    files = sorted(glob.glob(os.path.join(ic_dir, "klein_step_*.npz")))
    steps = [int(os.path.basename(p).split("_")[-1].split(".")[0]) for p in files]
    return dict(zip(steps, files))

def rollout_ic(ic_name, start_step, steps_to_run):
    ic_dir = os.path.join(ROOT_2L, ic_name)
    idx = collect_ic_files(ic_dir)
    if start_step not in idx:
        raise FileNotFoundError(f"Start step {start_step} not found in {ic_dir}")

    # load start
    (eta,etai,uc1,vc1,uc2,vc2,f,y2), meta = load_npz_2L_centered(idx[start_step])
    adapter = InferenceAdapter2L(CKPT_DIR, device="cpu")
    out_dir = os.path.join(ic_dir, "pred_gnn2L"); os.makedirs(out_dir, exist_ok=True)

    # metrics CSV
    csvp = os.path.join(out_dir, f"rollout_from_{start_step:06d}.csv")
    with open(csvp, "w", newline="") as f:
        w = csv.writer(f)
        header = ["step","rmse_eta","rmse_etai","rmse_uc1","rmse_vc1","rmse_uc2","rmse_vc2",
                  "corr_eta","corr_etai","corr_uc1","corr_vc1","corr_uc2","corr_vc2"]
        w.writerow(header)

        base = start_step
        for k in range(1, steps_to_run+1):
            # predict next
            e_p, ei_p, u1_p, v1_p, u2_p, v2_p = adapter.predict_next(eta,etai,uc1,vc1,uc2,vc2,f,y2)
            tgt_step = base + k*HORIZON

            # save pred
            if SAVE_PRED_NPZ:
                np.savez_compressed(
                    os.path.join(out_dir, f"pred_step_{tgt_step:06d}.npz"),
                    eta=e_p, etai=ei_p, uc1=u1_p, vc1=v1_p, uc2=u2_p, vc2=v2_p, f=f, y_m=y2[:,0].astype(np.float32)
                )

            # score vs FD if available
            if tgt_step in idx:
                (eta_t,etai_t,uc1_t,vc1_t,uc2_t,vc2_t,_,_), _ = load_npz_2L_centered(idx[tgt_step])
                sc = score_fields(
                    (eta_t,etai_t,uc1_t,vc1_t,uc2_t,vc2_t),
                    (e_p,  ei_p,  u1_p,  v1_p,  u2_p,  v2_p)
                )
                w.writerow([tgt_step] + [sc[k] for k in [
                    "rmse_eta","rmse_etai","rmse_uc1","rmse_vc1","rmse_uc2","rmse_vc2",
                    "corr_eta","corr_etai","corr_uc1","corr_vc1","corr_uc2","corr_vc2"]])
                if PLOT_COMPARE:
                    quick_plot_compare(ic_dir, tgt_step,
                                       (eta_t,etai_t,uc1_t,vc1_t,uc2_t,vc2_t),
                                       (e_p,  ei_p,  u1_p,  v1_p,  u2_p,  v2_p),
                                       y2, include_vel=COMPARE_VELS)

            # advance state for next step (iterative rollout)
            eta,etai,uc1,vc1,uc2,vc2 = e_p, ei_p, u1_p, v1_p, u2_p, v2_p

    print(f"[rollout] Done IC={ic_name}; wrote {csvp}")

def rollout_all(start_step, steps_to_run):
    ic_dirs = sorted([d for d in glob.glob(os.path.join(ROOT_2L, "*")) if os.path.isdir(d)])
    summary_csv = os.path.join(ROOT_2L, f"rollout_summary_from_{start_step:06d}.csv")
    with open(summary_csv, "w", newline="") as fsum:
        wsum = csv.writer(fsum)
        wsum.writerow(["IC","mean_rmse_eta","mean_rmse_etai","mean_rmse_uc1","mean_rmse_vc1","mean_rmse_uc2","mean_rmse_vc2"])
        for ic_path in ic_dirs:
            ic = os.path.basename(ic_path)
            try:
                rollout_ic(ic, start_step, steps_to_run)
                # aggregate per-IC
                csvp = os.path.join(ic_path, "pred_gnn2L", f"rollout_from_{start_step:06d}.csv")
                arr = np.genfromtxt(csvp, delimiter=",", names=True)
                means = [float(np.mean(arr[f"rmse_{k}"])) for k in ["eta","etai","uc1","vc1","uc2","vc2"]]
                wsum.writerow([ic] + means)
            except Exception as e:
                print(f"[rollout_all] Skipping {ic}: {e}")
    print(f"[rollout_all] Wrote {summary_csv}")

# =========================
# Main
# =========================
if __name__ == "__main__":
    os.makedirs(CKPT_DIR, exist_ok=True)
    if MODE == "single":
        evaluate_single(SINGLE_SRC, SINGLE_TGT)
    elif MODE == "rollout_ic":
        rollout_ic(ROLLOUT_IC, ROLLOUT_START_STEP, ROLLOUT_STEPS)
    elif MODE == "rollout_all":
        rollout_all(ROLLOUT_START_STEP, ROLLOUT_STEPS)
    else:
        raise ValueError("MODE must be one of {'single','rollout_ic','rollout_all'}")


# Building ensemble using SV ensemble generator for 2-layer GNN

In [None]:
# sv_ensemble_gnn2L.py
# Singular-vector ensemble generator around a chosen IC/step for your 2-layer GNN.
# - Matrix-free power iteration on J^T J using autograd JVP/VJP (no dense Jacobians).
# - Builds M ensemble members from top-r singular vectors; optional symmetric pairs.
# - Optional "breeding" rescale so ||F(x0+δx)-F(x0)|| hits a target growth norm.
#
# Outputs:
#   ROOT_2L/<IC>/ens_sv/<tag>/
#       member_####_init_step_<s>.npz  (eta,etai,uc1,vc1,uc2,vc2,f,y_m)
#       member_####_tplus_step_<s+H>.npz (same keys; model forecast)
#   plus a CSV with the singular values and per-member growth.

import os, csv, glob
import numpy as np
import torch
from torch import nn

# ======== CONFIG (edit as needed) ========
CKPT_DIR = "/content/drive/MyDrive/klein_gnn2L_ckpt_8nbr_AL"   # trained 2L GNN
ROOT_2L  = "/content/drive/MyDrive/klein_ckpt_2L_centers"      # FD snapshots (per-IC)
IC_NAME  = "rh4"                                               # IC folder to work on
START_STEP = 0                                                 # base snapshot step (e.g., 0)
HORIZON    = 200                                               # must match training
TAG        = "sv_r6_pow10_breed2"                              # subfolder tag for outputs

# SV / ensemble knobs
RANK        = 6        # how many top singular vectors to compute
POW_ITERS   = 10       # power iterations per mode
BREED_ITERS = 2        # 0 = off; else rescale δx via F to hit target forecast norm
M_MEMBERS   = 10       # total members (build symmetric pairs if even)

# Target amplitudes
# (used to scale δx channel-wise *before* breeding; tweak to match your spread desires)
TARGET_STD = {
    "eta":  0.5,   # meters (free surface)
    "etai": 0.5,   # meters (interface)
    "vel":  0.20   # m/s   (applies to uc1,vc1,uc2,vc2)
}
# Target growth norm at t+H (used by breeding rescale) = RMS over all 6 channels
TARGET_GROWTH_RMS = 0.30  # in native units: m and m/s mixed via plain RMS across channels

# Device
DEVICE = "cpu"   # keep CPU-safe; switch to "cuda" if you have a GPU

# ========================================

# ---- Minimal model defs (must match training) ----
try:
    from torch_geometric.nn import SAGEConv
except Exception:
    raise RuntimeError("Needs torch_geometric. In Colab:\n"
                       "!pip -q install torch-geometric torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.2.0+cpu.html")

def build_edge_index_8(ny:int, nx:int):
    src, dst = [], []
    for j in range(ny):
        for i in range(nx):
            me = j*nx + i
            for dj in (-1,0,1):
                jj = j + dj
                if jj < 0 or jj >= ny:  # clamped y
                    continue
                for di in (-1,0,1):
                    if di==0 and dj==0: continue
                    ii = (i + di) % nx   # periodic x
                    src.append(me); dst.append(jj*nx + ii)
    return torch.tensor([src, dst], dtype=torch.long)

class SAGEBlock(nn.Module):
    def __init__(self, dim, dropout):
        super().__init__()
        self.conv = SAGEConv(dim, dim, normalize=True)
        self.bn   = nn.BatchNorm1d(dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x, edge_index):
        y = self.conv(x, edge_index)
        y = self.bn(y)
        y = torch.relu(y)
        y = self.drop(y)
        return x + y

class ResidGraphSAGE(nn.Module):
    def __init__(self, in_ch=8, out_ch=6, hidden=256, layers=6, dropout=0.05):
        super().__init__()
        self.inp = nn.Linear(in_ch, hidden)
        self.blocks = nn.ModuleList([SAGEBlock(hidden, dropout) for _ in range(layers)])
        self.out = nn.Linear(hidden, out_ch)
    def forward(self, x, edge_index):
        h = torch.relu(self.inp(x))
        for blk in self.blocks:
            h = blk(h, edge_index)
        return self.out(h)

# ---- IO ----
def load_npz_2L_centered(path:str):
    d = np.load(path)
    eta  = d["eta" ].astype(np.float32)
    etai = d["etai"].astype(np.float32)
    uc1  = d["uc1" ].astype(np.float32)
    vc1  = d["vc1" ].astype(np.float32)
    uc2  = d["uc2" ].astype(np.float32)
    vc2  = d["vc2" ].astype(np.float32)
    f    = d["f"   ].astype(np.float32)
    y_m  = d["y_m" ].astype(np.float32)  # (ny,) or (ny,nx)
    if y_m.ndim == 1:
        y2 = np.repeat(y_m[:,None], eta.shape[1], axis=1).astype(np.float32)
    else:
        y2 = y_m
    return (eta,etai,uc1,vc1,uc2,vc2,f,y2)

def save_npz_centered(path, eta,etai,uc1,vc1,uc2,vc2, f, y2):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    np.savez_compressed(path,
        eta=eta.astype(np.float32), etai=etai.astype(np.float32),
        uc1=uc1.astype(np.float32), vc1=vc1.astype(np.float32),
        uc2=uc2.astype(np.float32), vc2=vc2.astype(np.float32),
        f=f.astype(np.float32), y_m=y2[:,0].astype(np.float32)
    )

# ---- Adapter with grads enabled ----
class GNN2L_AdapterGrad:
    def __init__(self, ckpt_dir, device="cpu", hidden=256, layers=6, dropout=0.05):
        stats = torch.load(os.path.join(ckpt_dir, "norm_stats_2L.pth"), map_location="cpu")
        self.ny, self.nx = int(stats["ny"]), int(stats["nx"])
        mean = stats["mean"].astype(np.float32); std = stats["std"].astype(np.float32)
        self.mean = torch.tensor(mean, dtype=torch.float32, device=device).view(1,8)
        self.std  = torch.tensor(std,  dtype=torch.float32, device=device).view(1,8)

        self.edge_index = build_edge_index_8(self.ny, self.nx).to(device)
        self.model = ResidGraphSAGE(in_ch=8, out_ch=6, hidden=hidden, layers=layers, dropout=dropout).to(device)
        sd = torch.load(os.path.join(ckpt_dir, "gnn2L_rollout.pt"), map_location=device)
        self.model.load_state_dict(sd)
        self.model.eval()  # BN/Dropout frozen, but gradients OK

        self.device = device

    def _pack_x(self, eta,etai,uc1,vc1,uc2,vc2, f, y2):
        # Build features (N,8) in torch, with gradients on the 6 dynamic channels only
        yy = (y2 - y2.min())/(y2.max()-y2.min()+1e-9)
        X8 = np.stack([eta,etai,uc1,vc1,uc2,vc2, f, yy], axis=0).astype(np.float32)  # (8,ny,nx)
        X8 = torch.from_numpy(X8).to(self.device).view(8, -1).T  # (N,8)
        X8.requires_grad_(False)
        return X8

    def forward_flat(self, x6_flat, const_feat):
        """
        x6_flat: (N*6,) torch float (requires_grad as needed)
        const_feat: dict with 'f' (ny,nx) and 'y2' (ny,nx) numpy arrays
        returns y6_flat: (N*6,) torch (no detach!)
        """
        N = self.ny * self.nx
        x6 = x6_flat.view(N, 6)

        # pull out channels and rebuild the 8-channel input
        eta,etai,uc1,vc1,uc2,vc2 = x6.unbind(dim=1)  # each (N,)

        # reshape to (ny,nx) for packing with f,y
        def unvec(v): return v.view(self.ny, self.nx)
        etaT  = unvec(eta).detach().cpu().numpy()
        etaiT = unvec(etai).detach().cpu().numpy()
        uc1T  = unvec(uc1).detach().cpu().numpy()
        vc1T  = unvec(vc1).detach().cpu().numpy()
        uc2T  = unvec(uc2).detach().cpu().numpy()
        vc2T  = unvec(vc2).detach().cpu().numpy()

        X8 = self._pack_x(etaT,etaiT,uc1T,vc1T,uc2T,vc2T, const_feat["f"], const_feat["y2"])  # (N,8) torch
        # Replace the dynamic 6 columns by x6 (to keep grad)
        X8 = X8.clone()
        X8[:,0:6] = x6

        # standardize and forward
        X8n = (X8 - self.mean) / self.std
        y6  = self.model(X8n, self.edge_index)          # (N,6)
        return y6.reshape(-1)                           # flatten (N*6,)

# ---- JVP / VJP wrappers ----
def Jv(adapter, x_flat, v_flat, const_feat):
    # y, Jv via autograd.functional.jvp
    def fun(inp):
        return adapter.forward_flat(inp, const_feat)
    y, jvp = torch.autograd.functional.jvp(fun, (x_flat,), (v_flat,), create_graph=False)
    return jvp  # (N*6,)

def JTz(adapter, x_flat, z_flat, const_feat):
    # vjp: gradient wrt input given z as grad on outputs
    def fun(inp):
        return adapter.forward_flat(inp, const_feat)
    y, vjp_fn = torch.autograd.functional.vjp(fun, x_flat, create_graph=False)
    (vjp,) = vjp_fn(z_flat)
    return vjp  # (N*6,)

# ---- Power iteration on J^T J (matrix-free), with Gram-Schmidt deflation ----
def gram_schmidt(v, basis):
    for b in basis:
        v -= (v @ b) * b
    return v

def leading_singular_vectors(adapter, x0_flat, const_feat, rank=4, iters=10):
    N6 = x0_flat.numel()
    Vs = []          # right singular vectors (unit)
    sigmas = []      # singular values
    torch.manual_seed(123)
    for r in range(rank):
        v = torch.randn(N6, device=x0_flat.device, dtype=x0_flat.dtype)
        v = v / (torch.norm(v) + 1e-12)
        # Deflate against previous modes
        if Vs:
            v = gram_schmidt(v, Vs)
            v = v / (torch.norm(v) + 1e-12)
        sigma = None
        for _ in range(iters):
            Jv_vec = Jv(adapter, x0_flat, v, const_feat)
            sigma = torch.norm(Jv_vec) + 1e-12   # current singular value estimate
            w = JTz(adapter, x0_flat, Jv_vec, const_feat)
            # Deflate and renormalize
            if Vs:
                w = gram_schmidt(w, Vs)
            v = w / (torch.norm(w) + 1e-12)
        # finalize one mode
        Vs.append(v.detach() / (torch.norm(v.detach()) + 1e-12))
        sigmas.append(float(sigma.detach().cpu()))
        print(f"[SV] mode {r+1}: sigma ≈ {sigmas[-1]:.4e}")
    return Vs, sigmas

# ---- Utilities: reshape, RMS, channel scaling ----
def vec6_from_arrays(eta,etai,uc1,vc1,uc2,vc2):
    arr = np.stack([eta,etai,uc1,vc1,uc2,vc2], axis=0).astype(np.float32)  # (6,ny,nx)
    return torch.from_numpy(arr.reshape(-1))  # (N*6,)

def arrays6_from_vec(v, ny, nx):
    a = v.detach().cpu().numpy().reshape(6, ny, nx)
    return a[0],a[1],a[2],a[3],a[4],a[5]

def channelwise_rms(eta,etai,uc1,vc1,uc2,vc2):
    def rms(a): return float(np.sqrt(np.mean(a*a)))
    return {
        "eta": rms(eta), "etai": rms(etai),
        "uc1": rms(uc1), "vc1": rms(vc1), "uc2": rms(uc2), "vc2": rms(vc2)
    }

def scale_channels_to_targets(eta,etai,uc1,vc1,uc2,vc2, targets):
    # scale each channel so its RMS matches targets (vel channels share 'vel' target)
    stats = channelwise_rms(eta,etai,uc1,vc1,uc2,vc2)
    def s(old, tgt):
        return 1.0 if old < 1e-12 else (tgt / old)
    s_eta  = s(stats["eta"],  targets["eta"])
    s_etai = s(stats["etai"], targets["etai"])
    s_vel  = [s(stats["uc1"], targets["vel"]),
              s(stats["vc1"], targets["vel"]),
              s(stats["uc2"], targets["vel"]),
              s(stats["vc2"], targets["vel"])]
    eta  = eta  * s_eta
    etai = etai * s_etai
    uc1  = uc1  * s_vel[0]; vc1 = vc1 * s_vel[1]
    uc2  = uc2  * s_vel[2]; vc2 = vc2 * s_vel[3]
    return eta,etai,uc1,vc1,uc2,vc2

def rms6(eta,etai,uc1,vc1,uc2,vc2):
    z = np.stack([eta,etai,uc1,vc1,uc2,vc2], axis=0)
    return float(np.sqrt(np.mean(z*z)))

# ---- Breeding-style rescale (optional) ----
def breed_once(adapter, x0_flat, dx_flat, const_feat, target_growth_rms):
    # Nonlinear forecast growth of perturbation over one step:
    y0 = adapter.forward_flat(x0_flat, const_feat)              # (N*6,)
    y1 = adapter.forward_flat(x0_flat + dx_flat, const_feat)    # (N*6,)
    dy = (y1 - y0).detach()
    g = torch.sqrt(torch.mean(dy*dy)).item()
    if g < 1e-12:
        return dx_flat, g
    scale = target_growth_rms / g
    return dx_flat * scale, g

# ---- Main routine ----
def main():
    device = torch.device(DEVICE)
    # Load base snapshot
    ic_dir = os.path.join(ROOT_2L, IC_NAME)
    base_path = os.path.join(ic_dir, f"klein_step_{START_STEP:06d}.npz")
    if not os.path.exists(base_path):
        raise FileNotFoundError(base_path)
    eta0,etai0,uc10,vc10,uc20,vc20, f, y2 = load_npz_2L_centered(base_path)
    ny, nx = eta0.shape
    const_feat = {"f": f, "y2": y2}

    # Adapter + base state tensor
    adapter = GNN2L_AdapterGrad(CKPT_DIR, device=device)
    assert adapter.ny==ny and adapter.nx==nx, "Grid mismatch between model and data."
    x0_flat = vec6_from_arrays(eta0,etai0,uc10,vc10,uc20,vc20).to(device)

    # Compute top-RANK singular vectors and singular values
    Vs, sigmas = leading_singular_vectors(adapter, x0_flat, const_feat,
                                          rank=RANK, iters=POW_ITERS)

    # Build ensemble members
    out_dir = os.path.join(ic_dir, "ens_sv", TAG); os.makedirs(out_dir, exist_ok=True)
    # Save SV metadata
    with open(os.path.join(out_dir, "singular_values.csv"), "w", newline="") as fcsv:
        w = csv.writer(fcsv); w.writerow(["mode","sigma"])
        for i,s in enumerate(sigmas, 1): w.writerow([i, s])

    # Prepare weights (random combos of top RANK SVs)
    rng = np.random.RandomState(1234)
    weights = rng.randn(M_MEMBERS, RANK).astype(np.float32)
    # Optional symmetric pairs if even number of members
    if M_MEMBERS % 2 == 0:
        for i in range(M_MEMBERS//2):
            weights[i+M_MEMBERS//2] = -weights[i]

    # Build members
    for m in range(M_MEMBERS):
        # combine SVs
        dx = torch.zeros_like(x0_flat)
        for k in range(RANK):
            dx = dx + weights[m, k] * Vs[k]

        # reshape to channels for channel-wise scaling
        e,ei,u1,v1,u2,v2 = arrays6_from_vec(dx, ny, nx)
        e,ei,u1,v1,u2,v2 = scale_channels_to_targets(e,ei,u1,v1,u2,v2, TARGET_STD)
        dx = vec6_from_arrays(e,ei,u1,v1,u2,v2).to(device)

        # breeding rescale (optional)
        growths = []
        for _ in range(BREED_ITERS):
            dx, g = breed_once(adapter, x0_flat, dx, const_feat, TARGET_GROWTH_RMS)
            growths.append(g)

        # Initial and one-step forecast states for this member
        x_init = (x0_flat + dx).detach().cpu()
        y_pred = adapter.forward_flat(x_init.to(device), const_feat).detach().cpu()

        # Save init
        e,ei,u1,v1,u2,v2 = arrays6_from_vec(x_init, ny, nx)
        init_path = os.path.join(out_dir, f"member_{m+1:04d}_init_step_{START_STEP:06d}.npz")
        save_npz_centered(init_path, e,ei,u1,v1,u2,v2, f, y2)

        # Save t+H forecast (useful for quick sanity & spread vs. control)
        ep, eip, u1p, v1p, u2p, v2p = arrays6_from_vec(y_pred, ny, nx)
        tplus_path = os.path.join(out_dir, f"member_{m+1:04d}_tplus_step_{START_STEP+HORIZON:06d}.npz")
        save_npz_centered(tplus_path, ep,eip,u1p,v1p,u2p,v2p, f, y2)

        # Record growth info
        with open(os.path.join(out_dir, "members_growth.csv"), "a", newline="") as fg:
            w = csv.writer(fg)
            if m == 0:
                w.writerow(["member","rms_growth_iter1","rms_growth_iter2","..."])
            row = [m+1] + growths
            w.writerow(row)

        print(f"[ens] member {m+1:04d} saved. growths={['%.3f'%g for g in growths]}")

    print(f"Done. Ensemble in {out_dir}")

if __name__ == "__main__":
    main()
