In [1]:
import torch

print("Torch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("Supported archs:", torch.cuda.get_arch_list())

Torch version: 2.7.1+cu128
CUDA version: 12.8
Supported archs: ['sm_50', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120']


In [1]:
import cdsapi

c = cdsapi.Client()

c.retrieve(
    "reanalysis-era5-pressure-levels",
    {
        "product_type": "reanalysis",
        "format": "netcdf",
        "variable": [
            "temperature", "geopotential", "u_component_of_wind", "v_component_of_wind"
        ],
        "pressure_level": ["500"],
        "year": "2024",
        "month": "09",
        "day": [f"{d:02d}" for d in range(1, 31 + 1)],
        "time": [f"{h:02d}:00" for h in range(24)],
        "area": [90, -180, -90, 180],
    },
    "C:/Users/hars/Documents/era5_data/era5_500hpa_2024_09.nc"
)

2025-09-15 09:45:10,510 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.
2025-09-15 09:45:10,942 INFO Request ID is 41e5a134-d9b6-4c3a-b85a-7cc62d30e7a9
2025-09-15 09:45:11,142 INFO status has been updated to accepted
2025-09-15 09:45:16,585 INFO status has been updated to running
2025-09-15 09:45:20,170 INFO status has been updated to successful
                                                                                          

'C:/Users/hars/Documents/era5_data/era5_500hpa_2024_09.nc'

In [1]:
import cdsapi
import shutil, os

# Step 1: Download ERA5 data to a temp location
c = cdsapi.Client()

tmp_dir = r"C:\Users\hars\AppData\Local\Temp"
tmp_nc = os.path.join(tmp_dir, "era5_500hpa_2024_09.tmp")

c.retrieve(
    "reanalysis-era5-pressure-levels",
    {
        "product_type": "reanalysis",
        "format": "netcdf",
        "variable": [
            "temperature", "geopotential", "u_component_of_wind", "v_component_of_wind"
        ],
        "pressure_level": ["500"],
        "year": "2024",
        "month": "09",
        "day": [f"{d:02d}" for d in range(1, 31 + 1)],
        "time": [f"{h:02d}:00" for h in range(24)],
        "area": [90, -180, -90, 180],
    },
    tmp_nc
)

# Step 2: Atomic copy to D:\era5_cache
dst_dir = r"D:\era5_cache"
os.makedirs(dst_dir, exist_ok=True)

dst_final = os.path.join(dst_dir, "era5_500hpa_2024_09.nc")
shutil.copy2(tmp_nc, dst_final)

print("Copied to D:. Size (MB):", round(os.path.getsize(dst_final)/(1024*1024), 2))

2025-09-18 10:25:43,790 INFO [2025-09-03T00:00:00] To improve our C3S service, we need to hear from you! Please complete this very short [survey](https://confluence.ecmwf.int/x/E7uBEQ/). Thank you.
2025-09-18 10:25:43,791 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.
2025-09-18 10:25:44,229 INFO Request ID is 52fb7b30-3da4-4bb7-a2c5-7f0ca991dd46
2025-09-18 10:25:44,411 INFO status has been updated to accepted
2025-09-18 10:26:18,089 INFO status has been updated to successful
                                                                                           

Copied to D:. Size (MB): 4503.66


In [2]:
# Cell 1: Paths and basic config
DATA_PATH = "D:/era5_cache/era5_500hpa_2024_09.nc"  # your ERA5 file
RESULTS_ROOT = "D:/era5_spherical_resume_results"   # where previous run artifacts were saved
# If you know the exact timestamp folder you want to package, set it here (e.g., "20250918_153012").
# Otherwise we'll auto-pick the latest in a later cell.
SPECIFIC_TIMESTAMP = None  # or e.g. "20250918_153012"

In [3]:
# Cell 1: Imports, device, seeds, deterministic/cudnn
import os
import math
import numpy as np
import xarray as xr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# e3nn primitives (you already installed and imported in earlier cells)
from e3nn.o3 import Irreps
from e3nn.o3 import FullyConnectedTensorProduct

# Repro + speed
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cudnn.benchmark = True   # fixed-size kernels accelerate
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")

print("Device:", device)

Device: cuda


In [4]:
# Cell 2a: Force temp/cache dirs to D:
import os
from pathlib import Path

D_DRIVE_ROOT = Path(r"D:/abacus_temp")
D_DRIVE_ROOT.mkdir(parents=True, exist_ok=True)

os.environ["TMPDIR"]  = str(D_DRIVE_ROOT)
os.environ["TEMP"]    = str(D_DRIVE_ROOT)
os.environ["TMP"]     = str(D_DRIVE_ROOT)
os.environ["XDG_CACHE_HOME"] = str(D_DRIVE_ROOT)
os.environ["TORCH_HOME"] = str(D_DRIVE_ROOT / "torch")
os.environ["HF_HOME"]    = str(D_DRIVE_ROOT / "hf")

print("Temp/Cache dirs set to:", D_DRIVE_ROOT)

Temp/Cache dirs set to: D:\abacus_temp


In [7]:
from pathlib import Path
DATA_PATH = Path(r"C:/Users/hars/Documents/era5_data/era5_500hpa_2024_09.nc")
print("Using:", DATA_PATH, "| Exists?", DATA_PATH.exists())

Using: C:\Users\hars\Documents\era5_data\era5_500hpa_2024_09.nc | Exists? True


In [8]:
DATA_PATH = Path(r"D:/era5_cache/era5_500hpa_2024_09.nc")

In [5]:
# Cell 2b: Robust open of ERA5 NetCDF on Windows

from pathlib import Path
import os
import xarray as xr

# Set your data path
DATA_PATH = r"D:\era5_cache\era5_500hpa_2024_09.nc"

p = Path(DATA_PATH)
print("Using ERA5 file:", p)

# Basic checks
if not p.exists():
    raise FileNotFoundError(f"File not found: {p}")
size_bytes = p.stat().st_size
print("File size (bytes):", size_bytes)
if size_bytes < 1024:  # arbitrarily treat <1KB as empty/corrupt
    raise ValueError(f"File seems empty or corrupt (size={size_bytes} bytes).")

# Try netcdf4 first
try:
    ds = xr.open_dataset(str(p), engine="netcdf4")
    print("Opened with netcdf4 engine.")
except Exception as e1:
    print("netcdf4 engine failed:", repr(e1))
    # Try h5netcdf engine for NetCDF4/HDF5 files if available
    try:
        ds = xr.open_dataset(str(p), engine="h5netcdf")
        print("Opened with h5netcdf engine.")
    except Exception as e2:
        print("h5netcdf engine failed:", repr(e2))
        # Try scipy only for classic NetCDF3
        try:
            ds = xr.open_dataset(str(p), engine="scipy")
            print("Opened with scipy engine (NetCDF3).")
        except Exception as e3:
            print("scipy engine failed:", repr(e3))
            raise RuntimeError(
                "Could not open the NetCDF file with netcdf4, h5netcdf, or scipy. "
                "Check that the file is a valid NetCDF4/HDF5 file and not locked/corrupted."
            )

# Select 500 hPa if present
if "pressure_level" in ds.dims or "pressure_level" in ds.coords:
    try:
        if 500.0 in ds["pressure_level"].values:
            ds = ds.sel(pressure_level=500.0)
        else:
            ds = ds.isel(pressure_level=0)
        print("Selected single pressure level.")
    except Exception as e:
        print("Pressure level selection warning:", e)

# Normalize time coordinate name if needed
if "valid_time" in ds.dims or "valid_time" in ds.coords:
    ds = ds.rename({"valid_time": "time"})

# Print a quick summary
print("Opened Dataset summary:")
print(" - dims:", dict(ds.sizes))
print(" - vars:", list(ds.data_vars)[:10])

Using ERA5 file: D:\era5_cache\era5_500hpa_2024_09.nc
File size (bytes): 4722431171
Opened with netcdf4 engine.
Selected single pressure level.
Opened Dataset summary:
 - dims: {'time': 720, 'latitude': 721, 'longitude': 1440}
 - vars: ['t', 'z', 'u', 'v']


In [6]:
# Cell 2c: Streaming stats over time (robust dims, low memory)

import json
from pathlib import Path
import numpy as np
import xarray as xr

assert 'DATA_PATH' in globals(), "Set DATA_PATH before running Cell 2c."

# Map dataset variables to canonical names
VAR_MAP = {
    "t": "temperature",
    "z": "geopotential",
    "u": "u",
    "v": "v",
    # fallbacks
    "ta": "temperature",
    "zg": "geopotential",
    "ua": "u",
    "va": "v",
}
required = ["temperature", "geopotential", "u", "v"]

# Helper: ensure we have a proper time dimension called "time"
def ensure_time_dim(ds_in: xr.Dataset) -> xr.Dataset:
    ds_out = ds_in
    # Normalize potential alternative names
    if "valid_time" in ds_out.dims or "valid_time" in ds_out.coords:
        ds_out = ds_out.rename({"valid_time": "time"})
    # If time exists as a coord but not a dim, try to move it to a dim
    if "time" not in ds_out.dims:
        if "time" in ds_out.coords:
            # If variables lack time dimension, add it with length 1 (unlikely for ERA5)
            needs_expand = []
            for v in ds_out.data_vars:
                if "time" not in ds_out[v].dims:
                    needs_expand.append(v)
            if needs_expand:
                # Expand each variable to include a time=1 dimension
                ds_out = ds_out.expand_dims(time=[ds_out["time"].values]).copy()
        else:
            # No time at all: create a synthetic single time index
            ds_out = ds_out.assign_coords(time=np.array([0])).expand_dims("time")
    # Sort by time if present
    if "time" in ds_out.coords:
        ds_out = ds_out.sortby("time")
    return ds_out

# Open lazily with explicit engine
ds = xr.open_dataset(DATA_PATH, engine="netcdf4")

# Select 500 hPa level (lazy)
if "pressure_level" in ds.dims or "pressure_level" in ds.coords:
    if 500.0 in ds["pressure_level"].values:
        ds_500 = ds.sel(pressure_level=500.0)
    else:
        ds_500 = ds.isel(pressure_level=0)
else:
    ds_500 = ds

# Ensure time dimension exists and is named 'time'
ds_500 = ensure_time_dim(ds_500)

# Build mapping raw->canonical present in ds
present = set(ds_500.data_vars)
mapped = {}
for raw, canon in VAR_MAP.items():
    if raw in present and canon not in mapped.values():
        mapped[raw] = canon

if set(mapped.values()) != set(required):
    missing = [v for v in required if v not in mapped.values()]
    raise KeyError(f"Missing required variables: {missing}. Dataset has: {list(ds_500.data_vars)}. "
                   f"Found mapping: {mapped}")

# Determine dims robustly
def first_available(name_options, sizes):
    for n in name_options:
        if n in sizes:
            return n
    return None

time_dim = first_available(["time"], ds_500.sizes)
lat_dim = first_available(["latitude", "lat"], ds_500.sizes)
lon_dim = first_available(["longitude", "lon"], ds_500.sizes)

if time_dim is None or lat_dim is None or lon_dim is None:
    raise KeyError(f"Could not find required dims. sizes={ds_500.sizes}")

T = ds_500.sizes[time_dim]
H = ds_500.sizes[lat_dim]
W = ds_500.sizes[lon_dim]
print(f"Streaming stats over {time_dim}={T}, {lat_dim}={H}, {lon_dim}={W}")

# Streaming accumulators per canonical variable
acc_sum = {c: 0.0 for c in required}
acc_sum2 = {c: 0.0 for c in required}
acc_count = {c: 0 for c in required}

# Chunk over time to keep memory low
chunk_t = 16  # adjust if needed
for t0 in range(0, T, chunk_t):
    t1 = min(T, t0 + chunk_t)
    sl = slice(t0, t1)

    for canon in required:
        raw = [k for k, v in mapped.items() if v == canon][0]
        # Reorder dims to (time, lat, lon), dropping any singleton leftover
        da = ds_500[raw].isel({time_dim: sl}).transpose(time_dim, lat_dim, lon_dim)

        # Read only this small slice
        arr = da.values  # shape ~ [chunk_t, H, W] as float32
        # Convert to float64 for stable sums, but no copy if already float64
        arr = arr.astype(np.float64, copy=False)

        n = arr.size
        s = float(arr.sum())
        s2 = float((arr * arr).sum())

        acc_sum[canon] += s
        acc_sum2[canon] += s2
        acc_count[canon] += n

    print(f"Processed {time_dim} slice [{t0}:{t1})")

# Compute mean/std
stats = {}
for canon in required:
    n = acc_count[canon]
    mean = acc_sum[canon] / max(n, 1)
    var = max(acc_sum2[canon] / max(n, 1) - mean * mean, 0.0)
    std = float(np.sqrt(var) + 1e-6)
    stats[canon] = (float(mean), std)

VARS = required

print("Canonical variable order:", VARS)
print("Stats:")
for k in VARS:
    m, s = stats[k]
    print(f" - {k:13s}: mean={m:.4f}, std={s:.4f}")

# Optional: save stats to D:
try:
    out_dir = Path(r"D:/era5_cache")
    out_dir.mkdir(parents=True, exist_ok=True)
    out_file = out_dir / "stats_500hpa_2024_09.json"
    with open(out_file, "w") as f:
        json.dump({"VARS": VARS, "stats": stats}, f, indent=2)
    print("Saved stats to:", out_file)
except Exception as e:
    print("Could not save stats to D:, continuing without saving. Reason:", e)

Streaming stats over time=720, latitude=721, longitude=1440
Processed time slice [0:16)
Processed time slice [16:32)
Processed time slice [32:48)
Processed time slice [48:64)
Processed time slice [64:80)
Processed time slice [80:96)
Processed time slice [96:112)
Processed time slice [112:128)
Processed time slice [128:144)
Processed time slice [144:160)
Processed time slice [160:176)
Processed time slice [176:192)
Processed time slice [192:208)
Processed time slice [208:224)
Processed time slice [224:240)
Processed time slice [240:256)
Processed time slice [256:272)
Processed time slice [272:288)
Processed time slice [288:304)
Processed time slice [304:320)
Processed time slice [320:336)
Processed time slice [336:352)
Processed time slice [352:368)
Processed time slice [368:384)
Processed time slice [384:400)
Processed time slice [400:416)
Processed time slice [416:432)
Processed time slice [432:448)
Processed time slice [448:464)
Processed time slice [464:480)
Processed time slice [48

In [7]:
# Cell 3 (updated): Spherical vertices + robust chunked KNN

import numpy as np
import torch

def build_spherical_vertices(H=64, W=128):
    lats = np.linspace(-90.0, 90.0, H, endpoint=True)
    lons = np.linspace(0.0, 360.0, W, endpoint=False)
    lat_grid, lon_grid = np.meshgrid(lats, lons, indexing="ij")
    lat_rad = np.deg2rad(lat_grid)
    lon_rad = np.deg2rad(lon_grid)

    x = np.cos(lat_rad) * np.cos(lon_rad)
    y = np.cos(lat_rad) * np.sin(lon_rad)
    z = np.sin(lat_rad)
    xyz = np.stack([x, y, z], axis=-1).reshape(-1, 3).astype(np.float32)  # [N,3]
    latlon = np.stack([lat_grid, lon_grid], axis=-1).reshape(-1, 2).astype(np.float32)  # [N,2]
    idx_flat = np.arange(H * W, dtype=np.int64)
    return {"xyz": xyz, "latlon": latlon, "idx_flat": idx_flat, "H": H, "W": W}

def knn_on_sphere_chunked(xyz_np: np.ndarray, K: int = 16, block: int = 2048) -> np.ndarray:
    """
    Memory-safe KNN by processing query rows in blocks.
    Avoids forming the full N x N distance matrix.
    """
    xyz = torch.from_numpy(xyz_np)  # [N,3], CPU
    N = xyz.shape[0]
    K = int(min(K, max(1, N - 1)))  # clamp
    knn_idx = torch.empty(N, K, dtype=torch.long)

    with torch.no_grad():
        for i0 in range(0, N, block):
            i1 = min(N, i0 + block)
            # Compute distances for the current block to all points
            d = torch.cdist(xyz[i0:i1], xyz, p=2)  # [B,N]
            # Make self-distance large
            ar = torch.arange(i0, i1)
            d[torch.arange(i1 - i0), ar] = 1e9
            topk = torch.topk(-d, k=K, dim=1).indices  # nearest -> largest negative
            knn_idx[i0:i1] = topk.cpu()
    return knn_idx.numpy().astype(np.int64)

# Initialize small default verts; we’ll rebuild in Cell 3b
verts = build_spherical_vertices(H=64, W=128)
# Use chunked KNN to be safe
verts["knn_idx"] = knn_on_sphere_chunked(verts["xyz"], K=16, block=1024)
print("Initial verts:", {"H": verts["H"], "W": verts["W"], "N": verts["H"] * verts["W"]})
print("KNN shape:", verts["knn_idx"].shape)

Initial verts: {'H': 64, 'W': 128, 'N': 8192}
KNN shape: (8192, 16)


In [12]:
# Cell X: Clean up cache files before rebuilding
from pathlib import Path

CACHE_DIR = Path(r"D:/era5_cache/runtime")
for name in [
    "baseline_ctx_src_64x128.npy",
    "baseline_tgt_64x128.npy",
    "spherical_ctx_src_TCN.npy",
    "spherical_tgt_CN.npy",
    "baseline_ctx_src_64x128.tmpmm",
    "baseline_tgt_64x128.tmpmm",
    "spherical_ctx_src_TCN.tmpmm",
    "spherical_tgt_CN.tmpmm",
    "baseline_ctx_src_64x128.npy.tmp",
    "baseline_tgt_64x128.npy.tmp",
    "spherical_ctx_src_TCN.npy.tmp",
    "spherical_tgt_CN.npy.tmp",
]:
    p = CACHE_DIR / name
    if p.exists():
        try:
            p.unlink()
            print("Deleted:", p)
        except Exception as e:
            print("Could not delete", p, "->", e)

Deleted: D:\era5_cache\runtime\baseline_ctx_src_64x128.tmpmm
Deleted: D:\era5_cache\runtime\baseline_tgt_64x128.tmpmm


In [8]:
# Cell 3.5 (Windows-robust): Rebuild caches and save directly with np.save

import os
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F

CACHE_DIR = Path(r"D:/era5_cache/runtime")
CACHE_DIR.mkdir(parents=True, exist_ok=True)

BASELINE_CACHE = CACHE_DIR / "baseline_ctx_src_64x128.npy"
TARGET_CACHE   = CACHE_DIR / "baseline_tgt_64x128.npy"
SPHERICAL_CACHE = CACHE_DIR / "spherical_ctx_src_TCN.npy"
SPHERICAL_TGT   = CACHE_DIR / "spherical_tgt_CN.npy"

USE_TORCH_RESIZE = False
CHUNK = 48
torch.set_num_threads(min(8, max(1, os.cpu_count() or 4)))

# Ascending latitude view
T = ds.sizes["time"]
lats = ds["latitude"].values
if lats[0] > lats[-1]:
    ds_cache = ds.isel(latitude=slice(None, None, -1))
    lats = ds_cache["latitude"].values
else:
    ds_cache = ds

VARS = ["temperature", "geopotential", "u_component_of_wind", "v_component_of_wind"]
C = len(VARS)

# Baseline caches: build full array in chunks, then save once
if not (BASELINE_CACHE.exists() and TARGET_CACHE.exists()):
    print("Rebuilding baseline 64x128 caches...")
    Xb = np.empty((T, C, 64, 128), dtype=np.float32)
    Yb = np.empty((T, C, 64, 128), dtype=np.float32)

    for t0 in range(0, T, CHUNK):
        t1 = min(t0 + CHUNK, T)
        chunk_list = []
        for v in VARS:
            arr = ds_cache[v].isel(time=slice(t0, t1)).values.astype(np.float32)  # [Tc, H, W]
            chunk_list.append(arr[:, None, ...])  # [Tc,1,H,W]
        block = np.concatenate(chunk_list, axis=1)  # [Tc,C,H,W]

        if USE_TORCH_RESIZE:
            ten = torch.from_numpy(block)  # CPU
            ten_rs = F.interpolate(ten, size=(64, 128), mode="bilinear", align_corners=False).numpy()
        else:
            Tc, Cc, H0, W0 = block.shape
            step_h = max(1, H0 // 64)
            step_w = max(1, W0 // 128)
            ten_rs = block[:, :, ::step_h, ::step_w]
            ten_rs = ten_rs[:, :, :64, :128]
            if ten_rs.shape[2] < 64 or ten_rs.shape[3] < 128:
                pad_h = 64 - ten_rs.shape[2]
                pad_w = 128 - ten_rs.shape[3]
                ten_rs = np.pad(ten_rs, ((0,0),(0,0),(0,pad_h),(0,pad_w)), mode='edge')

        Xb[t0:t1] = ten_rs
        Yb[t0:t1] = ten_rs

        # Progress
        if (t0 // CHUNK) % 4 == 0:
            print(f"  baseline cached t={t0}/{T}")

    # Save directly
    np.save(BASELINE_CACHE, Xb)
    np.save(TARGET_CACHE,   Yb)
    print("Baseline caches saved:", BASELINE_CACHE, TARGET_CACHE)
else:
    print("Baseline caches already present.")

# Spherical caches: build and save once
if not (SPHERICAL_CACHE.exists() and SPHERICAL_TGT.exists()):
    print("Rebuilding spherical vertex caches...")
    if 'verts' not in globals():
        raise RuntimeError("verts not found. Run Cell 3 (spherical vertices) before this cell.")
    latlon_targets = verts["latlon"]
    lons = ds_cache["longitude"].values
    lat_idx = np.searchsorted(lats, latlon_targets[:, 0], side="left")
    lat_idx = np.clip(lat_idx, 0, len(lats) - 1)
    lon_targets = ((latlon_targets[:, 1] - lons[0] + 360) % 360) + lons[0]
    lon_idx = np.searchsorted(lons, lon_targets, side="left")
    lon_idx = np.clip(lon_idx, 0, len(lons) - 1)
    N = lat_idx.shape[0]

    Xs = np.empty((T, C, N), dtype=np.float32)
    Ys = np.empty((T, C, N), dtype=np.float32)

    for t0 in range(0, T, CHUNK):
        t1 = min(t0 + CHUNK, T)
        chunk_list = []
        for v in VARS:
            arr = ds_cache[v].isel(time=slice(t0, t1)).values.astype(np.float32)  # [Tc, H, W]
            chunk_list.append(arr[:, None, ...])
        block = np.concatenate(chunk_list, axis=1)  # [Tc,C,H,W]
        gather = block[:, :, lat_idx, lon_idx]      # [Tc,C,N]
        Xs[t0:t1] = gather
        Ys[t0:t1] = gather

        if (t0 // CHUNK) % 4 == 0:
            print(f"  spherical cached t={t0}/{T}")

    np.save(SPHERICAL_CACHE, Xs)
    np.save(SPHERICAL_TGT,   Ys)
    print("Spherical caches saved:", SPHERICAL_CACHE, SPHERICAL_TGT)
else:
    print("Spherical caches already present.")

print("Cache setup complete (Windows-robust).")

KeyError: 'time'

In [12]:
# Cell 4: Winds-only datasets + TinyLRU (no Coriolis extra channel)

import torch
from torch.utils.data import Dataset
from collections import OrderedDict
import numpy as np

WINDS_ONLY = True  # we focus on u, v only

def get_raw_key_for(canon_name, ds):
    candidates = {
        "temperature": ["t", "ta", "temperature"],
        "geopotential": ["z", "zg", "geopotential"],
        "u": ["u", "ua"],
        "v": ["v", "va"],
    }
    for raw in candidates[canon_name]:
        if raw in ds.data_vars:
            return raw
    raise KeyError(f"Could not find raw key for {canon_name} in dataset vars: {list(ds.data_vars)}")

class TinyLRU:
    def __init__(self, capacity=64):
        self.cap = capacity
        self.d = OrderedDict()
    def get(self, k):
        if k in self.d:
            v = self.d.pop(k); self.d[k] = v; return v
        return None
    def put(self, k, v):
        if k in self.d: self.d.pop(k)
        self.d[k] = v
        if len(self.d) > self.cap: self.d.popitem(last=False)

def build_var_list(ds):
    base = ["u", "v"] if WINDS_ONLY else ["t", "z", "u", "v"]
    return [v for v in base if get_raw_key_for(v, ds)]

class ERA5GridDataset(Dataset):
    def __init__(self, ds, stats, VARS, time_indices, T_ctx=2, cache_capacity=64):
        super().__init__()
        self.ds = ds
        self.stats = stats
        self.VARS = VARS
        self.T_ctx = T_ctx

        self.time_dim = "time" if "time" in ds.sizes else list(ds.sizes.keys())[0]
        self.lat_dim = "latitude" if "latitude" in ds.sizes else "lat"
        self.lon_dim = "longitude" if "longitude" in ds.sizes else "lon"

        self.time_indices = time_indices
        self.raw_keys = {c: get_raw_key_for(c, ds) for c in VARS}
        self.cache = TinyLRU(capacity=cache_capacity)

    def _read_grid_norm(self, canon_name, t_index):
        key = (canon_name, int(t_index))
        cached = self.cache.get(key)
        if cached is not None: return cached

        raw = self.raw_keys[canon_name]
        da = self.ds[raw].isel({self.time_dim: t_index})

        for d in list(da.dims):
            if d not in (self.lat_dim, self.lon_dim, self.time_dim) and da.sizes.get(d, 0) == 1:
                da = da.isel({d: 0})

        if self.lat_dim not in da.dims or self.lon_dim not in da.dims:
            keep_dims = [d for d in da.dims if d in (self.lat_dim, self.lon_dim)]
            if len(keep_dims) != 2:
                raise ValueError(f"{raw} dims after squeeze: {da.dims}")
            da = da.transpose(*keep_dims)
            if da.dims != (self.lat_dim, self.lon_dim):
                da = da.transpose(self.lat_dim, self.lon_dim)
        else:
            da = da.transpose(self.lat_dim, self.lon_dim)

        arr = torch.from_numpy(da.values.astype("float32"))
        mean, std = self.stats[canon_name]
        out = (arr - mean) / std
        self.cache.put(key, out)
        return out

    def __len__(self): return len(self.time_indices)

    def __getitem__(self, idx):
        t = self.time_indices[idx]
        ctx_indices = list(range(max(0, t - self.T_ctx), t))

        x_list = []
        for ti in ctx_indices:
            c_list = [self._read_grid_norm(c, ti) for c in self.VARS]
            x_t = torch.stack(c_list, dim=0)
            x_list.append(x_t)
        if len(x_list) == 0:
            Hds = self.ds.sizes[self.lat_dim]; Wds = self.ds.sizes[self.lon_dim]
            x = torch.zeros(self.T_ctx, len(self.VARS), Hds, Wds, dtype=torch.float32)
        else:
            while len(x_list) < self.T_ctx: x_list.insert(0, x_list[0])
            x = torch.stack(x_list, dim=0)

        y_list = [self._read_grid_norm(c, t) for c in self.VARS]
        y = torch.stack(y_list, dim=0)
        return x, y

class ERA5SphericalDataset(Dataset):
    def __init__(self, ds, stats, VARS, time_indices, verts, T_ctx=2, cache_capacity=64):
        super().__init__()
        self.ds = ds
        self.stats = stats
        self.VARS = VARS
        self.T_ctx = T_ctx
        self.verts = verts

        self.time_dim = "time" if "time" in ds.sizes else list(ds.sizes.keys())[0]
        self.lat_dim = "latitude" if "latitude" in ds.sizes else "lat"
        self.lon_dim = "longitude" if "longitude" in ds.sizes else "lon"

        self.time_indices = time_indices
        self.raw_keys = {c: get_raw_key_for(c, ds) for c in VARS}

        self.H = ds.sizes[self.lat_dim]
        self.W = ds.sizes[self.lon_dim]

        self.cache = TinyLRU(capacity=cache_capacity)

    def _read_grid_norm(self, canon_name, t_index):
        key = (canon_name, int(t_index))
        cached = self.cache.get(key)
        if cached is not None: return cached

        raw = self.raw_keys[canon_name]
        da = self.ds[raw].isel({self.time_dim: t_index})
        for d in list(da.dims):
            if d not in (self.lat_dim, self.lon_dim, self.time_dim) and da.sizes.get(d, 0) == 1:
                da = da.isel({d: 0})
        if self.lat_dim not in da.dims or self.lon_dim not in da.dims:
            keep_dims = [d for d in da.dims if d in (self.lat_dim, self.lon_dim)]
            if len(keep_dims) != 2:
                raise ValueError(f"{raw} dims after squeeze: {da.dims}")
            da = da.transpose(*keep_dims)
            if da.dims != (self.lat_dim, self.lon_dim):
                da = da.transpose(self.lat_dim, self.lon_dim)
        else:
            da = da.transpose(self.lat_dim, self.lon_dim)

        arr = torch.from_numpy(da.values.astype("float32"))
        mean, std = self.stats[canon_name]
        out = (arr - mean) / std
        self.cache.put(key, out)
        return out

    def __len__(self): return len(self.time_indices)

    def __getitem__(self, idx):
        t = self.time_indices[idx]
        ctx_indices = list(range(max(0, t - self.T_ctx), t))

        x_list = []
        for ti in ctx_indices:
            c_list = [self._read_grid_norm(c, ti) for c in self.VARS]
            x_t = torch.stack(c_list, dim=0)  # [C,H,W] with C=len(VARS)=2
            x_list.append(x_t)
        if len(x_list) == 0:
            x = torch.zeros(self.T_ctx, len(self.VARS), self.H, self.W, dtype=torch.float32)
        else:
            while len(x_list) < self.T_ctx: x_list.insert(0, x_list[0])
            x = torch.stack(x_list, dim=0)

        y_list = [self._read_grid_norm(c, t) for c in self.VARS]
        y = torch.stack(y_list, dim=0)
        return x, y

In [13]:
# Cell 5: Equivariant point encoder (Irreps + FCTP), version-agnostic Gate handling
def knn_indices(xyz: torch.Tensor, k: int = 16) -> torch.Tensor:
    # xyz: [B,N,3] -> idx: [B,N,K]
    with torch.no_grad():
        d = torch.cdist(xyz, xyz)  # [B,N,N]
        vals, idx = torch.topk(d, k=k+1, largest=False)  # includes self
        return idx[:, :, 1:]  # drop self

class EquivariantPointBlock(nn.Module):
    """
    Message passing from relative vectors (1o) and scalar node features (0e).
    Uses a tensor product to mix inputs into a hidden representation and then
    applies a simple gated nonlinearity constructed explicitly to avoid e3nn Gate API differences.
    """
    def __init__(self, irreps_in: Irreps, irreps_hidden: Irreps, irreps_out: Irreps, scalars_gate: int = 16):
        super().__init__()
        # Tensor product: (features x 1o) -> hidden
        self.tp1 = FullyConnectedTensorProduct(irreps_in, Irreps("1x1o"), irreps_hidden)

        # We implement a gate manually:
        # - produce gate scalars g in 0e via an MLP
        # - apply sigmoid to g to gate the hidden features projected to out
        self.lin_hidden = nn.Linear(irreps_hidden.dim, irreps_hidden.dim)
        self.lin_out = nn.Linear(irreps_hidden.dim, irreps_out.dim)
        self.gate_mlp = nn.Sequential(
            nn.Linear(irreps_hidden.dim, scalars_gate),
            nn.SiLU(),
            nn.Linear(scalars_gate, irreps_out.dim),  # gate dimension matches out scalars
        )

        # Lift scalar inputs
        self.lift = nn.Sequential(
            nn.Linear(irreps_in.dim, irreps_in.dim),
            nn.SiLU(),
            nn.Linear(irreps_in.dim, irreps_in.dim),
        )

    def forward(self, feats_in: torch.Tensor, rel_vec: torch.Tensor):
        # feats_in: [B,N,dim(irreps_in)] packed scalars
        # rel_vec: [B,N,K,3] (treated as 1o)
        B, N, K, _ = rel_vec.shape

        x = self.lift(feats_in)                             # [B,N,dim(ir_in)]
        x = x.unsqueeze(2).expand(B, N, K, x.shape[-1])     # [B,N,K,dim(ir_in)]
        x = x.reshape(B, N*K, -1)                           # [B,NK,dim(ir_in)]
        v = rel_vec.reshape(B, N*K, 3)                      # [B,NK,3]

        h = self.tp1(x, v)                                  # [B,NK,dim(ir_hidden)]
        h = F.silu(self.lin_hidden(h))                      # [B,NK,dim(ir_hidden)]

        gates = torch.sigmoid(self.gate_mlp(h))             # [B,NK,dim(ir_out)]
        h = self.lin_out(h)                                  # [B,NK,dim(ir_out)]
        h = gates * h                                        # gated output

        h = h.view(B, N, K, -1).mean(dim=2)                 # [B,N,dim(ir_out)]
        return h

class E3NNPointEncoder(nn.Module):
    def __init__(self, in_channels: int, hidden_irreps: str = "16x0e + 16x1o", out_channels: int = 32, k: int = 16):
        super().__init__()
        self.k = k
        self.ir_in = Irreps(f"{in_channels}x0e")
        self.ir_hidden = Irreps(hidden_irreps)
        self.ir_out = Irreps(f"{out_channels}x0e")
        self.lin_in = nn.Linear(in_channels, self.ir_in.dim)
        self.block1 = EquivariantPointBlock(self.ir_in, self.ir_hidden, self.ir_out, scalars_gate=16)
        self.lin_out = nn.Linear(self.ir_out.dim, out_channels)

    def forward(self, feats: torch.Tensor, xyz: torch.Tensor):
        # feats: [B, N, Cin] (scalars), xyz: [B, N, 3]
        B, N, _ = feats.shape
        idx = knn_indices(xyz, k=self.k)  # [B,N,K]
        nbrs = torch.gather(xyz, 1, idx.unsqueeze(-1).expand(-1, -1, -1, 3))  # [B,N,K,3]
        rel = nbrs - xyz.unsqueeze(2)  # [B,N,K,3]
        x = self.lin_in(feats)         # [B,N,dim(ir_in)]
        h = self.block1(x, rel)        # [B,N,dim(ir_out)]
        h = self.lin_out(h)            # [B,N,Cout] (0e scalars)
        return h

In [9]:
# Cell 6: Stronger Spherical model (edge-aware KNN + 2-layer Conv3dLSTM), grid-safe input

import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv3dLSTMCell(nn.Module):
    def __init__(self, in_ch, hidden_ch, kernel_size=3):
        super().__init__()
        pad = kernel_size // 2
        self.conv = nn.Conv3d(in_ch + hidden_ch, 4 * hidden_ch, kernel_size, padding=pad)

    def forward(self, x, h, c):
        gates = self.conv(torch.cat([x, h], dim=1))
        ci, cf, co, cg = torch.chunk(gates, 4, dim=1)
        i = torch.sigmoid(ci)
        f = torch.sigmoid(cf)
        o = torch.sigmoid(co)
        g = torch.tanh(cg)
        c_next = f * c + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

class SphericalSpatioTemporal(nn.Module):
    def __init__(self, n_vars, verts, hidden_ch=64, edge_ch=16, T_ctx=4):
        super().__init__()
        self.C = n_vars
        self.T_ctx = T_ctx

        xyz = torch.from_numpy(verts["xyz"]).float()        # [N,3]
        idx = torch.from_numpy(verts["knn_idx"]).long()     # [N,K]
        idx_flat = torch.from_numpy(verts["idx_flat"]).long()
        self.H = int(verts["H"])
        self.W = int(verts["W"])
        self.N = self.H * self.W

        self.register_buffer("xyz", xyz)
        self.register_buffer("idx", idx)
        self.register_buffer("idx_flat", idx_flat)

        self.K = idx.shape[1]

        self.lin_node = nn.Linear(n_vars, hidden_ch)
        self.edge_mlp = nn.Sequential(
            nn.Linear(4, edge_ch),
            nn.ReLU(inplace=True),
            nn.Linear(edge_ch, 1),
            nn.Sigmoid(),
        )
        self.core1 = Conv3dLSTMCell(in_ch=hidden_ch, hidden_ch=hidden_ch, kernel_size=3)
        self.core2 = Conv3dLSTMCell(in_ch=hidden_ch, hidden_ch=hidden_ch, kernel_size=3)
        self.lin_out = nn.Linear(hidden_ch, n_vars)

    def _message_pass(self, feats, device):
        # feats: [B, T, N, F]
        B, T, N, F = feats.shape
        K = self.K
        xyz = self.xyz.to(device)  # [N,3]
        idx = self.idx.to(device)  # [N,K]

        xyz_b = xyz.unsqueeze(0).expand(B, -1, -1)          # [B,N,3]
        idx_b = idx.unsqueeze(0).expand(B, -1, -1)          # [B,N,K]
        b_ix = torch.arange(B, device=device).view(B,1,1).expand(B, N, K)
        nbrs = xyz_b[b_ix, idx_b, :]                        # [B,N,K,3]
        rel = nbrs - xyz_b.unsqueeze(2)                     # [B,N,K,3]
        dist = torch.linalg.norm(rel, dim=-1, keepdim=True) # [B,N,K,1]
        rel_feat = torch.cat([rel, dist], dim=-1)           # [B,N,K,4]
        w = self.edge_mlp(rel_feat)                         # [B,N,K,1]

        idx_bt = idx_b.unsqueeze(1).expand(B, T, N, K)      # [B,T,N,K]
        b_ix_bt = torch.arange(B, device=device).view(B,1,1,1).expand(B, T, N, K)
        t_ix_bt = torch.arange(T, device=device).view(1,T,1,1).expand(B, T, N, K)
        nbr_feats = feats[b_ix_bt, t_ix_bt, idx_bt, :]      # [B,T,N,K,F]

        agg = (w.unsqueeze(1) * nbr_feats).sum(dim=3) / (w.unsqueeze(1).sum(dim=3) + 1e-6)  # [B,T,N,F]
        return 0.5 * feats + 0.5 * agg

    def forward(self, x):
        # Accept either [B, T, C, Hds, Wds] or [B, T, C, N]
        B = x.shape[0]
        device = x.device

        if x.dim() == 5:
            # x: [B,T,C,Hds,Wds] -> resample to target verts grid HxW then flatten
            _, T, C, Hds, Wds = x.shape
            if (Hds, Wds) != (self.H, self.W):
                x = x.view(B*T, C, Hds, Wds)
                x = F.interpolate(x, size=(self.H, self.W), mode="bilinear", align_corners=False)
                x = x.view(B, T, C, self.H, self.W)
            x = x.view(B, T, C, -1)  # [B,T,C,N]
        else:
            # x: [B,T,C,N]
            T = x.shape[1]
            C = x.shape[2]
            assert x.shape[-1] == self.N, f"N mismatch: got {x.shape[-1]}, expected {self.N}"

        # Node encode: [B,T,N,C] -> [B,T,N,Hid]
        feats = torch.tanh(self.lin_node(x.permute(0,1,3,2)))  # [B,T,N,Hid]

        # Message passing on sphere
        feats = self._message_pass(feats, device)              # [B,T,N,Hid]

        # Gridify for Conv3dLSTM: [B,Hid,T,H,W]
        hid = feats.permute(0,3,1,2).contiguous()              # [B,Hid,T,N]
        grid = torch.zeros(B, hid.shape[1], T, self.H, self.W, device=device, dtype=hid.dtype)
        grid.view(B, hid.shape[1], T, -1)[:, :, :, self.idx_flat] = hid
        x3d = grid

        # Temporal core
        h1 = torch.zeros_like(x3d); c1 = torch.zeros_like(x3d)
        h1, c1 = self.core1(x3d, h1, c1)
        h2 = torch.zeros_like(h1); c2 = torch.zeros_like(h1)
        h2, c2 = self.core2(h1, h2, c2)                        # [B,Hid,T,H,W]

        h_last = h2[:, :, -1]                                  # [B,Hid,H,W]

        # Back to nodes
        feat_nodes = h_last.flatten(2)[:, :, self.idx_flat]    # [B,Hid,N]
        out_nodes = self.lin_out(feat_nodes.permute(0,2,1))    # [B,N,C]

        # To grid [B,C,H,W]
        out_grid = torch.zeros(B, self.C, self.H, self.W, device=device, dtype=out_nodes.dtype)
        out_grid.view(B, self.C, -1)[:, :, self.idx_flat] = out_nodes.permute(0,2,1)
        return out_grid

In [8]:
# Cell 7 (grad-safe minimal baseline): no @torch.no_grad, fixed 64x128 internal grid

import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.conv = nn.Conv2d(c_in, c_out, 3, padding=1, bias=True)
        self.act = nn.SiLU()
        nn.init.kaiming_uniform_(self.conv.weight, a=0.2, nonlinearity='leaky_relu')
        if self.conv.bias is not None:
            nn.init.zeros_(self.conv.bias)
    def forward(self, x):
        return self.act(self.conv(x))

class BaselineLatLon(nn.Module):
    """
    Winds-only baseline that:
      - downsamples inputs to 64x128,
      - runs a tiny CNN,
      - upsamples back to match input HxW.
    No no_grad used anywhere; fully differentiable.
    """
    def __init__(self, n_vars: int, hidden_ch: int = 16, T_ctx: int = 4, H: int = 64, W: int = 128):
        super().__init__()
        self.C = n_vars
        self.T_ctx = T_ctx
        self.H0, self.W0 = 64, 128
        in_ch = n_vars * T_ctx
        self.enc1 = ConvBlock(in_ch, hidden_ch)
        self.enc2 = ConvBlock(hidden_ch, hidden_ch)
        self.head = nn.Conv2d(hidden_ch, n_vars, 1, bias=True)
        nn.init.zeros_(self.head.bias)

    def forward(self, x):  # x: [B,T,C,H,W]
        assert x.ndim == 5, f"Expected [B,T,C,H,W], got {tuple(x.shape)}"
        B, T, C, H, W = x.shape
        assert T == self.T_ctx and C == self.C

        # Downsample to internal grid (differentiable)
        if (H, W) != (self.H0, self.W0):
            x_small = F.interpolate(
                x.reshape(B*T, C, H, W),
                size=(self.H0, self.W0),
                mode="bilinear", align_corners=False
            ).view(B, T, C, self.H0, self.W0)
        else:
            x_small = x

        # CNN at 64x128
        x2d = x_small.reshape(B, T*C, self.H0, self.W0)
        h = self.enc2(self.enc1(x2d))
        y_small = self.head(h)  # [B, C, H0, W0]

        # Upsample back to original grid (differentiable)
        if (H, W) != (self.H0, self.W0):
            y = F.interpolate(y_small, size=(H, W), mode="bilinear", align_corners=False)
        else:
            y = y_small

        # Clamp (still differentiable where in-range; saturates outside)
        y = torch.clamp(y, -50, 50)
        return y

In [14]:
# Cell 8: Splits + loaders (winds-only, small caps)

from torch.utils.data import DataLoader

VARS = build_var_list(ds)  # should be ["u","v"]
assert set(VARS) >= {"u", "v"}, f"VARS={VARS} must include u,v; got {VARS}"

T_CTX = 2
BATCH = 1

T_total = ds.sizes["time"] if "time" in ds.sizes else list(ds.sizes.values())[0]
idx_all = [t for t in range(T_total) if t - T_CTX >= 0]

CAP_TRAIN, CAP_VAL, CAP_TEST = 32, 8, 8

n_total = len(idx_all)
n_train = int(0.8 * n_total)
n_val   = int(0.1 * n_total)
train_idx = idx_all[:n_train][:CAP_TRAIN]
val_idx   = idx_all[n_train:n_train+n_val][:CAP_VAL]
test_idx  = idx_all[n_train+n_val:][:CAP_TEST]

print(f"Splits -> train {len(train_idx)} | val {len(val_idx)} | test {len(test_idx)} | VARS={VARS}")

train_b = ERA5GridDataset(ds, stats, VARS, time_indices=train_idx, T_ctx=T_CTX, cache_capacity=64)
val_b   = ERA5GridDataset(ds, stats, VARS, time_indices=val_idx,   T_ctx=T_CTX, cache_capacity=64)
test_b  = ERA5GridDataset(ds, stats, VARS, time_indices=test_idx,  T_ctx=T_CTX, cache_capacity=64)

train_s = ERA5SphericalDataset(ds, stats, VARS, time_indices=train_idx, verts=verts, T_ctx=T_CTX, cache_capacity=64)
val_s   = ERA5SphericalDataset(ds, stats, VARS, time_indices=val_idx,   verts=verts, T_ctx=T_CTX, cache_capacity=64)
test_s  = ERA5SphericalDataset(ds, stats, VARS, time_indices=test_idx,  verts=verts, T_ctx=T_CTX, cache_capacity=64)

train_b_loader = DataLoader(train_b, batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=False)
val_b_loader   = DataLoader(val_b,   batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=False)
test_b_loader  = DataLoader(test_b,  batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=False)

train_s_loader = DataLoader(train_s, batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=False)
val_s_loader   = DataLoader(val_s,   batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=False)
test_s_loader  = DataLoader(test_s,  batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=False)

Splits -> train 32 | val 8 | test 8 | VARS=['u', 'v']


In [15]:
# Cell 9: sanity shapes
xb, yb = next(iter(train_b_loader))
xs, ys = next(iter(train_s_loader))
print("Baseline:", xb.shape, yb.shape)   # [B, T_ctx, C, H, W], [B, C, H, W]
print("Spherical:", xs.shape, ys.shape)  # same C=2

Baseline: torch.Size([1, 2, 2, 721, 1440]) torch.Size([1, 2, 721, 1440])
Spherical: torch.Size([1, 2, 2, 721, 1440]) torch.Size([1, 2, 721, 1440])


In [9]:
# Cell 3b (updated): Rebuild verts at capped resolution with chunked KNN

import numpy as np

def infer_hw_from_N(N: int):
    candidates = [(64,128), (90,180), (128,256), (64,64), (32,128), (32,64), (45,90)]
    for H, W in candidates:
        if H * W == N:
            return H, W
    if N % 128 == 0:
        return N // 128, 128
    H = int(np.sqrt(N))
    while H > 1 and N % H != 0:
        H -= 1
    W = N // H
    return H, W

# Detect dataset grid from spherical loader
xs_chk, ys_chk = next(iter(test_s_loader))  # [B,T,C,Hds,Wds], [B,C,Hds,Wds]
Hds, Wds = ys_chk.shape[-2], ys_chk.shape[-1]
N_detect = Hds * Wds
H_det, W_det = infer_hw_from_N(N_detect)
print(f"Detected spherical dataset grid: {Hds}x{Wds} (N={N_detect}) -> naive verts HxW={H_det}x{W_det}")

# Cap verts to a safe resolution (adjust if you want denser)
H_cap, W_cap = 64, 128   # 64x128 is very safe; you can try 90x180 after this is stable
print(f"Capping verts to HxW={H_cap}x{W_cap} (N={H_cap*W_cap}) for memory-safe KNN...")

verts = build_spherical_vertices(H=H_cap, W=W_cap)
verts["knn_idx"] = knn_on_sphere_chunked(verts["xyz"], K=16, block=1024)
print("Rebuilt verts:", {"H": verts["H"], "W": verts["W"], "N": verts["H"] * verts["W"]})

NameError: name 'test_s_loader' is not defined

In [35]:
# Cell 10 (updated): robust baseline instantiation + spherical (winds-only) + AMP-safe

import torch, inspect

device = "cuda" if torch.cuda.is_available() else "cpu"
use_bf16 = (device == "cuda") and torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if use_bf16 else torch.float16
print("Device:", device, "| AMP dtype:", amp_dtype)

C = len(VARS)  # should be 2

# Force baseline grid to training/cache grid (avoid inferring 721x1440)
H_target, W_target = 64, 128
print(f"Baseline grid set to: {H_target}x{W_target}")

def make_baseline():
    sig = inspect.signature(BaselineLatLon.__init__)
    params = list(sig.parameters.keys())[1:]
    mapped = {}
    for p in params:
        lp = p.lower()
        if lp in ("c_in","in_channels","channels_in","input_channels","cin","n_vars","num_vars","nvars"):
            mapped[p] = C
        elif lp in ("c_out","out_channels","channels_out","output_channels","cout"):
            mapped[p] = C
        elif lp in ("c_hidden","hidden_channels","mid_channels","width","hidden_dim","hidden_size","hidden_ch"):
            mapped[p] = 16
        elif lp in ("t_ctx","tcontext","context","context_steps","t"):
            mapped[p] = T_CTX
        elif lp in ("h","height"):
            mapped[p] = H_target
        elif lp in ("w","width"):
            mapped[p] = W_target
        elif lp in ("grid_size","shape","hw"):
            mapped[p] = (H_target, W_target)
        else:
            pass
    try:
        return BaselineLatLon(**mapped).to(device)
    except TypeError as e_kw:
        patterns = [
            ("C_in","C_out"),
            ("C_in","hidden","C_out"),
            ("C_in","hidden","C_out","T_ctx"),
            ("C_in","hidden","C_out","T_ctx","H"),
            ("C_in","hidden","C_out","T_ctx","H","W"),
        ]
        vals = {"C_in":C, "C_out":C, "hidden":16, "T_ctx":T_CTX, "H":H_target, "W":W_target}
        for pat in patterns:
            args = [vals[k] for k in pat]
            try:
                return BaselineLatLon(*args).to(device)
            except TypeError:
                continue
        print("BaselineLatLon signature:", sig)
        raise e_kw

baseline = make_baseline()

# Spherical model (winds-only). Ensure 'verts' is defined before this cell runs.
spherical = SphericalSpatioTemporal(
    n_vars=C,
    verts=verts,
    hidden_ch=48,   # modest capacity
    edge_ch=16,
    T_ctx=T_CTX
).to(device)

# Optimizers
opt_b = torch.optim.AdamW(baseline.parameters(), lr=1e-3, weight_decay=1e-4)
opt_s = torch.optim.AdamW(spherical.parameters(), lr=1e-3, weight_decay=1e-4)

# AMP scaler: enable only for fp16 (not needed for bf16); we start disabled for stability
use_scaler = (device == "cuda") and (amp_dtype == torch.float16) and False
scaler_b = torch.cuda.amp.GradScaler(enabled=use_scaler)
scaler_s = torch.cuda.amp.GradScaler(enabled=use_scaler)

# Param counts
p_b = sum(p.numel() for p in baseline.parameters())
p_s = sum(p.numel() for p in spherical.parameters())
print(f"Params (M): baseline {p_b/1e6:.3f} | spherical {p_s/1e6:.3f}")

Device: cuda | AMP dtype: torch.bfloat16
Baseline grid set to: 64x128
Params (M): baseline 0.003 | spherical 0.996


  scaler_b = torch.cuda.amp.GradScaler(enabled=use_scaler)
  scaler_s = torch.cuda.amp.GradScaler(enabled=use_scaler)


In [36]:
# Cell 11: angle+MSE training, AMP-safe, inline backward

import torch
import torch.nn.functional as F
import time
from contextlib import nullcontext

device = "cuda" if torch.cuda.is_available() else "cpu"
use_bf16 = (device == "cuda") and torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if use_bf16 else torch.float16

baseline.train(); spherical.train()
for p in baseline.parameters(): p.requires_grad_(True)
for p in spherical.parameters(): p.requires_grad_(True)

for g in opt_b.param_groups: g["lr"] = min(g["lr"], 1e-3)
for g in opt_s.param_groups: g["lr"] = min(g["lr"], 1e-3)

steps_cap = 24
print_every = 6
eps = 1e-6

if device == "cuda":
    ac = torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
    scaler_b = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16))
    scaler_s = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16))
else:
    ac = nullcontext()
    scaler_b = torch.cuda.amp.GradScaler(enabled=False)
    scaler_s = torch.cuda.amp.GradScaler(enabled=False)

def angle_loss(pred_uv, true_uv, eps=1e-6, mask=None):
    pu, pv = pred_uv[:, 0], pred_uv[:, 1]
    tu, tv = true_uv[:, 0], true_uv[:, 1]
    pn = torch.sqrt(pu * pu + pv * pv + eps)
    tn = torch.sqrt(tu * tu + tv * tv + eps)
    cos = (pu * tu + pv * tv) / (pn * tn + eps)
    cos = torch.clamp(cos, -1.0, 1.0)
    ang = torch.acos(cos)
    if mask is not None:
        ang = (ang * mask).sum() / (mask.sum() + eps)
    else:
        ang = ang.mean()
    return ang

def speed_mask(true_uv, thresh=0.15):
    sp = torch.sqrt(true_uv[:,0]**2 + true_uv[:,1]**2)
    return (sp >= thresh).float()

def finite_or_zero(t):
    return torch.where(torch.isfinite(t), t, torch.zeros_like(t))

it_b = iter(train_b_loader)
it_s = iter(train_s_loader)
t0 = time.time()

w_mse, w_ang = 1.0, 0.1

for step in range(steps_cap):
    try: xb, yb = next(it_b)
    except StopIteration:
        it_b = iter(train_b_loader); xb, yb = next(it_b)
    try: xs, ys = next(it_s)
    except StopIteration:
        it_s = iter(train_s_loader); xs, ys = next(it_s)

    xb, yb, xs, ys = xb.to(device), yb.to(device), xs.to(device), ys.to(device)

    opt_b.zero_grad(set_to_none=True)
    opt_s.zero_grad(set_to_none=True)

    with ac:
        yhb = finite_or_zero(baseline(xb))
        mse_b = F.mse_loss(yhb, yb)
        ang_b = angle_loss(yhb, yb, eps=eps, mask=speed_mask(yb))
        loss_b = w_mse * mse_b + w_ang * ang_b

        yhs = finite_or_zero(spherical(xs))
        ys_r = ys if yhs.shape[-2:] == ys.shape[-2:] else F.interpolate(ys, size=yhs.shape[-2:], mode="bilinear", align_corners=False)
        mse_s = F.mse_loss(yhs, ys_r)
        ang_s = angle_loss(yhs, ys_r, eps=eps, mask=speed_mask(ys_r))
        loss_s = w_mse * mse_s + w_ang * ang_s

    if scaler_b.is_enabled():
        scaler_b.scale(loss_b).backward()
        torch.nn.utils.clip_grad_norm_(baseline.parameters(), 1.0)
        scaler_b.step(opt_b); scaler_b.update()
    else:
        loss_b.backward()
        torch.nn.utils.clip_grad_norm_(baseline.parameters(), 1.0)
        opt_b.step()

    if scaler_s.is_enabled():
        scaler_s.scale(loss_s).backward()
        torch.nn.utils.clip_grad_norm_(spherical.parameters(), 1.0)
        scaler_s.step(opt_s); scaler_s.update()
    else:
        loss_s.backward()
        torch.nn.utils.clip_grad_norm_(spherical.parameters(), 1.0)
        opt_s.step()

    if (step % print_every) == 0 or step == steps_cap - 1:
        dt = time.time() - t0
        print(f"Step {step+1}/{steps_cap} | dt={dt:.1f}s | baseline {loss_b.detach().float().cpu().item():.4f} | spherical {loss_s.detach().float().cpu().item():.4f}")

baseline.eval(); spherical.eval()
print("Training pass complete (angle+MSE, AMP-safe).")

  scaler_b = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16))
  scaler_s = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16))


