<a href="https://colab.research.google.com/github/Misha-private/Demo-repo/blob/main/ML_PINN_GNN2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os, sys, csv, traceback, random
import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

# ================== I/O ==================
OUTDIR = "outputs_PINN_Cgrid_anchor_colloc_curriculum_multicase"
os.makedirs(OUTDIR, exist_ok=True)
def log(msg):
    print(msg, flush=True)
    with open(os.path.join(OUTDIR,"run_log.txt"),"a") as f: f.write(msg+"\n")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_dtype(torch.float32)
torch.manual_seed(0); np.random.seed(0); random.seed(0)
log(f"Device: {device}")

# ================== Grid / Physics ==================
NX, NY = 64, 64
LX, LY = 1.5e6, 7.5e5
dx, dy = LX/NX, LY/NY
GRAV, H0   = 9.81, 100.0
nu      = 70.0           # slightly higher viscosity for stability
f0      = 1e-4
beta    = 2e-11

HOURS      = 6
DT_SEC     = 2.0
SEC_PER_HR = 3600.0
NT         = int(HOURS*SEC_PER_HR/DT_SEC)

SNAP_T     = [int(k*SEC_PER_HR/DT_SEC) for k in range(HOURS+1)]
ANCHOR_HRS = [float(k) for k in range(HOURS+1)]     # 0..6
MID_HRS    = [k+0.5 for k in range(HOURS)]          # 0.5..5.5
COLL_HRS   = sorted(set(ANCHOR_HRS + MID_HRS))      # 13 times

# coordinates
x  = torch.linspace(0, LX, NX,    device=device)    # centers x
y  = torch.linspace(0, LY, NY,    device=device)    # centers y
xu = torch.linspace(0, LX, NX+1,  device=device)    # u x-faces (periodic)
yv = torch.linspace(0, LY, NY+1,  device=device)    # v y-faces (walls)

# ================== Helpers ==================
def assert_shape(t, shape, name):
    if tuple(t.shape) != tuple(shape):
        raise RuntimeError(f"{name} shape {tuple(t.shape)} != {tuple(shape)}")

# ----- 2D BCs (for operators/truth) -----
def bc_uv_2d(h,u,v):
    if h.shape[0] > 1:
        h[0 , :] = h[1 , :]
        h[-1, :] = h[-2, :]
    h[:, 0]  = h[:, -2]
    h[:, -1] = h[:, 1]

    if u.shape[0] > 1:
        u[0 , :] = u[1 , :]
        u[-1, :] = u[-2, :]
    u[:, 0]  = u[:, -2]
    u[:, -1] = u[:, 1]

    v[:, 0]  = v[:, -1]
    v[0 , :] = 0.0
    v[-1, :] = 0.0
    return h,u,v

# ----- Batched BCs (B,...) used in NN forward -----
def bc_uv_B(h,u,v):
    B, NYh, NXh   = h.shape
    Bu, NYu, NXu1 = u.shape
    Bv, NYp1, NXv = v.shape
    assert B==Bu==Bv, "batch mismatch in bc_uv_B"
    assert NXh==NXv and NXu1==NXh+1 and NYu==NYh and NYp1==NYh+1, "grid mismatch in bc_uv_B"

    if NYh > 1:
        h[:, 0 , :] = h[:, 1 , :]
        h[:, -1, :] = h[:, -2, :]
    h[:, :,  0] = h[:, :, -2]
    h[:, :, -1] = h[:, :,  1]

    if NYu > 1:
        u[:, 0 , :] = u[:, 1 , :]
        u[:, -1, :] = u[:, -2, :]
    u[:, :,  0] = u[:, :, -2]
    u[:, :, -1] = u[:, :,  1]

    v[:, :, 0]  = v[:, :, -1]
    v[:, 0 , :] = 0.0
    v[:, -1, :] = 0.0
    return h,u,v

