In [None]:
import numpy as np, time
N, nu, dt = 128, 0.005, 0.002
M = 2*N//3

# Wavenumbers (clean meshgrid way)
x = np.linspace(0, 2*np.pi, N, endpoint=False)
Kx, Ky, Kz = np.meshgrid(
    np.fft.fftfreq(N, 1/N),
    np.fft.fftfreq(N, 1/N),
    np.fft.rfftfreq(N, 1/N),
    indexing='ij'
)
Kx, Ky, Kz = [a.astype(np.float64) for a in (Kx, Ky, Kz)]

# Precompute everything heavy
Ksq = [Kx**2, Ky**2, Kz**2]                    # for viscosity
K2 = Kx**2 + Ky**2 + Kz**2                     # for projection
mask = (np.abs(Kx) >= M) | (np.abs(Ky) >= M) | (np.abs(Kz) >= M)

# Taylor-Green initial condition
X, Y, Z = np.meshgrid(x, x, x, indexing='ij')
u = np.array([
    np.sin(X)*np.cos(Y)*np.cos(Z),
   -np.cos(X)*np.sin(Y)*np.cos(Z),
    np.zeros_like(X)
])

def energy(u):
    return 0.5*np.sum(u[0]**2 + u[1]**2 + u[2]**2) * (2*np.pi/N)**3

print(f"SPECTRALCORE320 — N={N}³ — 128³ ON YOUR PHONE GOES BRRR")
print(f"Initial energy: {energy(u):.10f} | t=0.00s")

# Warm up FFT (helps Colab allocate best plan)
np.fft.rfftn(np.random.randn(N,N,N)); t0 = time.time(); s = 0

while True:
    s += 1

    # Velocity → Fourier
    uhat = [np.fft.rfftn(u[i]) for i in range(3)]
    uhat = [np.where(mask, 0j, h) for h in uhat]          # 2/3 dealias

    # Vorticity in physical space
    w0 = np.fft.irfftn(1j*(Ky*uhat[2] - Kz*uhat[1]))
    w1 = np.fft.irfftn(1j*(Kz*uhat[0] - Kx*uhat[2]))
    w2 = np.fft.irfftn(1j*(Kx*uhat[1] - Ky*uhat[0]))

    # Real-space velocity (only needed for nonlinear term)
    ux, uy, uz = [np.fft.irfftn(h) for h in uhat]

    # Nonlinear term: -∇×(ω×u)  →  goes to nl
    nl = [
        -np.fft.rfftn(w1*uz - w2*uy),
        -np.fft.rfftn(w2*ux - w0*uz),
        -np.fft.rfftn(w0*uy - w1*ux)
    ]

    # Time step + viscosity
    for i in range(3):
        nl[i] -= nu * Ksq[i] * uhat[i]      # viscous term (precomputed K²)
        uhat[i] += dt * nl[i]               # explicit Euler

    # Incompressibility projection (exact, fast, no warnings)
    div = Kx*uhat[0] + Ky*uhat[1] + Kz*uhat[2]
    proj = div / (K2 + (K2 == 0))            # safe division everywhere
    uhat[0] -= Kx * proj
    uhat[1] -= Ky * proj
    uhat[2] -= Kz * proj

    # Final dealias again (optional but purists love it)
    uhat = [np.where(mask, 0j, h) for h in uhat]

    # Back to physical space
    u = [np.fft.irfftn(h) for h in uhat]

    if s % 100 == 0:
        elapsed = time.time() - t0
        print(f"Step {s:5d} | E={energy(u):.7f} | FPS={s/elapsed:5.2f} | t={elapsed:6.1f}s")