In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from time import time

N = 320
torch.set_default_dtype(torch.float32)
device = 'cpu'

L = 2 * np.pi
dx = L / N
z_dim = N // 2 + 1

kx = torch.fft.rfftfreq(N, d=dx/(2*np.pi))[:z_dim].to(device)
ky = torch.fft.fftfreq(N, d=dx/(2*np.pi)).to(device)
kz = torch.fft.fftfreq(N, d=dx/(2*np.pi)).to(device)

KX = kx[None, None, :].to(device)
KY = ky[None, :, None].to(device)
KZ = kz[:, None, None].to(device)

K2 = KX**2 + KY**2 + KZ**2 + 1e-12
invK2 = 1.0 / K2
invK2[0,0,0] = 0.0

dealias = ((KX.abs() < 2*N//3) & (KY.abs() < 2*N//3) & (KZ.abs() < 2*N//3)).float()

np.random.seed(42)
u_real = torch.randn(N, N, N, device=device)
v_real = torch.randn(N, N, N, device=device)
w_real = torch.randn(N, N, N, device=device)

u = torch.fft.rfftn(u_real) * 0.5
v = torch.fft.rfftn(v_real) * 0.5
w = torch.fft.rfftn(w_real) * 0.5

curl = 1j * (KX*u + KY*v + KZ*w)
u = u - 1j * KX * curl * invK2
v = v - 1j * KY * curl * invK2
w = w - 1j * KZ * curl * invK2

def step():
    global u, v, w
    ur = torch.fft.irfftn(u, s=(N, N, N))
    vr = torch.fft.irfftn(v, s=(N, N, N))
    wr = torch.fft.irfftn(w, s=(N, N, N))

    u_x = torch.fft.irfftn(1j * KX * u, s=(N, N, N))
    u_y = torch.fft.irfftn(1j * KY * u, s=(N, N, N))
    u_z = torch.fft.irfftn(1j * KZ * u, s=(N, N, N))
    v_x = torch.fft.irfftn(1j * KX * v, s=(N, N, N))
    v_y = torch.fft.irfftn(1j * KY * v, s=(N, N, N))
    v_z = torch.fft.irfftn(1j * KZ * v, s=(N, N, N))
    w_x = torch.fft.irfftn(1j * KX * w, s=(N, N, N))
    w_y = torch.fft.irfftn(1j * KY * w, s=(N, N, N))
    w_z = torch.fft.irfftn(1j * KZ * w, s=(N, N, N))

    nl_x = torch.fft.rfftn(ur*u_x + vr*u_y + wr*u_z)
    nl_y = torch.fft.rfftn(ur*v_x + vr*v_y + wr*v_z)
    nl_z = torch.fft.rfftn(ur*w_x + vr*w_y + wr*w_z)
    nl_x, nl_y, nl_z = dealias*nl_x, dealias*nl_y, dealias*nl_z

    dt, nu = 0.004, 0.00008          # slightly smaller dt/nu to stay stable at higher Re
    tent_u = u - dt*nl_x - nu*dt*K2*u
    tent_v = v - dt*nl_y - nu*dt*K2*v
    tent_w = w - dt*nl_z - nu*dt*K2*w

    div = 1j * (KX*tent_u + KY*tent_v + KZ*tent_w)
    corr = div * invK2
    u = tent_u - 1j * KX * corr
    v = tent_v - 1j * KY * corr
    w = tent_w - 1j * KZ * corr

plt.ion()
fig = plt.figure(figsize=(7,7))
img = None
t0 = time()
frame = 0

print("Starting 320³ float32 run... buckle up.")
while True:
    step()
    if frame % 20 == 0:
        speed = (torch.fft.irfftn(u,s=(N,N,N))**2 +
                 torch.fft.irfftn(v,s=(N,N,N))**2 +
                 torch.fft.irfftn(w,s=(N,N,N))**2).sqrt()[N//2].cpu()
        if img is None:
            img = plt.imshow(speed, cmap='turbo', vmin=0, vmax=3)
            plt.axis('off')
        else:
            img.set_data(speed)
        elapsed = time() - t0
        fps = frame / elapsed if elapsed > 0 else 0
        plt.title(f'320³ • {fps:.2f} FPS • float32')
        fig.canvas.flush_events()
        plt.pause(0.001)
    frame += 1