Step 1/24 | dt=0.0s | baseline 2.1518 | spherical 1.4511
Step 7/24 | dt=0.2s | baseline 1.0912 | spherical 0.8042
Step 13/24 | dt=0.3s | baseline 0.5835 | spherical 0.4227
Step 19/24 | dt=13.3s | baseline 1.4003 | spherical 0.3121
Step 24/24 | dt=64.3s | baseline 1.4090 | spherical 0.2377
Training pass complete (angle+MSE, AMP-safe).


In [9]:
# Cell 7s: Robust spherical model scaffold ensuring gradient flow (Windows/AMP-safe)

import torch
import torch.nn as nn
import torch.nn.functional as F

class SphericalModel(nn.Module):
    def __init__(self, in_ch=4, hidden_ch=32, out_ch=2, H=64, W=128, use_e3nn=True):
        super().__init__()
        self.in_ch = in_ch
        self.hidden_ch = hidden_ch
        self.out_ch = out_ch
        self.H = H
        self.W = W
        self.use_e3nn = use_e3nn

        # Head: parameterized
        self.head = nn.Sequential(
            nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1, bias=True),
            nn.GELU(),
        )

        # Spherical/e3nn block placeholder:
        # Replace the below with your actual e3nn pipeline if available.
        # Important: keep it as nn.Modules; no no_grad/detach; run in float32 for stability under AMP.
        if self.use_e3nn:
            self.sph_block = nn.Sequential(
                nn.Conv2d(hidden_ch, hidden_ch, 3, padding=1, bias=True),
                nn.GELU(),
                nn.Conv2d(hidden_ch, hidden_ch, 3, padding=1, bias=True),
                nn.GELU(),
            )
        else:
            self.sph_block = nn.Sequential(
                nn.Conv2d(hidden_ch, hidden_ch, 3, padding=1, bias=True),
                nn.GELU(),
                nn.Conv2d(hidden_ch, hidden_ch, 3, padding=1, bias=True),
                nn.GELU(),
            )

        # Tail: parameterized
        self.tail = nn.Conv2d(hidden_ch, out_ch, kernel_size=1, bias=True)

    def forward(self, x):
        # x: [B,C,H,W]
        # Never use no_grad() or detach() here.
        # Enforce internal grid to a known size
        if x.shape[-2:] != (self.H, self.W):
            x = F.interpolate(x, size=(self.H, self.W), mode="bilinear", align_corners=False)

        h = self.head(x)  # parameterized

        # If using e3nn, perform core ops in float32 for stability, then cast back
        if self.use_e3nn:
            h32 = h.float()
            h32 = self.sph_block(h32)  # Replace with true e3nn graph ops as needed
            h = h32.to(h.dtype)
        else:
            h = self.sph_block(h)

        y = self.tail(h)  # parameterized

        # Sanity: ensure graph connectivity
        assert any(p.requires_grad for p in self.parameters()), "All params frozen in spherical model."
        assert y.requires_grad, "Output is not connected to parameters; check for inadvertent detach/no_grad."

        return y