# ----- C-grid ops -----
def periodic_pad_x_center(f):   return torch.cat([f[:,-1:], f, f[:,0:1]], dim=1)
def ddx_u_to_c(u):              return (u[:,1:] - u[:,:-1]) / dx
def ddy_v_to_c(v):              return (v[1:,:] - v[:-1,:]) / dy
def grad_h_to_u(h):             return (periodic_pad_x_center(h)[:,1:] - periodic_pad_x_center(h)[:,:-1]) / dx
def grad_h_to_v(h):
    hp = torch.cat([h[0:1,:], h, h[-1:,:]], dim=0)
    return (hp[1:,:] - hp[:-1,:]) / dy
def avg_c_to_u(f):              return 0.5*(periodic_pad_x_center(f)[:,:-1] + periodic_pad_x_center(f)[:,1:])
def avg_c_to_v(f):
    fp = torch.cat([f[0:1,:], f, f[-1:,:]], dim=0)
    return 0.5*(fp[:-1,:] + fp[1:,:])
def avg_u_to_c(u):              return 0.5*(u[:,:-1] + u[:,1:])
def avg_v_to_c(v):              return 0.5*(v[:-1,:] + v[1:,:])
def ddx_u(u):                   return (torch.roll(u,-1,1) - torch.roll(u,1,1)) / (2*dx)
def ddy_u(u):                   return (torch.roll(u,-1,0) - torch.roll(u,1,0)) / (2*dy)
def ddx_v(v):                   return (torch.roll(v,-1,1) - torch.roll(v,1,1)) / (2*dx)
def ddy_v(v):                   return (torch.roll(v,-1,0) - torch.roll(v,1,0)) / (2*dy)
def lap_u(u):
    uxx = (torch.roll(u,-1,1) - 2*u + torch.roll(u,1,1))/(dx*dx)
    uyy = (torch.roll(u,-1,0) - 2*u + torch.roll(u,1,0))/(dy*dy)
    return uxx + uyy
def lap_v(v):
    vxx = (torch.roll(v,-1,1) - 2*v + torch.roll(v,1,1))/(dx*dx)
    vyy = (torch.roll(v,-1,0) - 2*v + torch.roll(v,1,0))/(dy*dy)
    return vxx + vyy
def coriolis_on_u_from_v(v):
    v_at_c = avg_v_to_c(v); v_at_u = avg_c_to_u(v_at_c)
    fy_u   = (f0 + beta*(y - LY/2)).view(NY,1).expand(NY,NX+1)
    return -fy_u * v_at_u
def coriolis_on_v_from_u(u):
    u_at_c = avg_u_to_c(u); u_at_v = avg_c_to_v(u_at_c)
    fy_v   = (f0 + beta*(yv - LY/2)).view(NY+1,1).expand(NY+1,NX)
    return  fy_v * u_at_v

# ----- Nonlinear SWE RHS -----
def rhs_nl(h,u,v):
    h,u,v = bc_uv_2d(h,u,v)
    H = H0 + h
    hu_u = u * avg_c_to_u(H)
    hv_v = v * avg_c_to_v(H)
    divF = ddx_u_to_c(hu_u) + ddy_v_to_c(hv_v)
    Hx_u = grad_h_to_u(H)
    Hy_v = grad_h_to_v(H)
    v_at_u = avg_c_to_u(avg_v_to_c(v))
    u_at_v = avg_c_to_v(avg_u_to_c(u))
    adv_u = u*ddx_u(u) + v_at_u*ddy_u(u)
    adv_v = u_at_v*ddx_v(v) + v*ddy_v(v)
    visc_u = nu*lap_u(u)
    visc_v = nu*lap_v(v)
    h_t = -divF
    u_t = -GRAV*Hx_u - adv_u + visc_u + coriolis_on_u_from_v(v)
    v_t = -GRAV*Hy_v - adv_v + visc_v + coriolis_on_v_from_u(u)
    return bc_uv_2d(h_t,u_t,v_t)

