In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from time import time

torch.set_default_dtype(torch.float32)

N = 256
L = 2 * np.pi
dx = L / N
z_dim = N // 2 + 1                      # 129

# This is the single most important line — length MUST be exactly z_dim
kx = torch.fft.rfftfreq(N, d=dx/(2*np.pi))[:z_dim]
ky = torch.fft.fftfreq(N, d=dx/(2*np.pi))
kz = torch.fft.fftfreq(N, d=dx/(2*np.pi))

# Manual broadcasting — proven to work
KX = kx[None, None, :]                  # (1, 1, 129)
KY = ky[None, :, None]                  # (1, 256, 1)
KZ = kz[:, None, None]                  # (256, 1, 1)

K2 = KX**2 + KY**2 + KZ**2 + 1e-12
invK2 = 1.0 / K2
invK2[0,0,0] = 0.0

dealias = (KX.abs() < 2*N//3) & (KY.abs() < 2*N//3) & (KZ.abs() < 2*N//3)

# Initial fields — exact shape
np.random.seed(42)
u = torch.randn(N, N, z_dim, dtype=torch.complex64) * 0.5
v = torch.randn(N, N, z_dim, dtype=torch.complex64) * 0.5
w = torch.randn(N, N, z_dim, dtype=torch.complex64) * 0.5

# Divergence-free projection
curl = 1j * (KX*u + KY*v + KZ*w)
u = u - KX * curl * invK2
v = v - KY * curl * invK2
w = w - KZ * curl * invK2

def step():
    global u, v, w
    ur = torch.fft.irfftn(u, s=(N, N, N))
    vr = torch.fft.irfftn(v, s=(N, N, N))
    wr = torch.fft.irfftn(w, s=(N, N, N))

    nl_x = torch.fft.rfftn(wr * torch.roll(ur, -1, 0) - torch.roll(ur, 1, 0))
    nl_y = torch.fft.rfftn(wr * torch.roll(vr, -1, 1) - torch.roll(vr, 1, 1))
    nl_z = torch.fft.rfftn(ur * torch.roll(vr, -1, 2) - torch.roll(vr, 1, 2))

    u = dealias * (u - 0.018j * (KX*nl_x + KY*nl_y + KZ*nl_z) * invK2 - 0.004*K2*u)
    v = dealias * (v - 0.018j * (KX*nl_x + KY*nl_y + KZ*nl_z) * invK2 - 0.004*K2*v)
    w = dealias * (w - 0.018j * (KX*nl_x + KY*nl_y + KZ*nl_z) * invK2 - 0.004*K2*w)

# Plot loop
fig, ax = plt.subplots(figsize=(8,8))
t0 = time()
frame = 0

while True:
    step()
    if frame % 12 == 0:
        speed = (torch.fft.irfftn(u, s=(N,N,N))**2 +
                 torch.fft.irfftn(v, s=(N,N,N))**2 +
                 torch.fft.irfftn(w, s=(N,N,N))**2).sqrt()
        ax.clear()
        ax.imshow(speed[N//2].cpu(), cmap='turbo', vmin=0, vmax=3)
        ax.set_title(f'256³ • {frame/(time()-t0):.2f} FPS')
        plt.pause(0.001)
    frame += 1