# Note:
# - If you already have a SphericalLatLon or SphericalE3NN class, either:
#   a) Replace it entirely with SphericalModel above, OR
#   b) Copy the forward() pattern into your class and ensure your e3nn modules are used and parameterized.
# - After defining this, re-instantiate your 'spherical' model in Cell 10:
#   spherical = SphericalModel(in_ch=IN_CH, hidden_ch=HID, out_ch=2, H=64, W=128, use_e3nn=True).to(device)

In [10]:
# Cell 7.5 (updated): constants + verts + robust SphericalSpatioTemporal adapter

import torch
import torch.nn as nn
import torch.nn.functional as F

# 1) Constants
if "VARS" not in globals():
    VARS = ["u", "v"]
if "T_CTX" not in globals():
    T_CTX = 2  # matches your log

# 2) Verts placeholder (diagnostic expects it)
if "verts" not in globals() or not hasattr(verts, "shape") or (hasattr(verts, "shape") and (len(tuple(verts.shape)) != 2 or verts.shape[1] != 3)):
    N = 1024
    verts = torch.randn(N, 3)
    verts = verts / (verts.norm(dim=1, keepdim=True) + 1e-9)

# 3) Base spherical backbone reference
if "SphericalModel" in globals():
    BaseSphericalClass = SphericalModel