# ----- RK2 truth generator -----
def rk2_step(h,u,v,dt):
    k1h,k1u,k1v = rhs_nl(h,u,v)
    h1 = h + dt*k1h; u1 = u + dt*k1u; v1 = v + dt*k1v
    k2h,k2u,k2v = rhs_nl(h1,u1,v1)
    h_new = h + 0.5*dt*(k1h+k2h)
    u_new = u + 0.5*dt*(k1u+k2u)
    v_new = v + 0.5*dt*(k1v+k2v)
    return bc_uv_2d(h_new,u_new,v_new)

# ================== Multi-case ICs ==================
def rand_ic():
    # random sum of Gaussian bumps (2–4), amplitudes varied
    Yc, Xc = torch.meshgrid(y, x, indexing='ij')
    nb = np.random.randint(2,5)
    h = torch.zeros_like(Yc, device=device)
    for _ in range(nb):
        cx = float(np.random.uniform(0.2,0.8) * LX)
        cy = float(np.random.uniform(0.2,0.8) * LY)
        sx = float(np.random.uniform(0.08,0.18) * LX)
        sy = float(np.random.uniform(0.08,0.20) * LY)
        amp= float(np.random.uniform(0.3,0.8))
        h += amp*torch.exp(-(((Xc-cx)**2)/(2*sx**2) + ((Yc-cy)**2)/(2*sy**2)))
    # small random geostrophic-like swirl (kept tiny)
    u = torch.zeros(NY, NX+1, device=device)
    v = torch.zeros(NY+1, NX, device=device)
    return bc_uv_2d(h,u,v)

@torch.no_grad()
def gen_truth_from_ic(h0,u0,v0):
    h,u,v = h0.clone(),u0.clone(),v0.clone()
    outs=[(h.clone(),u.clone(),v.clone())]
    tnext=1
    for n in range(1,NT+1):
        h,u,v = rk2_step(h,u,v,DT_SEC)
        if tnext < len(SNAP_T) and n == SNAP_T[tnext]:
            outs.append((h.clone(),u.clone(),v.clone()))
            tnext += 1
            if tnext >= len(SNAP_T): break
    return outs

N_CASES = 6   # increase this for a larger training sample
log(f"Generating {N_CASES} truth cases...")
cases_truth = []
for c in range(N_CASES):
    h0,u0,v0 = rand_ic()
    traj = gen_truth_from_ic(h0,u0,v0)
    cases_truth.append({
        "h": [t[0].unsqueeze(0) for t in traj],  # (1,NY,NX)
        "u": [t[1].unsqueeze(0) for t in traj],  # (1,NY,NX+1)
        "v": [t[2].unsqueeze(0) for t in traj],  # (1,NY+1,NX)
        "ic": (h0.unsqueeze(0), u0.unsqueeze(0), v0.unsqueeze(0))
    })
log("Truth dataset generated.")

# ================== Normalization over whole dataset ==================
with torch.no_grad():
    h_cat = torch.cat([torch.cat([c["h"][k].reshape(1,-1) for k in range(HOURS+1)], dim=1) for c in cases_truth], dim=1)
    u_cat = torch.cat([torch.cat([c["u"][k].reshape(1,-1) for k in range(HOURS+1)], dim=1) for c in cases_truth], dim=1)
    v_cat = torch.cat([torch.cat([c["v"][k].reshape(1,-1) for k in range(HOURS+1)], dim=1) for c in cases_truth], dim=1)
    mh, sh = float(h_cat.mean()), float(h_cat.std()+1e-6)
    mu, su = float(u_cat.mean()), float(u_cat.std()+1e-6)
    mv, sv = float(v_cat.mean()), float(v_cat.std()+1e-6)

def znorm(x, m, s): return (x - m) / s
def zdenorm(x, m, s): return x * s + m

# ================== Model ==================
class TimeFourier(nn.Module):
    def __init__(self, nf=8, scale=1.0):
        super().__init__()
        self.register_buffer("k", torch.arange(1,nf+1).float())
        self.scale=scale; self.dim=2*nf
    def forward(self, t_hr):
        t=t_hr*self.scale
        s=torch.sin(self.k[None,:]*t); c=torch.cos(self.k[None,:]*t)
        return torch.cat([s,c], dim=-1)  # (1,2nf)

