In [None]:
# SpectralCore42 — The Final Universal Accelerator
# 42 lines. Zero config. Runs on phone → 4090. Yours forever.
# Drop this file anywhere. Call spectral_step(). Win.
import torch
import torch.fft as fft

# ────────────────────── ONE-TIME SETUP (call once) ──────────────────────
def init(grid=256, device=None):
    global Kx, Ky, Kz, K2, MASK, SIZE
    SIZE = grid
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    k = torch.fft.fftfreq(grid, d=1/grid) * 2 * torch.pi
    Kx, Ky, Kz = torch.meshgrid(k, k, k, indexing='ij')
    Kx, Ky, Kz = Kx.to(device), Ky.to(device), Kz.to(device)
    K2 = Kx**2 + Ky**2 + Kz**2 + 1e-12

    # 2/3 dealiasing mask
    kmax = (2/3) * (grid // 2)
    MASK = (Kx.abs() <= kmax) & (Ky.abs() <= kmax) & (Kz.abs() <= kmax)
    MASK = MASK.to(device)

    print(f"SpectralCore42 → {grid}³ → {device} → READY")

# ────────────────────── THE ONE FUNCTION YOU EVER CALL ──────────────────────
def step(fields, dt=0.005, physics=None):
    """
    fields : tensor or list of tensors, shape (SIZE, SIZE, SIZE, C) or (SIZE, SIZE, SIZE)
    physics: optional function that takes list of spectral tensors → returns RHS list
    """
    if isinstance(fields, torch.Tensor) and fields.ndim == 3:
        fields = [fields]  # single scalar field
    if isinstance(fields, torch.Tensor):
        fields = [fields[..., None]]  # auto-add channel dim

    # → spectral space (ortho norm = energy preserving)
    hats = [fft.rfftn(f, dim=(0,1,2), norm="ortho") for f in fields]

    # ←←← YOUR PHYSICS GOES HERE (you write this, 1–20 lines max) ←←←
    if physics:
        rhs_hats = physics(hats)  # you return list of same length
    else:
        rhs_hats = [torch.zeros_like(h) for h in hats]

    # Advance + dealias in one shot
    hats = [h + dt * r for h, r in zip(hats, rhs_hats)]
    hats = [h * MASK for h in hats]

    # ← back to real space
    out = [fft.irfftn(h, s=(SIZE,SIZE,SIZE), norm="ortho").real for h in hats]

    # return same format you gave me
    return out[0].squeeze(-1) if len(out) == 1 else out