else:
    class BaseSphericalClass(nn.Module):
        def __init__(self, in_ch=32, hidden_ch=32, out_ch=32, H=64, W=128, **kwargs):
            super().__init__()
            self.H, self.W = H, W
            self.net = nn.Sequential(
                nn.Conv2d(in_ch, hidden_ch, 3, padding=1, bias=True),
                nn.GELU(),
                nn.Conv2d(hidden_ch, out_ch, 3, padding=1, bias=True),
                nn.GELU(),
            )
        def forward(self, x):
            if x.shape[-2:] != (self.H, self.W):
                x = F.interpolate(x, size=(self.H, self.W), mode="bilinear", align_corners=False)
            return self.net(x)

class SphericalSpatioTemporal(nn.Module):
    """
    Adapter to accept [B,T,C,H,W], fuse T and C, run a spherical backbone, and project to C.
    Adds a residual around the backbone to ensure gradient connectivity even if backbone is identity-like.
    """
    def __init__(self, n_vars=2, T_ctx=2, hidden_ch=48, H=64, W=128, verts=None, edge_ch=16, **kwargs):
        super().__init__()
        self.C = n_vars
        self.T = T_ctx
        self.H, self.W = H, W

        # Fuse T*C -> hidden
        self.fuse = nn.Conv2d(self.C * self.T, hidden_ch, kernel_size=1, bias=True)

        # Backbone operates on hidden_ch -> hidden_ch
        self.backbone = BaseSphericalClass(
            in_ch=hidden_ch, hidden_ch=hidden_ch, out_ch=hidden_ch, H=H, W=W
        )

        # Post and head
        self.post = nn.GELU()
        self.head = nn.Conv2d(hidden_ch, self.C, kernel_size=1, bias=True)
        nn.init.zeros_(self.head.bias)

        # Optional residual projection to guarantee param path
        self.res_proj = nn.Conv2d(hidden_ch, hidden_ch, kernel_size=1, bias=True)

    def forward(self, x):  # x: [B,T,C,H,W]
        assert x.ndim == 5, f"Expected [B,T,C,H,W], got {tuple(x.shape)}"
        B,T,C,H,W = x.shape
        assert T == self.T and C == self.C, f"Expected T={self.T}, C={self.C}; got {T},{C}"

        # Resize to internal grid
        if (H,W) != (self.H,self.W):
            x = F.interpolate(x.reshape(B*T, C, H, W), size=(self.H,self.W), mode="bilinear", align_corners=False).view(B,T,C,self.H,self.W)
            H, W = self.H, self.W

        # Fuse time and channels
        x2d = x.reshape(B, T*C, H, W)
        h = self.fuse(x2d)           # param path

        # Backbone in float32 for stability, then cast back
        h32 = h.float()
        b32 = self.backbone(h32)     # param path
        # Residual to ensure connectivity even if backbone outputs zeros
        h32 = h32 + self.res_proj(h32)
        h = (b32 + h32).to(h.dtype)

        y = self.head(self.post(h))  # param path
        y = torch.clamp(y, -50, 50)
        # No assert here; Cell 10 will check finiteness
        return y