def make_xy_feats():
    Yc,Xc = torch.meshgrid(y, x, indexing='ij')            # centers
    Yu,Xu = torch.meshgrid(y, xu, indexing='ij')           # u grid
    Yv,Xv = torch.meshgrid(yv, x, indexing='ij')           # v grid
    xc = (Xc/LX*2-1).unsqueeze(0).unsqueeze(0)             # (1,1,NY,NX)
    yc = (Yc/LY*2-1).unsqueeze(0).unsqueeze(0)
    xu_ = (Xu/LX*2-1).unsqueeze(0).unsqueeze(0)            # (1,1,NY,NX+1)
    yu_ = (Yu/LY*2-1).unsqueeze(0).unsqueeze(0)
    xv_ = (Xv/LX*2-1).unsqueeze(0).unsqueeze(0)            # (1,1,NY+1,NX)
    yv_ = (Yv/LY*2-1).unsqueeze(0).unsqueeze(0)
    return xc,yc,xu_,yu_,xv_,yv_
XC,YC,XU,YU,XV,YV = make_xy_feats()

class Block(nn.Module):
    def __init__(self, c_in, c_out, hidden=160, depth=5):
        super().__init__()
        L=[nn.Conv2d(c_in, hidden, 3, padding=1), nn.GELU()]
        for _ in range(depth-1): L += [nn.Conv2d(hidden, hidden, 3, padding=1), nn.GELU()]
        L += [nn.Conv2d(hidden, c_out, 3, padding=1)]
        self.net = nn.Sequential(*L)
    def forward(self, x): return self.net(x)

class STUNet(nn.Module):
    """ Predict absolute (h,u,v) at time t from IC and (x,y,t) features """
    def __init__(self, nf_t=8):
        super().__init__()
        self.tfe = TimeFourier(nf=nf_t, scale=1.0)
        tdim = self.tfe.dim
        self.h_trunk = Block(c_in=1+1+tdim, c_out=1)  # [h0_norm, pos-sum, t_fourier]
        self.u_trunk = Block(c_in=1+1+tdim, c_out=1)
        self.v_trunk = Block(c_in=1+1+tdim, c_out=1)
        self.logsig_h = nn.Parameter(torch.tensor(0.0))
        self.logsig_u = nn.Parameter(torch.tensor(0.0))
        self.logsig_v = nn.Parameter(torch.tensor(0.0))

    def forward(self, h0,u0,v0, t_hr):
        tf = self.tfe(t_hr)  # (1,tdim)
        B=1
        tf_c = tf.view(B,-1,1,1).expand(B,-1,NY,NX)
        tf_u = tf.view(B,-1,1,1).expand(B,-1,NY,NX+1)
        tf_v = tf.view(B,-1,1,1).expand(B,-1,NY+1,NX)

        h0n = znorm(h0, mh, sh)                 # (1,NY,NX)
        u0n = znorm(u0, mu, su)                 # (1,NY,NX+1)
        v0n = znorm(v0, mv, sv)                 # (1,NY+1,NX)

        hc_in = torch.cat([h0n.unsqueeze(1), (XC+YC), tf_c], dim=1)
        hu_in = torch.cat([u0n.unsqueeze(1), (XU+YU), tf_u], dim=1)
        hv_in = torch.cat([v0n.unsqueeze(1), (XV+YV), tf_v], dim=1)

        hn = self.h_trunk(hc_in).squeeze(1)     # (1,NY,NX)
        un = self.u_trunk(hu_in).squeeze(1)     # (1,NY,NX+1)
        vn = self.v_trunk(hv_in).squeeze(1)     # (1,NY+1,NX)

        h = zdenorm(hn, mh, sh)
        u = zdenorm(un, mu, su)
        v = zdenorm(vn, mv, sv)

        # Soft clamps to avoid JVP blow-ups
        h = torch.clamp(h, -2.0, 2.0)
        u = torch.clamp(u, -25.0, 25.0)
        v = torch.clamp(v, -25.0, 25.0)

        h,u,v = bc_uv_B(h,u,v)
        assert_shape(h[0], (NY,NX),   "model h")
        assert_shape(u[0], (NY,NX+1), "model u")
        assert_shape(v[0], (NY+1,NX), "model v")
        return h,u,v

model = STUNet().to(device)
opt   = torch.optim.Adam(model.parameters(), lr=3e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=800, eta_min=1e-5)

# ================== AD-PINN residuals ==================
T_scale = HOURS*SEC_PER_HR
eta_scale = max(float(h_cat.std().cpu()), 1e-6)
u_scale   = max(float(u_cat.std().cpu()), 1e-6)
v_scale   = max(float(v_cat.std().cpu()), 1e-6)

def pack_flat(h,u,v):
    return torch.cat([h.reshape(-1), u.reshape(-1), v.reshape(-1)], dim=0)
def unpack_flat(yflat):
    n_h = NY*NX
    n_u = NY*(NX+1)
    n_v = (NY+1)*NX
    h = yflat[:n_h].view(1,NY,NX)
    u = yflat[n_h:n_h+n_u].view(1,NY,NX+1)
    v = yflat[n_h+n_u:].view(1,NY+1,NX)
    return h,u,v
def f_time_only(t_hr_scalar, h0,u0,v0):
    h,u,v = model(h0,u0,v0, t_hr_scalar)
    return pack_flat(h,u,v)

def residuals_AD(h0,u0,v0, t_hr_scalar):
    t = torch.tensor(float(t_hr_scalar), device=device, dtype=torch.float32, requires_grad=True)
    y, dy_dt_hr = torch.autograd.functional.jvp(
        lambda tt: f_time_only(tt, h0,u0,v0),
        (t,), (torch.ones_like(t),), create_graph=True
    )
    h, u, v = unpack_flat(y)               # (1,...)
    dhdt, dudt, dvdt = unpack_flat(dy_dt_hr / SEC_PER_HR)
    # drop batch for FD ops
    h, u, v = h[0], u[0], v[0]
    dhdt, dudt, dvdt = dhdt[0], dudt[0], dvdt[0]

    H = H0 + h
    hu_u = u * avg_c_to_u(H)
    hv_v = v * avg_c_to_v(H)
    cont = dhdt + ddx_u_to_c(hu_u) + ddy_v_to_c(hv_v)

    Hx_u = grad_h_to_u(H)
    Hy_v = grad_h_to_v(H)
    v_at_u = avg_c_to_u(avg_v_to_c(v))
    u_at_v = avg_c_to_v(avg_u_to_c(u))
    adv_u = u*ddx_u(u) + v_at_u*ddy_u(u)
    adv_v = u_at_v*ddx_v(v) + v*ddy_v(v)
    R_u = dudt + GRAV*Hx_u + adv_u - nu*lap_u(u) - coriolis_on_u_from_v(v)
    R_v = dvdt + GRAV*Hy_v + adv_v - nu*lap_v(v) + coriolis_on_v_from_u(u)

    Rh = cont / (eta_scale / T_scale)
    Ru = R_u  / (u_scale   / T_scale)
    Rv = R_v  / (v_scale   / T_scale)
    return Rh, Ru, Rv

# ================== Training (curriculum + multi-case) ==================
EPOCHS_DATA  = 250   # longer pure-data pretrain (stability)
EPOCHS_JOINT = 400
TOTAL_EPOCHS = EPOCHS_DATA + EPOCHS_JOINT

def lambda_phys(ep):
    if ep <= EPOCHS_DATA: return 0.0
    k = ep - EPOCHS_DATA
    return min(0.02, 1e-3 + 5e-5*k)  # gentle ramp up to 0.02