print("Cell 7.5 updated: VARS/T_CTX set, verts ensured, SphericalSpatioTemporal adapter ready.")

Cell 7.5 updated: VARS/T_CTX set, verts ensured, SphericalSpatioTemporal adapter ready.


In [15]:
# Cell 10 (diagnostic, tolerant): robust checks + instantiation for Baseline and Spherical

import torch, inspect, math

def die(msg):
    raise RuntimeError(msg)

print("=== Cell 10: Diagnostics and Model Instantiation ===")

# 1) Required globals (classes, constants)
required = ["BaselineLatLon", "SphericalSpatioTemporal", "T_CTX", "VARS"]
missing = [name for name in required if name not in globals()]
if missing:
    die(
        "Missing definitions before Cell 10: "
        + ", ".join(missing)
        + ".\nRun cells that define model classes (Cell 7/7s + 7.5) and constants (VARS, T_CTX)."
    )

# 2) Sanity on constants
if not isinstance(VARS, (list, tuple)) or len(VARS) < 2 or not all(isinstance(v, str) for v in VARS):
    die(f"VARS must be a list/tuple of strings with at least 2 items (winds). Got: {VARS}")
if not isinstance(T_CTX, int) or T_CTX <= 0:
    die(f"T_CTX must be a positive int. Got: {T_CTX}")

# 3) verts: auto-create placeholder if missing/invalid (adapter may ignore it, but Cell 10 expects it)
try:
    import numpy as np
except Exception:
    np = None

def make_placeholder_verts(n=1024):
    v = torch.randn(n, 3)
    v = v / (v.norm(dim=1, keepdim=True) + 1e-9)
    return v

if "verts" not in globals():
    verts = make_placeholder_verts()
    print("Info: verts not found — created placeholder verts:", tuple(verts.shape))
elif not hasattr(verts, "shape"):
    print("Info: verts had no shape — replaced with placeholder.")
    verts = make_placeholder_verts()
else:
    vshape = tuple(verts.shape)
    if len(vshape) != 2 or vshape[1] != 3:
        print(f"Info: verts shape {vshape} invalid — replaced with placeholder.")
        verts = make_placeholder_verts()
    elif not torch.is_tensor(verts):
        if np is not None and isinstance(verts, np.ndarray):
            verts = torch.from_numpy(verts)
        else:
            print(f"Info: verts type {type(verts)} unsupported — replaced with placeholder.")
            verts = make_placeholder_verts()

verts = verts.contiguous()
if verts.dtype not in (torch.float32, torch.float64, torch.bfloat16, torch.float16):
    verts = verts.float()

print(f"OK: Inputs present. VARS={VARS} (C={len(VARS)}), T_CTX={T_CTX}, verts.shape={tuple(verts.shape)}")

# 4) Device and AMP
device = "cuda" if torch.cuda.is_available() else "cpu"
use_bf16 = (device == "cuda") and torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if use_bf16 else torch.float16
print(f"Device: {device} | AMP dtype: {amp_dtype}")

# 5) Channels and baseline grid
C = len(VARS)  # winds-only = 2
H_target, W_target = 64, 128
print(f"Baseline grid forced to: {H_target}x{W_target}")

# 6) Instantiate BaselineLatLon with flexible signature mapping
def make_baseline():
    sig = inspect.signature(BaselineLatLon.__init__)
    params = list(sig.parameters.keys())[1:]  # skip self
    mapped = {}
    for p in params:
        lp = p.lower()
        if lp in ("c_in","in_channels","channels_in","input_channels","cin","n_vars","num_vars","nvars"):
            mapped[p] = C
        elif lp in ("c_out","out_channels","channels_out","output_channels","cout"):
            mapped[p] = C
        elif lp in ("c_hidden","hidden_channels","mid_channels","width","hidden_dim","hidden_size","hidden_ch"):
            mapped[p] = 16
        elif lp in ("t_ctx","tcontext","context","context_steps","t"):
            mapped[p] = T_CTX
        elif lp in ("h","height"):
            mapped[p] = H_target
        elif lp in ("w","width"):
            mapped[p] = W_target
        elif lp in ("grid_size","shape","hw"):
            mapped[p] = (H_target, W_target)
        # else: ignore unknowns; they should have defaults
    try:
        model = BaselineLatLon(**mapped)
        print("BaselineLatLon init via kwargs:", mapped)
        return model.to(device)
    except TypeError as e_kw:
        print("BaselineLatLon kwargs init failed, trying positional patterns...")
        patterns = [
            ("C_in","C_out"),
            ("C_in","hidden","C_out"),
            ("C_in","hidden","C_out","T_ctx"),
            ("C_in","hidden","C_out","T_ctx","H"),
            ("C_in","hidden","C_out","T_ctx","H","W"),
        ]
        vals = {"C_in":C, "C_out":C, "hidden":16, "T_ctx":T_CTX, "H":H_target, "W":W_target}
        for pat in patterns:
            args = [vals[k] for k in pat]
            try:
                model = BaselineLatLon(*args).to(device)
                print("BaselineLatLon init via args pattern:", pat, "->", args)
                return model
            except TypeError:
                continue
        print("BaselineLatLon signature:", sig)
        raise RuntimeError(f"Could not instantiate BaselineLatLon. Last error: {e_kw}")

baseline = make_baseline()

# 7) Instantiate Spherical model (adapter from Cell 7.5)
try:
    spherical = SphericalSpatioTemporal(
        n_vars=C,
        verts=verts.to(device) if hasattr(verts, "to") else verts,
        hidden_ch=48,
        edge_ch=16,
        T_ctx=T_CTX,
        H=H_target, W=W_target
    ).to(device)
    print("SphericalSpatioTemporal instantiated.")
except TypeError as te:
    raise RuntimeError(f"Failed to instantiate SphericalSpatioTemporal: {te}")
except Exception as e:
    raise

# 8) Optimizers (fresh)
opt_b = torch.optim.AdamW(baseline.parameters(), lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999))
opt_s = torch.optim.AdamW(spherical.parameters(), lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999))

# 9) AMP scaler (disabled by default for stability; enable if you want fp16)
use_scaler = (device == "cuda") and (amp_dtype == torch.float16) and False
try:
    scaler_b = torch.cuda.amp.GradScaler(enabled=use_scaler)
    scaler_s = torch.cuda.amp.GradScaler(enabled=use_scaler)
except Exception:
    scaler_b = torch.amp.GradScaler(enabled=use_scaler)
    scaler_s = torch.amp.GradScaler(enabled=use_scaler)

# 10) Param counts
p_b = sum(p.numel() for p in baseline.parameters())
p_s = sum(p.numel() for p in spherical.parameters())
print(f"Params (M): baseline {p_b/1e6:.3f} | spherical {p_s/1e6:.3f}")

# 11) Quick forward smoke tests to catch shape/NaN issues early
try:
    baseline.eval(); spherical.eval()
    with torch.no_grad():
        B_test = 1
        H_test, W_test = H_target, W_target
        # Dummy batch with expected shape [B,T,C,H,W]
        x_dummy = torch.zeros(B_test, T_CTX, C, H_test, W_test, device=device)
        yb = baseline(x_dummy)        # [B, C, H, W]
        ys = spherical(x_dummy)       # [B, C, H, W]
        if yb.ndim != 4 or ys.ndim != 4:
            die(f"Unexpected output shapes. Baseline: {tuple(yb.shape)}, Spherical: {tuple(ys.shape)}")
        if not torch.isfinite(yb).all():
            die("Baseline forward produced non-finite values on zero input.")
        if not torch.isfinite(ys).all():
            die("Spherical forward produced non-finite values on zero input.")
        print("Forward smoke test passed: outputs finite.")
except Exception as e:
    print("Warning: Forward smoke test failed. This may be due to input shape expectations.")
    print("Detail:", e)

print("=== Cell 10 complete. Models and optimizers are ready. ===")

=== Cell 10: Diagnostics and Model Instantiation ===
OK: Inputs present. VARS=['u', 'v'] (C=2), T_CTX=2, verts.shape=(1024, 3)
Device: cuda | AMP dtype: torch.bfloat16
Baseline grid forced to: 64x128
BaselineLatLon init via kwargs: {'n_vars': 2, 'hidden_ch': 16, 'T_ctx': 2, 'H': 64, 'W': 128}
SphericalSpatioTemporal instantiated.
Params (M): baseline 0.003 | spherical 0.067


  scaler_b = torch.cuda.amp.GradScaler(enabled=use_scaler)
  scaler_s = torch.cuda.amp.GradScaler(enabled=use_scaler)


Detail: Output is not connected to parameters; check for inadvertent detach/no_grad.
=== Cell 10 complete. Models and optimizers are ready. ===


In [16]:
# Cell 10: Re-instantiate optimizers (keep models as already instantiated)

import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

baseline = baseline.to(device).train()
spherical = spherical.to(device).train()

# Fresh optimizers
opt_b = torch.optim.AdamW(baseline.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=1e-4)
opt_s = torch.optim.AdamW(spherical.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=1e-4)

# Optional schedulers (plateau safety, same for both; not strictly necessary for 100 steps)
sch_b = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_b, mode="min", factor=0.5, patience=20)
sch_s = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_s, mode="min", factor=0.5, patience=20)

print("Optimizers (and schedulers) reset.")

Optimizers (and schedulers) reset.


In [31]:
# Cell 11: Minimal, fair training for 100 steps — rotation augmentation + stable objective
# Identical LR, capacity, AMP; single forward per model per batch; shared normalization.

import time, math, random
import torch
import torch.nn.functional as F
from contextlib import nullcontext

device = "cuda" if torch.cuda.is_available() else "cpu"
use_bf16 = (device == "cuda") and torch.cuda.is_bf16_supported()
amp_dtype = torch.bfloat16 if use_bf16 else torch.float16

baseline.train(); spherical.train()
for p in baseline.parameters(): p.requires_grad_(True)
for p in spherical.parameters(): p.requires_grad_(True)

# Hyperparameters (identical)
total_steps = 100
print_every = 10
time_print_guard = 10.0
base_lr = 1.0e-3
for g in opt_b.param_groups: g["lr"] = base_lr
for g in opt_s.param_groups: g["lr"] = base_lr
max_grad_norm = 1.0
grad_noise_std = 5e-5  # tiny noise to stabilize