# midpoints weighting slightly lower
def colloc_weight(t_hr): return 1.0 if t_hr in ANCHOR_HRS else 0.7

# loss balancing across anchors (later anchors slightly heavier)
anchor_w = {k: 1.0 + 0.10*k for k in range(HOURS+1)}
w_sum = sum(anchor_w.values())

def safe_mean(x):
    x = torch.where(torch.isfinite(x), x, torch.zeros_like(x))
    return x.mean()

logs=[]
for ep in range(1, TOTAL_EPOCHS+1):
    opt.zero_grad()
    L_data = torch.tensor(0.0, device=device)
    L_phys = torch.tensor(0.0, device=device)
    lam    = lambda_phys(ep)

    # === DATA: average over all cases and anchors ===
    for ci in range(N_CASES):
        h0,u0,v0 = cases_truth[ci]["ic"]
        for t_hr in ANCHOR_HRS:
            hp,up,vp = model(h0,u0,v0, torch.tensor(t_hr,device=device))
            k = int(round(t_hr))
            w = anchor_w.get(k,1.0)
            Ld = (F.mse_loss(znorm(hp,mh,sh), znorm(cases_truth[ci]["h"][k],mh,sh)) +
                  F.mse_loss(znorm(up,mu,su), znorm(cases_truth[ci]["u"][k],mu,su)) +
                  F.mse_loss(znorm(vp,mv,sv), znorm(cases_truth[ci]["v"][k],mv,sv)))
            L_data += w*Ld
    L_data = L_data / (w_sum * N_CASES)

    # === PHYSICS: over all cases and collocation times ===
    if lam > 0:
        for ci in range(N_CASES):
            h0,u0,v0 = cases_truth[ci]["ic"]
            for t_hr in COLL_HRS:
                Rh, Ru, Rv = residuals_AD(h0,u0,v0, t_hr)
                # clamp logsig to avoid exp overflow/underflow
                sh = model.logsig_h.clamp(-2.0, 2.0)
                su = model.logsig_u.clamp(-2.0, 2.0)
                sv = model.logsig_v.clamp(-2.0, 2.0)
                Lp  = 0.5*(torch.exp(-2*sh)*safe_mean(Rh.pow(2)) + 2*sh)
                Lp += 0.5*(torch.exp(-2*su)*safe_mean(Ru.pow(2)) + 2*su)
                Lp += 0.5*(torch.exp(-2*sv)*safe_mean(Rv.pow(2)) + 2*sv)
                L_phys += colloc_weight(t_hr)*Lp
        L_phys = L_phys / (N_CASES * (len(ANCHOR_HRS) + 0.7*len(MID_HRS)))

    L = L_data + lam*L_phys

    if not torch.isfinite(L):
        log(f"[NaN GUARD] Non-finite loss at epoch {ep}: L={L} data={L_data} phys={L_phys} λ={lam}")
        opt.zero_grad(set_to_none=True)
        # reduce LR proactively
        for g in opt.param_groups: g['lr'] = max(g['lr']*0.5, 1e-5)
        continue

    L.backward()
    # tighter clip when physics > 0
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5 if lam>0 else 1.0)
    # optional LR tweak when physics starts
    if lam>0 and sched.get_last_lr()[0] > 2e-4:
        for g in opt.param_groups: g['lr'] = 2e-4
    opt.step(); sched.step()

    logs.append([ep, float(L.item()), float(L_data.item()), float(L_phys.item()), float(lam),
                 float(model.logsig_h.item()), float(model.logsig_u.item()), float(model.logsig_v.item()),
                 float(sched.get_last_lr()[0])])
    if ep%50==0 or ep in (1,EPOCHS_DATA, TOTAL_EPOCHS):
        log(f"[{ep:04d}] L={L:.3e} data={L_data:.3e} phys={L_phys:.3e} λ={lam:.3f} "
            f"logsig=({model.logsig_h.item():.2f},{model.logsig_u.item():.2f},{model.logsig_v.item():.2f}) "
            f"lr={sched.get_last_lr()[0]:.2e}")