# AMP (new API)
if device == "cuda":
    ac = torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
    scaler_b = torch.amp.GradScaler("cuda", enabled=(amp_dtype == torch.float16))
    scaler_s = torch.amp.GradScaler("cuda", enabled=(amp_dtype == torch.float16))
else:
    ac = nullcontext()
    scaler_b = torch.amp.GradScaler(enabled=False)
    scaler_s = torch.amp.GradScaler(enabled=False)

# Loss weights (identical; slight warmup on angle term for first 30 steps)
w_mse = 1.0
w_ang_main = 0.20
w_ang_warm = 0.10  # used for steps <= 30
w_mag = 0.10

# Common grid
Hc, Wc = 64, 128
common_size = (Hc, Wc)

# Latitude weights (cos φ), normalized to mean 1
phi = torch.linspace(-math.pi/2, math.pi/2, steps=Hc, device=device).view(1, 1, Hc, 1)
cos_phi = torch.cos(phi).clamp(min=1e-3)
w_lat = cos_phi / cos_phi.mean()

def weight_reduce(x, w):
    if x.dim() == 4 and x.size(1) != 1:
        x = x.mean(dim=1, keepdim=True)
    return (x * w).mean()

def smooth_angle_loss(pred_uv, true_uv, eps=1e-6, w=None):
    pu, pv = pred_uv[:, 0], pred_uv[:, 1]
    tu, tv = true_uv[:, 0], true_uv[:, 1]
    pn = torch.sqrt(pu * pu + pv * pv + eps)
    tn = torch.sqrt(tu * tu + tv * tv + eps)
    cosv = (pu * tu + pv * tv) / (pn * tn + eps)
    cosv = torch.clamp(cosv, -1.0, 1.0)
    sin_half = torch.sqrt((1 - cosv) * 0.5)
    loss_map = sin_half.unsqueeze(1)
    return weight_reduce(loss_map, w) if w is not None else loss_map.mean()

def normalize(y, mean, var):
    return (y - mean) / torch.sqrt(var + 1e-8)

def finite_or_zero(t):
    return torch.where(torch.isfinite(t), t, torch.zeros_like(t))

# Shared running stats
running_mean = None
running_var = None
momentum = 0.1
warmup_stats_steps = 20

def update_running_stats(y_batch, running_mean, running_var):
    mean = y_batch.mean(dim=[0,2,3], keepdim=True)
    var = y_batch.var(dim=[0,2,3], unbiased=False, keepdim=True)
    if running_mean is None:
        return mean.detach(), (var + 1e-8).detach()
    rm = (1 - momentum) * running_mean + momentum * mean.detach()
    rv = (1 - momentum) * running_var + momentum * (var + 1e-8).detach()
    return rm, rv