with open(os.path.join(OUTDIR,"losses_train.csv"),"w",newline="") as f:
    w=csv.writer(f); w.writerow(["epoch","total","data","phys","lambda","logsig_h","logsig_u","logsig_v","lr"]); w.writerows(logs)

# ================== Evaluation & Plots (guarded) ==================
def guard_plot(path, fn):
    try:
        fn(); plt.savefig(path,dpi=140); plt.close(); log(f"Saved: {path}")
    except Exception as e:
        log(f"[PLOT ERROR] {path}: {e}")
        with open(os.path.join(OUTDIR,"plot_errors.txt"),"a") as f:
            f.write(f"{path}: {e}\n{traceback.format_exc()}\n")

try:
    @torch.no_grad()
    def predict_case_hour(ci, th):
        h0,u0,v0 = cases_truth[ci]["ic"]
        return model(h0,u0,v0, torch.tensor(float(th),device=device))

    def rmse(a,b): return float(torch.sqrt(((a-b)**2).mean()).item())

    # RMSE by hour averaged across cases
    hrs = np.arange(HOURS+1, dtype=float)
    RM_h = np.zeros_like(hrs, dtype=float)
    RM_u = np.zeros_like(hrs, dtype=float)
    RM_v = np.zeros_like(hrs, dtype=float)

    # collect predictions for case 0 to plot maps
    pred0 = []
    for k in range(HOURS+1):
        hp0,up0,vp0 = predict_case_hour(0, k)
        pred0.append((hp0,up0,vp0))

    for k in range(HOURS+1):
        rh,ru,rv = 0.0,0.0,0.0
        for ci in range(N_CASES):
            hp,up,vp = predict_case_hour(ci, k)
            rh += rmse(hp, cases_truth[ci]["h"][k])
            ru += rmse(up, cases_truth[ci]["u"][k])
            rv += rmse(vp, cases_truth[ci]["v"][k])
        RM_h[k] = rh / N_CASES
        RM_u[k] = ru / N_CASES
        RM_v[k] = rv / N_CASES

    # energy on case 0
    def energy(hb, ub, vb):
        hb = hb[0]; ub = ub[0]; vb = vb[0]
        assert hb.shape == (NY, NX)
        assert ub.shape == (NY, NX+1)
        assert vb.shape == (NY+1, NX)
        uc = 0.5*(ub[:, :-1] + ub[:, 1:])
        vc = 0.5*(vb[:-1, :] + vb[1:, :])
        KE = 0.5*H0*(uc*uc + vc*vc)
        PE = 0.5*GRAV*(hb*hb)
        TE = KE + PE
        return float(KE.mean()), float(PE.mean()), float(TE.mean())

    KEs, PEs, TEs = [],[],[]
    for k in range(HOURS+1):
        ke,pe,te = energy(*pred0[k]); KEs.append(ke); PEs.append(pe); TEs.append(te)

    # Save CSVs
    with open(os.path.join(OUTDIR,"rmse_vs_hour_avg.csv"),"w",newline="") as f:
        w=csv.writer(f); w.writerow(["hour","RMSE_h_mean","RMSE_u_mean","RMSE_v_mean"])
        for h,a,b,c in zip(hrs,RM_h,RM_u,RM_v): w.writerow([float(h),float(a),float(b),float(c)])
    with open(os.path.join(OUTDIR,"energy_vs_hour_case0.csv"),"w",newline="") as f:
        w=csv.writer(f); w.writerow(["hour","KE","PE","TE"])
        for h,ke,pe,te in zip(hrs,KEs,PEs,TEs): w.writerow([float(h),float(ke),float(pe),float(te)])

    # Plots
    def imshow(ax, field, title):
        im=ax.imshow(field, origin='lower', interpolation='nearest', cmap='viridis')
        ax.set_title(title, fontsize=10); ax.set_xticks([]); ax.set_yticks([])
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    hT_true = cases_truth[0]["h"][-1][0].cpu().numpy()
    uT_true = cases_truth[0]["u"][-1][0].cpu().numpy()
    vT_true = cases_truth[0]["v"][-1][0].cpu().numpy()
    hT_pred = pred0[-1][0][0].cpu().numpy()
    uT_pred = pred0[-1][1][0].cpu().numpy()
    vT_pred = pred0[-1][2][0].cpu().numpy()

    guard_plot(os.path.join(OUTDIR,"maps_h_u_v_T6_case0.png"), lambda: (
        plt.figure(figsize=(12,8)),
        imshow(plt.subplot(3,3,1), hT_true, "h True (case0)"),
        imshow(plt.subplot(3,3,2), hT_pred, "h Pred (case0)"),
        imshow(plt.subplot(3,3,3), hT_pred-hT_true, "h Diff"),
        imshow(plt.subplot(3,3,4), uT_true, "u True (case0)"),
        imshow(plt.subplot(3,3,5), uT_pred, "u Pred (case0)"),
        imshow(plt.subplot(3,3,6), uT_pred-uT_true, "u Diff"),
        imshow(plt.subplot(3,3,7), vT_true, "v True (case0)"),
        imshow(plt.subplot(3,3,8), vT_pred, "v Pred (case0)"),
        imshow(plt.subplot(3,3,9), vT_pred-vT_true, "v Diff"),
        plt.tight_layout()
    ))

    guard_plot(os.path.join(OUTDIR,"rmse_vs_hour_avg.png"), lambda: (
        plt.figure(figsize=(7,5)),
        plt.plot(hrs, RM_h, label='h (avg over cases)'),
        plt.plot(hrs, RM_u, label='u (avg)'),
        plt.plot(hrs, RM_v, label='v (avg)'),
        plt.xlabel("hour"), plt.ylabel("RMSE"), plt.title("RMSE vs hour (mean over cases)"),
        plt.legend(), plt.tight_layout()
    ))

    guard_plot(os.path.join(OUTDIR,"energy_vs_hour_case0.png"), lambda: (
        plt.figure(figsize=(7,5)),
        plt.plot(hrs,KEs,label="KE"), plt.plot(hrs,PEs,label="PE"), plt.plot(hrs,TEs,label="TE"),
        plt.xlabel("hour"), plt.ylabel("Energy"), plt.title("Energy vs hour (case 0)"),
        plt.legend(), plt.tight_layout()
    ))

    log("SUCCESS")
    sys.exit(0)