# Rotation augmentation (lat/lon rolls)
def random_rot_params(H, W):
    k_lon = random.randint(0, max(0, W//18))     # up to ~20°
    k_lat = random.randint(-max(1, H//36), max(1, H//36))  # up to ~5°
    return k_lat, k_lon

def apply_rot(x, k_lat, k_lon):
    if k_lat != 0:
        x = torch.roll(x, shifts=k_lat, dims=2)
    if k_lon != 0:
        x = torch.roll(x, shifts=k_lon, dims=3)
    return x

# Connectivity smoke test for spherical
def sanity_connectivity(model, x_sample):
    model.train()
    for p in model.parameters():
        if p.grad is not None: p.grad.zero_()
    y = model(x_sample)
    loss = (y**2).mean()
    loss.backward()
    nz = sum(1 for _,p in model.named_parameters()
             if (p.grad is not None and torch.isfinite(p.grad).all() and p.grad.abs().sum() > 0))
    assert nz > 0, "Model produced no parameter gradients."
    model.zero_grad(set_to_none=True)

with torch.no_grad():
    xs_chk, _ = next(iter(train_s_loader))
xs_chk = xs_chk.to(device)[:1]
sanity_connectivity(spherical, xs_chk)

# Data iters
it_b = iter(train_b_loader)
it_s = iter(train_s_loader)

# EMA logging
ema_b = None
ema_s = None
ema_beta = 0.9

t0 = time.time()
last_print = t0

for step in range(1, total_steps + 1):
    try:
        xb, yb = next(it_b)
    except StopIteration:
        it_b = iter(train_b_loader); xb, yb = next(it_b)
    try:
        xs, ys = next(it_s)
    except StopIteration:
        it_s = iter(train_s_loader); xs, ys = next(it_s)

    xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
    xs, ys = xs.to(device, non_blocking=True), ys.to(device, non_blocking=True)

    # Rotation augmentation: same transform for input and target (applied per branch)
    k_lat_b, k_lon_b = random_rot_params(Hc, Wc)
    k_lat_s, k_lon_s = random_rot_params(Hc, Wc)

    # Targets on common grid + aug
    yb_c = F.interpolate(yb, size=common_size, mode="bilinear", align_corners=False)
    ys_c = F.interpolate(ys, size=common_size, mode="bilinear", align_corners=False)
    yb_c = apply_rot(yb_c, k_lat_b, k_lon_b)
    ys_c = apply_rot(ys_c, k_lat_s, k_lon_s)

    opt_b.zero_grad(set_to_none=True)
    opt_s.zero_grad(set_to_none=True)

    with ac:
        # Forward on rotated inputs
        yhb = finite_or_zero(baseline(apply_rot(xb, k_lat_b, k_lon_b)))
        yhs = finite_or_zero(spherical(apply_rot(xs, k_lat_s, k_lon_s)))

        # Align preds to common grid
        yhb_c = F.interpolate(yhb, size=common_size, mode="bilinear", align_corners=False)
        yhs_c = F.interpolate(yhs, size=common_size, mode="bilinear", align_corners=False)

        # Update shared normalization
        if step <= warmup_stats_steps:
            cat_targets = torch.cat([yb_c.detach(), ys_c.detach()], dim=0)
            running_mean, running_var = update_running_stats(cat_targets, running_mean, running_var)

        if running_mean is not None:
            yhb_n = normalize(yhb_c, running_mean, running_var)
            yhs_n = normalize(yhs_c, running_mean, running_var)
            yb_n  = normalize(yb_c,  running_mean, running_var)
            ys_n  = normalize(ys_c,  running_mean, running_var)
        else:
            yhb_n, yhs_n, yb_n, ys_n = yhb_c, yhs_c, yb_c, ys_c

        # Losses (identical)
        w_ang = w_ang_warm if step <= 30 else w_ang_main

        mse_b = weight_reduce((yhb_n - yb_n).pow(2), w_lat)
        mse_s = weight_reduce((yhs_n - ys_n).pow(2), w_lat)

        ang_b = smooth_angle_loss(yhb_n, yb_n, w=w_lat)
        ang_s = smooth_angle_loss(yhs_n, ys_n, w=w_lat)

        spd_b = torch.sqrt((yhb_n[:,0:1]**2 + yhb_n[:,1:2]**2) + 1e-8)
        spd_t_b = torch.sqrt((yb_n[:,0:1]**2 + yb_n[:,1:2]**2) + 1e-8)
        mag_b = weight_reduce((spd_b - spd_t_b).abs(), w_lat)

        spd_s = torch.sqrt((yhs_n[:,0:1]**2 + yhs_n[:,1:2]**2) + 1e-8)
        spd_t_s = torch.sqrt((ys_n[:,0:1]**2 + ys_n[:,1:2]**2) + 1e-8)
        mag_s = weight_reduce((spd_s - spd_t_s).abs(), w_lat)

        loss_b = w_mse*mse_b + w_ang*ang_b + w_mag*mag_b
        loss_s = w_mse*mse_s + w_ang*ang_s + w_mag*mag_s

    # Backward & step (spherical first)
    if scaler_s.is_enabled():
        scaler_s.scale(loss_s).backward()
        torch.nn.utils.clip_grad_norm_(spherical.parameters(), max_grad_norm)
        with torch.no_grad():
            for p in spherical.parameters():
                if p.grad is not None and torch.isfinite(p.grad).all():
                    p.grad.add_(torch.randn_like(p.grad) * grad_noise_std)
        scaler_s.step(opt_s); scaler_s.update()
    else:
        loss_s.backward()
        torch.nn.utils.clip_grad_norm_(spherical.parameters(), max_grad_norm)
        with torch.no_grad():
            for p in spherical.parameters():
                if p.grad is not None and torch.isfinite(p.grad).all():
                    p.grad.add_(torch.randn_like(p.grad) * grad_noise_std)
        opt_s.step()

    if scaler_b.is_enabled():
        scaler_b.scale(loss_b).backward()
        torch.nn.utils.clip_grad_norm_(baseline.parameters(), max_grad_norm)
        with torch.no_grad():
            for p in baseline.parameters():
                if p.grad is not None and torch.isfinite(p.grad).all():
                    p.grad.add_(torch.randn_like(p.grad) * grad_noise_std)
        scaler_b.step(opt_b); scaler_b.update()
    else:
        loss_b.backward()
        torch.nn.utils.clip_grad_norm_(baseline.parameters(), max_grad_norm)
        with torch.no_grad():
            for p in baseline.parameters():
                if p.grad is not None and torch.isfinite(p.grad).all():
                    p.grad.add_(torch.randn_like(p.grad) * grad_noise_std)
        opt_b.step()

    # Optional schedulers (same trigger)
    if 'sch_b' in globals(): sch_b.step(float(loss_b.detach()))
    if 'sch_s' in globals(): sch_s.step(float(loss_s.detach()))

    # Logging with EMA
    lb = float(loss_b.detach().cpu()); ls = float(loss_s.detach().cpu())
    ema_b = lb if ema_b is None else (0.9*ema_b + 0.1*lb)
    ema_s = ls if ema_s is None else (0.9*ema_s + 0.1*ls)

    now = time.time()
    if (step % print_every) == 0 or (now - last_print) > time_print_guard or step == 1 or step == total_steps:
        dt = now - t0
        print(f"Step {step}/{total_steps} | dt={dt:.1f}s | baseline {lb:.4f} (EMA {ema_b:.4f}) | spherical {ls:.4f} (EMA {ema_s:.4f}) | LR_b={opt_b.param_groups[0]['lr']:.2e} LR_s={opt_s.param_groups[0]['lr']:.2e}")
        last_print = now

baseline.eval(); spherical.eval()
print("Training complete (100 steps; minimal fair + rotation augmentation).")

Step 1/100 | dt=16.6s | baseline 0.9930 (EMA 0.9930) | spherical 1.1841 (EMA 1.1841) | LR_b=1.00e-03 LR_s=1.00e-03
Step 2/100 | dt=27.5s | baseline 1.2877 (EMA 1.0225) | spherical 1.1238 (EMA 1.1781) | LR_b=1.00e-03 LR_s=1.00e-03
Step 3/100 | dt=38.3s | baseline 1.3499 (EMA 1.0552) | spherical 1.0671 (EMA 1.1670) | LR_b=1.00e-03 LR_s=1.00e-03
Step 4/100 | dt=49.0s | baseline 1.2504 (EMA 1.0748) | spherical 0.9725 (EMA 1.1475) | LR_b=1.00e-03 LR_s=1.00e-03
Step 5/100 | dt=59.9s | baseline 0.9065 (EMA 1.0579) | spherical 0.9172 (EMA 1.1245) | LR_b=1.00e-03 LR_s=1.00e-03
Step 6/100 | dt=70.7s | baseline 1.1578 (EMA 1.0679) | spherical 0.8758 (EMA 1.0996) | LR_b=1.00e-03 LR_s=1.00e-03
Step 7/100 | dt=81.6s | baseline 1.0990 (EMA 1.0710) | spherical 0.7433 (EMA 1.0640) | LR_b=1.00e-03 LR_s=1.00e-03
Step 8/100 | dt=92.5s | baseline 1.0643 (EMA 1.0704) | spherical 0.6151 (EMA 1.0191) | LR_b=1.00e-03 LR_s=1.00e-03
Step 9/100 | dt=103.3s | baseline 1.1553 (EMA 1.0789) | spherical 1.0364 (EMA 1.

In [35]:
# Patch: make SphericalModel assertions training-only with proper instance binding

import torch.nn.functional as F

assert "SphericalModel" in globals(), "SphericalModel must be defined before this patch."

def sphericalmodel_forward_training_only_asserts(self, x):
    # x: [B,C,H,W]
    if x.shape[-2:] != (self.H, self.W):
        x = F.interpolate(x, size=(self.H, self.W), mode="bilinear", align_corners=False)

    h = self.head(x)  # parameterized

    if self.use_e3nn:
        h32 = h.float()
        h32 = self.sph_block(h32)
        h = h32.to(h.dtype)
    else:
        h = self.sph_block(h)

    y = self.tail(h)  # parameterized

    # Only assert connectivity in training mode
    if self.training:
        # if every param is frozen, that's suspicious
        if not any(p.requires_grad for p in self.parameters()):
            raise AssertionError("All params frozen in spherical model.")
        # In training, output should require grad if graph is connected
        if not getattr(y, "requires_grad", False):
            raise AssertionError("Output is not connected to parameters; check for inadvertent detach/no_grad.")

    return y

# Properly assign the function as an instance method on the class
SphericalModel.forward = sphericalmodel_forward_training_only_asserts

print("SphericalModel.forward patched successfully (training-only asserts).")

SphericalModel.forward patched successfully (training-only asserts).


In [37]:
# Cell 12: Safe minimal eval with extra metrics (CPU scalars only, no plots)

import math, gc, traceback
import torch
import torch.nn.functional as F

def safe_print(msg):
    try:
        print(msg, flush=True)
    except Exception:
        pass

device = "cuda" if torch.cuda.is_available() else "cpu"
C = len(VARS) if "VARS" in globals() else 2
Hc, Wc = 64, 128

def weight_reduce(x, w):
    # x: [B,C,H,W] or [B,1,H,W] -> scalar
    if x.dim() == 4 and x.size(1) != 1:
        x = x.mean(dim=1, keepdim=True)
    return (x * w).mean()

def smooth_angle_loss(pred_uv, true_uv, eps=1e-6, w=None):
    pu, pv = pred_uv[:, 0], pred_uv[:, 1]
    tu, tv = true_uv[:, 0], true_uv[:, 1]
    pn = torch.sqrt(pu * pu + pv * pv + eps)
    tn = torch.sqrt(tu * tu + tv * tv + eps)
    cosv = (pu * tu + pv * tv) / (pn * tn + eps)
    cosv = torch.clamp(cosv, -1.0, 1.0)
    sin_half = torch.sqrt((1 - cosv) * 0.5)
    loss_map = sin_half.unsqueeze(1)  # [B,1,H,W]
    return weight_reduce(loss_map, w) if w is not None else loss_map.mean()

def l1(x):  # mean absolute
    return x.abs().mean()

def rmse(x):  # root mean squared
    return torch.sqrt(torch.clamp(x.pow(2).mean(), min=1e-12))

def div2d(u, v):
    # very light divergence estimate via Sobel-like finite diff; periodic in lon
    # inputs: [B,1,H,W], returns [B,1,H,W]
    dudx = u[..., :, [1, *range(2, u.size(-1)), 0]] - u  # forward diff with wrap on lon
    dvdy = v[..., [1, *range(2, v.size(-2)), 0], :] - v  # forward diff in lat (no wrap but tiny batch)
    return dudx + dvdy

# Latitude weights (on device used for compute)
phi = torch.linspace(-math.pi/2, math.pi/2, steps=Hc, device=device).view(1, 1, Hc, 1)
cos_phi = torch.cos(phi).clamp(min=1e-3)
w_lat = cos_phi / cos_phi.mean()

# Tiny batch from val (fallback to train). Expect loaders that yield [B,T,C,H,W].
def get_small_batch():
    try:
        xb, yb = next(iter(val_b_loader))
        xs, ys = next(iter(val_s_loader))
        return xb, yb, xs, ys
    except Exception:
        xb, yb = next(iter(train_b_loader))
        xs, ys = next(iter(train_s_loader))
        return xb, yb, xs, ys

try:
    xb, yb, xs, ys = get_small_batch()
except Exception:
    safe_print("FATAL: Could not obtain a batch from loaders. Ensure loaders exist and yield [B,T,C,H,W].")
    raise

xb = xb.to(device, non_blocking=True)
yb = yb.to(device, non_blocking=True)
xs = xs.to(device, non_blocking=True)
ys = ys.to(device, non_blocking=True)

baseline.eval(); spherical.eval()

# If running stats missing, compute from this batch only (safe and tiny)
def compute_running_stats(y_list):
    with torch.no_grad():
        cat = torch.cat([F.interpolate(y, size=(Hc, Wc), mode="bilinear", align_corners=False) for y in y_list], dim=0)
        mean = cat.mean(dim=[0,2,3], keepdim=True)
        var  = cat.var(dim=[0,2,3], unbiased=False, keepdim=True).clamp_min(1e-8)
    return mean, var

if 'running_mean' not in globals() or running_mean is None or 'running_var' not in globals() or running_var is None:
    running_mean, running_var = compute_running_stats([yb, ys])

with torch.no_grad():
    # Forward
    yhb = baseline(xb)      # [B,C,H,W]
    yhs = spherical(xs)     # [B,C,H,W]

    # To common grid
    yhb_c = F.interpolate(yhb, size=(Hc, Wc), mode="bilinear", align_corners=False)
    yhs_c = F.interpolate(yhs, size=(Hc, Wc), mode="bilinear", align_corners=False)
    yb_c  = F.interpolate(yb,  size=(Hc, Wc), mode="bilinear", align_corners=False)
    ys_c  = F.interpolate(ys,  size=(Hc, Wc), mode="bilinear", align_corners=False)

    # Normalize
    def norm(y): return (y - running_mean) / torch.sqrt(running_var)
    yhb_n, yhs_n, yb_n, ys_n = norm(yhb_c), norm(yhs_c), norm(yb_c), norm(ys_c)

    # Basic metrics
    diff_b = yhb_n - yb_n
    diff_s = yhs_n - ys_n

    mse_b = weight_reduce(diff_b.pow(2), w_lat)
    mse_s = weight_reduce(diff_s.pow(2), w_lat)

    mae_b = weight_reduce(diff_b.abs(), w_lat)
    mae_s = weight_reduce(diff_s.abs(), w_lat)

    ang_b = smooth_angle_loss(yhb_n, yb_n, w=w_lat)
    ang_s = smooth_angle_loss(yhs_n, ys_n, w=w_lat)

    # Speed magnitude error
    spd_b = torch.sqrt((yhb_n[:,0:1]**2 + yhb_n[:,1:2]**2) + 1e-8)
    spd_t_b = torch.sqrt((yb_n[:,0:1]**2 + yb_n[:,1:2]**2) + 1e-8)
    mag_b = weight_reduce((spd_b - spd_t_b).abs(), w_lat)

    spd_s = torch.sqrt((yhs_n[:,0:1]**2 + yhs_n[:,1:2]**2) + 1e-8)
    spd_t_s = torch.sqrt((ys_n[:,0:1]**2 + ys_n[:,1:2]**2) + 1e-8)
    mag_s = weight_reduce((spd_s - spd_t_s).abs(), w_lat)

    # Divergence penalty proxy (smaller is better)
    div_b = weight_reduce(div2d(yhb_n[:,0:1], yhb_n[:,1:2]).abs(), w_lat)
    div_s = weight_reduce(div2d(yhs_n[:,0:1], yhs_n[:,1:2]).abs(), w_lat)

    # Hemispheric metrics (optional, small)
    mid = Hc // 2
    wN = w_lat[:, :, mid:, :]
    wS = w_lat[:, :, :mid, :]

    def hemi_reduce(diff, w, hemi):
        if hemi == "N":
            d = diff[..., mid:, :]
        else:
            d = diff[..., :mid, :]
        return weight_reduce(d.pow(2), w)

    mse_b_N = hemi_reduce(diff_b, wN, "N")
    mse_b_S = hemi_reduce(diff_b, wS, "S")
    mse_s_N = hemi_reduce(diff_s, wN, "N")
    mse_s_S = hemi_reduce(diff_s, wS, "S")

    # Composite (same philosophy as before; tweak weights if desired)
    w_mse, w_ang, w_mag, w_div = 1.0, 0.20, 0.10, 0.05
    score_b = w_mse*mse_b + w_ang*ang_b + w_mag*mag_b + w_div*div_b
    score_s = w_mse*mse_s + w_ang*ang_s + w_mag*mag_s + w_div*div_s

# Convert to CPU floats
def fcpu(t): return float(t.detach().cpu())

eval_metrics_fast = {
    "baseline": {
        "mse": fcpu(mse_b), "mae": fcpu(mae_b), "rmse": float(math.sqrt(max(fcpu(mse_b), 0.0))),
        "angle": fcpu(ang_b), "mag_mae": fcpu(mag_b), "div_abs": fcpu(div_b),
        "mse_N": fcpu(mse_b_N), "mse_S": fcpu(mse_b_S),
        "score": fcpu(score_b),
    },
    "spherical": {
        "mse": fcpu(mse_s), "mae": fcpu(mae_s), "rmse": float(math.sqrt(max(fcpu(mse_s), 0.0))),
        "angle": fcpu(ang_s), "mag_mae": fcpu(mag_s), "div_abs": fcpu(div_s),
        "mse_N": fcpu(mse_s_N), "mse_S": fcpu(mse_s_S),
        "score": fcpu(score_s),
    },
    "winner": "spherical" if fcpu(score_s) < fcpu(score_b) else "baseline",
    "weights": {"w_mse": 1.0, "w_ang": 0.20, "w_mag": 0.10, "w_div": 0.05},
    "common_grid": {"H": Hc, "W": Wc},
}

safe_print("Fast eval (Cell 12) complete.")
safe_print(f"Baseline score: {eval_metrics_fast['baseline']['score']:.6f} | Spherical score: {eval_metrics_fast['spherical']['score']:.6f} | Winner: {eval_metrics_fast['winner']}")

Fast eval (Cell 12) complete.
Baseline score: 0.679794 | Spherical score: 0.667570 | Winner: spherical


In [16]:
# Cell 12b: Crash-proof save — weights + metrics JSON only (no plots)

import os, json, datetime, gc
from pathlib import Path
import torch

def safe_print(msg):
    try:
        print(msg, flush=True)
    except Exception:
        pass

# Optional: clear CUDA cache to avoid transient memory pressure
try:
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
except Exception:
    pass

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
try:
    RUN_DIR = Path(f"D:/era5_runs/spherical_vs_baseline_{timestamp}")
    RUN_DIR.mkdir(parents=True, exist_ok=True)
except Exception:
    safe_print("Warn: Could not create dir on D:. Using current directory.")
    RUN_DIR = Path(f"./spherical_vs_baseline_{timestamp}")
    RUN_DIR.mkdir(parents=True, exist_ok=True)

# Weights
try:
    torch.save({k: v.detach().cpu() for k, v in baseline.state_dict().items()}, RUN_DIR / "baseline_weights.pt")
    torch.save({k: v.detach().cpu() for k, v in spherical.state_dict().items()}, RUN_DIR / "spherical_weights.pt")
    safe_print("Saved weights.")
except Exception as e:
    safe_print(f"Warn(weights): {e}")

# Metrics
metrics = {}
if 'eval_metrics_fast' in globals() and isinstance(eval_metrics_fast, dict) and len(eval_metrics_fast) > 0:
    metrics = eval_metrics_fast
else:
    metrics = {"note": "No eval metrics found. Run Cell 12 before 12b in this session."}

try:
    with open(RUN_DIR / "fast_eval_metrics.json", "w") as f:
        json.dump(metrics, f, indent=2)
    safe_print("Saved fast_eval_metrics.json.")
except Exception as e:
    safe_print(f"Warn(metrics): {e}")

# Final cleanup
try:
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
except Exception:
    pass
gc.collect()

safe_print(f"Artifacts saved to: {RUN_DIR}")

Warn(weights): name 'baseline' is not defined
Saved fast_eval_metrics.json.
Artifacts saved to: D:\era5_runs\spherical_vs_baseline_20250920_091042


In [17]:
# Load previously saved weights into current models (run AFTER Cell 10, BEFORE Cell 13)

from pathlib import Path
import torch

SAVED = Path(r"D:\era5_runs\spherical_vs_baseline_20250919_214736")

# Sanity checks
assert 'baseline' in globals() and 'spherical' in globals(), "Models not instantiated. Run Cell 10 first."
assert (SAVED / "baseline_weights.pt").exists(), f"Missing: {SAVED / 'baseline_weights.pt'}"
assert (SAVED / "spherical_weights.pt").exists(), f"Missing: {SAVED / 'spherical_weights.pt'}"

# Load to CPU (safe regardless of device); then move to whatever device you used in Cell 10 if needed
state_b = torch.load(SAVED / "baseline_weights.pt", map_location="cpu")
state_s = torch.load(SAVED / "spherical_weights.pt", map_location="cpu")
baseline.load_state_dict(state_b)
spherical.load_state_dict(state_s)

# Ensure eval mode
baseline.eval()
spherical.eval()

print(f"Loaded weights from: {SAVED}")

Loaded weights from: D:\era5_runs\spherical_vs_baseline_20250919_214736


In [1]:
# Resume-ready HTML report from saved metrics (no plotting, no GPU)
# Target run folder:
RUN_DIR = r"D:\era5_runs\spherical_vs_baseline_20250919_214736"

import json, datetime, platform, sys, os
from pathlib import Path
from html import escape

RUN = Path(RUN_DIR)
metrics_path = RUN / "fast_eval_metrics.json"
assert RUN.exists(), f"Run folder not found: {RUN}"
assert metrics_path.exists(), f"Missing metrics JSON: {metrics_path}"

# Load metrics
with open(metrics_path, "r", encoding="utf-8") as f:
    metrics = json.load(f)

# Helpers
def fmt_val(v):
    if isinstance(v, float):
        # use 6 significant digits, strip trailing zeros
        s = f"{v:.6g}"
        return s
    return str(v)

def flatten_metrics(d, prefix=""):
    rows = []
    for k, v in d.items():
        key = f"{prefix}{k}" if not prefix else f"{prefix}.{k}"
        if isinstance(v, dict):
            rows.extend(flatten_metrics(v, key))
        else:
            rows.append((key, fmt_val(v)))
    return rows

rows = flatten_metrics(metrics)

# Attempt to pick out common fields if present
highlights = []
for k, v in rows:
    kl = k.lower()
    if any(term in kl for term in ["mse_u", "mse-v", "mse.u", "mse.v"]):
        highlights.append((k, v))
    if any(term in kl for term in ["angle", "vector_angle", "dir_angle"]):
        highlights.append((k, v))
    if any(term in kl for term in ["mae", "rmse"]):
        highlights.append((k, v))

# Environment info (for reproducibility)
try:
    import torch
    torch_ver = torch.__version__
    cuda_avail = torch.cuda.is_available()
except Exception:
    torch_ver = "unknown"
    cuda_avail = False

timestamp_now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
sys_info = {
    "Generated": timestamp_now,
    "OS": f"{platform.system()} {platform.release()}",
    "Python": sys.version.split()[0],
    "PyTorch": torch_ver,
    "CUDA available": str(cuda_avail),
    "Run folder": str(RUN.resolve()),
    "Artifacts": ", ".join([p.name for p in RUN.glob("*")]),
}

# You can customize these if you wish (static descriptors for clarity)
project_info = {
    "Project": "ERA5 500 hPa Winds — Spherical vs Baseline",
    "Data": "ERA5 (500 hPa), vars: u, v",
    "Grid": "64 x 128 (lat-lon)",
    "Context Length (T)": "2",
    "OS/Env": "Windows-native PyTorch (no WSL/Docker/Dask)",
}

# Build HTML
css = """
body { font-family: Arial, Helvetica, sans-serif; color: #222; padding: 24px; }
h1, h2, h3 { margin: 0.2em 0 0.4em 0; }
h1 { font-size: 22px; }
h2 { font-size: 18px; color: #444; }
h3 { font-size: 16px; color: #555; }
.section { margin: 18px 0 22px 0; }
.kv { display: grid; grid-template-columns: 220px 1fr; gap: 6px 14px; }
.kv div.key { font-weight: bold; color: #333; }
table { border-collapse: collapse; width: 100%; max-width: 900px; }
th, td { border: 1px solid #ddd; padding: 8px 10px; }
th { background: #f4f4f4; text-align: left; }
.badges { display: flex; flex-wrap: wrap; gap: 8px; margin: 8px 0 0 0; }
.badge { background: #eef3ff; color: #2a4b8d; padding: 4px 8px; border-radius: 10px; font-size: 12px; border: 1px solid #d0ddff; }
.note { color: #666; font-size: 12px; }
hr { border: none; height: 1px; background: #eee; margin: 18px 0; }
"""

html = []
html.append("<!DOCTYPE html>")
html.append("<html><head><meta charset='utf-8'>")
html.append("<meta name='viewport' content='width=device-width, initial-scale=1' />")
html.append("<title>ERA5 Spherical vs Baseline — Fast Metrics Report</title>")
html.append(f"<style>{css}</style>")
html.append("</head><body>")

# Title
html.append("<h1>ERA5 500 hPa — Spherical vs Baseline</h1>")
html.append("<div class='badges'>")
for k, v in project_info.items():
    html.append(f"<span class='badge'>{escape(k)}: {escape(v)}</span>")
html.append("</div>")

# Executive summary
html.append("<div class='section'>")
html.append("<h2>Executive Summary</h2>")
html.append("<p>This report summarizes fast evaluation metrics comparing a Spherical model against a Baseline model for wind prediction (u, v) on a 64×128 grid with context length T=2. Models were trained and evaluated in a Windows-native PyTorch setup, avoiding WSL/Docker/Dask. The metrics below can be cited directly in a portfolio or resume.</p>")
html.append("</div>")

# Highlights (if any extracted)
if highlights:
    html.append("<div class='section'>")
    html.append("<h2>Key Highlights</h2>")
    html.append("<table><tr><th>Metric</th><th>Value</th></tr>")
    for k, v in highlights:
        html.append(f"<tr><td>{escape(k)}</td><td>{escape(v)}</td></tr>")
    html.append("</table>")
    html.append("</div>")

# Full metrics table
html.append("<div class='section'>")
html.append("<h2>Full Metrics</h2>")
html.append("<table><tr><th>Metric</th><th>Value</th></tr>")
for k, v in rows:
    html.append(f"<tr><td>{escape(k)}</td><td>{escape(v)}</td></tr>")
html.append("</table>")
html.append("</div>")

# Reproducibility / Environment
html.append("<div class='section'>")
html.append("<h2>Reproducibility & Environment</h2>")
html.append("<div class='kv'>")
for k, v in sys_info.items():
    html.append(f"<div class='key'>{escape(k)}</div><div>{escape(v)}</div>")
html.append("</div>")
html.append("<p class='note'>Note: This is a fast report derived solely from saved metrics (no inference, no plotting). For image panels or additional analyses, see the artifacts folder or rerun visualization cells in a stable environment.</p>")
html.append("</div>")

html.append("<hr>")
html.append("<div class='section note'>Generated automatically from fast_eval_metrics.json. Safe-mode report: no GPU, no plotting libraries, no dataset access.</div>")
html.append("</body></html>")

# Write HTML
out_html = RUN / "fast_metrics_report.html"
out_html.write_text("\n".join(html), encoding="utf-8")
print(f"Report written to: {out_html}")

Report written to: D:\era5_runs\spherical_vs_baseline_20250919_214736\fast_metrics_report.html


In [18]:
# Export a single tiny batch and weights for external visualization (no inference)
import json
from pathlib import Path
import torch
import numpy as np

# Where to write the export
RUN_DIR = Path(r"D:\era5_runs\spherical_vs_baseline_20250919_214736")
EXPORT_DIR = RUN_DIR / "export_for_external_viz"
EXPORT_DIR.mkdir(parents=True, exist_ok=True)

# 1) Get one tiny batch from val if available, else train
def get_one_batch():
    if 'val_b_loader' in globals() and 'val_s_loader' in globals():
        xb, yb = next(iter(val_b_loader))
        xs, ys = next(iter(val_s_loader))
        split = "val"
    elif 'train_b_loader' in globals() and 'train_s_loader' in globals():
        xb, yb = next(iter(train_b_loader))
        xs, ys = next(iter(train_s_loader))
        split = "train"
    else:
        raise RuntimeError("No loaders found. Run your Cell 8 (splits + loaders) first.")
    # Keep only first item (B=1)
    xb = xb[:1].detach().cpu()
    yb = yb[:1].detach().cpu()
    xs = xs[:1].detach().cpu()
    return xb, yb, xs, split

xb, yb, xs, split = get_one_batch()

# 2) Capture minimal metadata
meta = {
    "split": split,
    "xb_shape": list(xb.shape),  # [1, C, H, W] or [1, C, ...]
    "yb_shape": list(yb.shape),
    "xs_shape": list(xs.shape),
    "vars": None,
    "T_ctx": None,
    "grid": {"H": None, "W": None},
    "notes": [
        "Inputs are normalized if your Dataset applies normalization.",
        "External script should use the same normalization reversal if needed.",
        "Targets yb correspond to baseline target; spherical expects xs."
    ],
}

# Try to infer VARS/T_ctx/H/W if available
try:
    if 'VARS' in globals():
        meta["vars"] = list(VARS)
except Exception:
    pass
try:
    if 'T_CTX' in globals():
        meta["T_ctx"] = int(T_CTX)
except Exception:
    pass
# Guess H,W from xb last two dims
if xb.ndim == 4:
    meta["grid"]["H"] = int(xb.shape[-2])
    meta["grid"]["W"] = int(xb.shape[-1])

# If your datasets have a .stats attribute, export u,v mean/std snapshot
stats_snapshot = None
try:
    ds_obj = val_b_loader.dataset if 'val_b_loader' in globals() else train_b_loader.dataset
    if hasattr(ds_obj, "stats"):
        # Ensure JSON serializable simple dict
        stats_snapshot = ds_obj.stats
except Exception:
    pass

# 3) Save arrays as .npz
npz_path = EXPORT_DIR / "one_batch_uv_64x128.npz"
np.savez(
    npz_path,
    xb=xb.numpy().astype(np.float32),
    yb=yb.numpy().astype(np.float32),
    xs=xs.numpy().astype(np.float32),
    meta=json.dumps(meta),
    stats=json.dumps(stats_snapshot) if stats_snapshot is not None else json.dumps({})
)
print(f"Saved one tiny batch to: {npz_path}")

# 4) Save model weights copies for convenience (no re-serialization if already .pt)
# If you already have weight files in RUN_DIR, just copy paths below into a manifest.
weight_manifest = {
    "baseline_weight_path": None,
    "spherical_weight_path": None,
}

# Try common names in the run folder
candidates = list(RUN_DIR.glob("*.pt")) + list(RUN_DIR.glob("**/*.pt"))
for p in candidates:
    name = p.name.lower()
    if weight_manifest["baseline_weight_path"] is None and ("baseline" in name or "b_" in name):
        weight_manifest["baseline_weight_path"] = str(p.resolve())
    if weight_manifest["spherical_weight_path"] is None and ("spherical" in name or "s_" in name):
        weight_manifest["spherical_weight_path"] = str(p.resolve())

# If models are in memory but no files were found, save minimal state_dicts
if weight_manifest["baseline_weight_path"] is None or weight_manifest["spherical_weight_path"] is None:
    assert 'baseline' in globals() and 'spherical' in globals(), "Models not in memory; cannot save state_dict."
    # Save minimal state_dicts to EXPORT_DIR
    b_path = EXPORT_DIR / "baseline_state_dict.pt"
    s_path = EXPORT_DIR / "spherical_state_dict.pt"
    torch.save(baseline.state_dict(), b_path)
    torch.save(spherical.state_dict(), s_path)
    weight_manifest["baseline_weight_path"] = str(b_path.resolve())
    weight_manifest["spherical_weight_path"] = str(s_path.resolve())

# 5) Write a small manifest JSON for external script
manifest = {
    "npz_path": str(npz_path.resolve()),
    "weights": weight_manifest,
    "vars": meta["vars"],
    "T_ctx": meta["T_ctx"],
    "grid": meta["grid"],
    "run_dir": str(RUN_DIR.resolve()),
}
manifest_path = EXPORT_DIR / "export_manifest.json"
with open(manifest_path, "w", encoding="utf-8") as f:
    json.dump(manifest, f, indent=2)
print(f"Wrote manifest: {manifest_path}")

print("Done. You can now share the NPZ and the two weight files for external visualization.")

Saved one tiny batch to: D:\era5_runs\spherical_vs_baseline_20250919_214736\export_for_external_viz\one_batch_uv_64x128.npz
Wrote manifest: D:\era5_runs\spherical_vs_baseline_20250919_214736\export_for_external_viz\export_manifest.json
Done. You can now share the NPZ and the two weight files for external visualization.


In [20]:
# Single self-contained cell: Safe metrics + PNG visuals if/when panel_arrays.npz is present
# - No CUDA, no matplotlib. Uses only numpy + Pillow.
# - If the NPZ is missing, it prints instructions and exits without crashing.
# - When the NPZ exists, it computes MSE/MAE/angle metrics and saves PNGs.

import os, json
import numpy as np
from pathlib import Path

# 1) Configure the expected NPZ path (returned from external inference)
NPZ_PATH = Path(r"D:\era5_runs\spherical_vs_baseline_20250919_214736\external_results\panel_arrays.npz")
OUT_DIR = NPZ_PATH.parent / "png"
OUT_DIR.mkdir(parents=True, exist_ok=True)

if not NPZ_PATH.exists():
    print("panel_arrays.npz not found yet.")
    print(f"Expected here: {NPZ_PATH}")
    print("Once you have it (with keys u_t, v_t, u_b, v_b, u_s, v_s), place it there and rerun this cell.")
else:
    # 2) Load arrays
    d = np.load(NPZ_PATH)
    required = ["u_t", "v_t", "u_b", "v_b", "u_s", "v_s"]
    missing = [k for k in required if k not in d.files]
    if missing:
        print("NPZ is present but missing keys:", missing)
        print("It must contain:", required)
    else:
        u_t = d["u_t"].astype(np.float32); v_t = d["v_t"].astype(np.float32)
        u_b = d["u_b"].astype(np.float32); v_b = d["v_b"].astype(np.float32)
        u_s = d["u_s"].astype(np.float32); v_s = d["v_s"].astype(np.float32)

        # 3) Metrics (MSE, MAE, angle error, optional directional F1-like)
        def mse(a, b): return float(np.nanmean((a - b) ** 2))
        def mae(a, b): return float(np.nanmean(np.abs(a - b)))
        def vec_angle_deg(u1, v1, u2, v2, eps=1e-8):
            dot = u1*u2 + v1*v2
            n1 = np.sqrt(u1*u1 + v1*v1) + eps
            n2 = np.sqrt(u2*u2 + v2*v2) + eps
            cos = np.clip(dot / (n1*n2), -1.0, 1.0)
            return np.degrees(np.arccos(cos))

        def directional_f1(u_hat, v_hat, u_true, v_true, theta_deg=20.0):
            ang = vec_angle_deg(u_hat, v_hat, u_true, v_true)
            correct = (ang <= theta_deg).astype(np.uint8)
            mag_t = np.sqrt(u_true*u_true + v_true*v_true)
            strong = (mag_t >= np.nanmedian(mag_t)).astype(np.uint8)
            tp = int(np.sum((correct == 1) & (strong == 1)))
            fp = int(np.sum((correct == 1) & (strong == 0)))
            fn = int(np.sum((correct == 0) & (strong == 1)))
            precision = tp / (tp + fp + 1e-9)
            recall = tp / (tp + fn + 1e-9)
            f1 = 2 * precision * recall / (precision + recall + 1e-9)
            return {"theta_deg": theta_deg, "acc": float(np.mean(correct)),
                    "precision": float(precision), "recall": float(recall), "f1": float(f1)}

        metrics = {
            "baseline": {
                "mse_u": mse(u_b, u_t),
                "mse_v": mse(v_b, v_t),
                "mae_u": mae(u_b, u_t),
                "mae_v": mae(v_b, v_t),
                "angle_deg_mean": float(np.nanmean(vec_angle_deg(u_b, v_b, u_t, v_t))),
                "angle_deg_median": float(np.nanmedian(vec_angle_deg(u_b, v_b, u_t, v_t))),
                "dir_f1@20deg": directional_f1(u_b, v_b, u_t, v_t, 20.0),
            },
            "spherical": {
                "mse_u": mse(u_s, u_t),
                "mse_v": mse(v_s, v_t),
                "mae_u": mae(u_s, u_t),
                "mae_v": mae(v_s, v_t),
                "angle_deg_mean": float(np.nanmean(vec_angle_deg(u_s, v_s, u_t, v_t))),
                "angle_deg_median": float(np.nanmedian(vec_angle_deg(u_s, v_s, u_t, v_t))),
                "dir_f1@20deg": directional_f1(u_s, v_s, u_t, v_t, 20.0),
            }
        }

        # Save metrics JSON
        metrics_path = OUT_DIR / "panel_metrics.json"
        with open(metrics_path, "w", encoding="utf-8") as f:
            json.dump(metrics, f, indent=2)
        print("Wrote metrics JSON:", metrics_path)

        # 4) Safe PNG visuals with Pillow (no matplotlib, no GPU)
        from PIL import Image

        def to_img(arr, mid="median", scale=2.5):
            a = arr.astype(np.float32).copy()
            m = float(np.nanmedian(a)) if mid == "median" else float(np.nanmean(a))
            s = float(np.nanstd(a) + 1e-6)
            a = (a - m) / (scale * s)
            a = np.clip(a, -1.0, 1.0)
            r = ((a + 1.0) / 2.0) * 255.0
            b = ((1.0 - a) / 2.0) * 255.0
            g = 255.0 - np.abs(a) * 255.0 * 0.7
            img = np.stack([r, g, b], axis=-1).astype(np.uint8)
            return Image.fromarray(img)

        def save_panel(img, name):
            img = img.resize((512, 256), Image.NEAREST)
            img.save(OUT_DIR / name)

        # Fields
        save_panel(to_img(u_t), "truth_u.png")
        save_panel(to_img(v_t), "truth_v.png")
        save_panel(to_img(u_b), "baseline_u_hat.png")
        save_panel(to_img(v_b), "baseline_v_hat.png")
        save_panel(to_img(u_s), "spherical_u_hat.png")
        save_panel(to_img(v_s), "spherical_v_hat.png")

        # Absolute error maps
        eu_b = np.abs(u_b - u_t); ev_b = np.abs(v_b - v_t)
        eu_s = np.abs(u_s - u_t); ev_s = np.abs(v_s - v_t)

        save_panel(to_img(eu_b, mid="mean", scale=4.0), "abs_err_u_baseline.png")
        save_panel(to_img(ev_b, mid="mean", scale=4.0), "abs_err_v_baseline.png")
        save_panel(to_img(eu_s, mid="mean", scale=4.0), "abs_err_u_spherical.png")
        save_panel(to_img(ev_s, mid="mean", scale=4.0), "abs_err_v_spherical.png")

        print(f"Saved PNGs to: {OUT_DIR}")
        print("Done.")

panel_arrays.npz not found yet.
Expected here: D:\era5_runs\spherical_vs_baseline_20250919_214736\external_results\panel_arrays.npz
Once you have it (with keys u_t, v_t, u_b, v_b, u_s, v_s), place it there and rerun this cell.