except Exception as e:
    with open(os.path.join(OUTDIR,"post_training_error.txt"),"w") as f:
        f.write(str(e) + "\n" + traceback.format_exc())
    log(f"POST_TRAINING_FAILURE: {e}")
    sys.exit(1)


Device: cuda
Generating 6 truth cases...
Truth dataset generated.
[0001] L=3.036e+00 data=3.036e+00 phys=0.000e+00 λ=0.000 logsig=(0.00,0.00,0.00) lr=3.00e-04
[0050] L=2.405e+00 data=2.405e+00 phys=0.000e+00 λ=0.000 logsig=(0.00,0.00,0.00) lr=2.97e-04
[0100] L=2.226e+00 data=2.226e+00 phys=0.000e+00 λ=0.000 logsig=(0.00,0.00,0.00) lr=2.89e-04
[0150] L=2.129e+00 data=2.129e+00 phys=0.000e+00 λ=0.000 logsig=(0.00,0.00,0.00) lr=2.76e-04
[0200] L=2.062e+00 data=2.062e+00 phys=0.000e+00 λ=0.000 logsig=(0.00,0.00,0.00) lr=2.58e-04
[0250] L=2.035e+00 data=2.035e+00 phys=0.000e+00 λ=0.000 logsig=(0.00,0.00,0.00) lr=2.36e-04


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 18.12 MiB is free. Process 47453 has 14.72 GiB memory in use. Of the allocated memory 13.30 GiB is allocated by PyTorch, and 1.30 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)