In [1]:
import os, numpy as np, matplotlib.pyplot as plt, pyfftw
import pyfftw.interfaces.numpy_fft as fft
os.makedirs("diagnostics", exist_ok=True)
os.makedirs("plots", exist_ok=True)
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
pyfftw.interfaces.cache.enable()
np.fft.rfft2 = fft.rfft2
np.fft.irfft2 = fft.irfft2
# Chamge list for 
grid_list = [128, 256, 512] 
Re_list = [100, 1000, 10000]
delta0 = 1/28
dt_scale = 1e-3
dt_label_map = {4: "dt4", 2: "dt2", 1: "dt1", 0.5: "dt0_5"}
PRECISION_MAP = {32: (np.float32, np.complex64), 64: (np.float64, np.complex128), 80: (np.longdouble, np.clongdouble)}
siml_factor = 800 # 1/800th the time of reference
T_END = 14.285714285714285 * (1/siml_factor)
DX_DUMPS = 10

def run_simulation(Nx, Re, dt, dt_label, bits, do_plots):
    DTYPE, CDTYPE = PRECISION_MAP[bits]
    Ny = Nx
    Lx = Ly = DTYPE(1.0)
    dx, dy = Lx / Nx, Ly / Ny
    x = np.linspace(0, Lx, Nx, endpoint=False, dtype=DTYPE)
    y = np.linspace(0, Ly, Ny, endpoint=False, dtype=DTYPE)
    X, Y = np.meshgrid(x, y)
    u_ref = DTYPE(1.0)
    c_n = DTYPE(1e-3)
    nu = (u_ref * DTYPE(delta0)) / DTYPE(Re)
    kx = (np.fft.fftfreq(Nx, d=dx) * 2 * np.pi).astype(DTYPE)
    ky = (np.fft.fftfreq(Ny, d=dy) * 2 * np.pi).astype(DTYPE)
    KX, KY = np.meshgrid(kx, ky)
    K2_full = (KX**2 + KY**2).astype(DTYPE)
    K2_full[0, 0] = DTYPE(1.0)
    kx_cut = DTYPE(2/3) * np.max(np.abs(kx))
    ky_cut = DTYPE(2/3) * np.max(np.abs(ky))
    dealias_full = (np.abs(KX) < kx_cut) & (np.abs(KY) < ky_cut)
    Nx_r = Nx // 2 + 1
    K2_r = K2_full[:, :Nx_r]
    dealias_r = dealias_full[:, :Nx_r]
    def spectral_derivative(f, axis):
        f_hat = np.fft.rfft2(f)
        factor = 1j * (KY[:, :Nx_r] if axis == 0 else KX[:, :Nx_r])
        return np.fft.irfft2(factor * f_hat, s=(Ny, Nx)).astype(DTYPE)
    def spectral_laplacian(f):
        f_hat = np.fft.rfft2(f)
        return np.fft.irfft2(-K2_r * f_hat, s=(Ny, Nx)).astype(DTYPE)
    u_x_base = u_ref * np.tanh((2 * Y - DTYPE(1.0)) / DTYPE(delta0))
    u_y_base = np.zeros_like(u_x_base, dtype=DTYPE)
    omega_base = -u_ref * (2 / DTYPE(delta0)) * (1 / np.cosh((2 * Y - DTYPE(1.0)) / DTYPE(delta0)))**2
    psi_init = u_ref * np.exp(-((Y - DTYPE(0.5)) / DTYPE(delta0))**2) * (np.cos(8 * np.pi * X) + np.cos(20 * np.pi * X))
    omega_pert = -c_n * spectral_laplacian(psi_init)
    omega = omega_base + omega_pert
    u_x = u_x_base + c_n * spectral_derivative(psi_init, axis=0)
    u_y = u_y_base - c_n * spectral_derivative(psi_init, axis=1)
    u_mean_init = np.mean(u_x, axis=1)
    dUdy_init = np.gradient(u_mean_init, y)
    delta_vort0 = (u_mean_init.max() - u_mean_init.min()) / np.max(np.abs(dUdy_init))
    dt = DTYPE(dt)
    num_steps = int(np.ceil(T_END / dt))
    DX_CONTOUR = max(1, num_steps // DX_DUMPS)
    run_name = f"{Nx}x{Nx}_Re{Re}_{dt_label}_prec{bits}"
    diag_dir = os.path.join("diagnostics", run_name)
    os.makedirs(diag_dir, exist_ok=True)
    diag_file = os.path.join(diag_dir, f"diagnostics_{run_name}.csv")
    with open(diag_file, "w") as f: f.write("step,time,KE,enstrophy,palinstrophy,div_L2,vort_norm\n")
    if do_plots:
        base_plot_dir = os.path.join("plots", run_name)
        combos = [("contour_only", False, False), ("lines_only", True, False), ("quiver_only", False, True)]
    def imex_ssprk3(ωh):
        def N_hat(ωhh, ux_loc, uy_loc):
            ω_loc = np.fft.irfft2(ωhh, s=(Ny, Nx)) + omega_base
            d_ox = spectral_derivative(ω_loc, axis=1)
            d_oy = spectral_derivative(ω_loc, axis=0)
            nonlin = -(ux_loc * d_ox + uy_loc * d_oy)
            return np.fft.rfft2(nonlin) * dealias_r
        ωn = ωh
        ψn = np.fft.irfft2(ωn / (c_n * K2_r), s=(Ny, Nx))
        uxn = u_x_base + c_n * spectral_derivative(ψn, axis=0)
        uyn = u_y_base - c_n * spectral_derivative(ψn, axis=1)
        N1 = N_hat(ωn, uxn, uyn)
        ω1 = (ωn + dt * N1) / (1 + dt * nu * K2_r)
        ψ1 = np.fft.irfft2(ω1 / (c_n * K2_r), s=(Ny, Nx))
        ux1 = u_x_base + c_n * spectral_derivative(ψ1, axis=0)
        uy1 = u_y_base - c_n * spectral_derivative(ψ1, axis=1)
        N2 = N_hat(ω1, ux1, uy1)
        ω2 = ((3/4) * ωn + (1/4) * (ω1 + dt * N2)) / (1 + dt * nu * K2_r)
        ψ2 = np.fft.irfft2(ω2 / (c_n * K2_r), s=(Ny, Nx))
        ux2 = u_x_base + c_n * spectral_derivative(ψ2, axis=0)
        uy2 = u_y_base - c_n * spectral_derivative(ψ2, axis=1)
        N3 = N_hat(ω2, ux2, uy2)
        ω3 = ((1/3) * ωn + (2/3) * (ω2 + dt * N3)) / (1 + dt * nu * K2_r)
        return ω3
    ω_hat = np.fft.rfft2(omega_pert)
    t = DTYPE(0.0)
    for step in range(1, num_steps + 1):
        ω_hat = imex_ssprk3(ω_hat)
        omega_p = np.fft.irfft2(ω_hat, s=(Ny, Nx))
        omega = omega_base + omega_p
        ψp = np.fft.irfft2(ω_hat / (c_n * K2_r), s=(Ny, Nx))
        u_x = u_x_base + c_n * spectral_derivative(ψp, axis=0)
        u_y = u_y_base - c_n * spectral_derivative(ψp, axis=1)
        t += dt
        if step % DX_CONTOUR == 0:
            dω_dx = spectral_derivative(omega, axis=1)
            dω_dy = spectral_derivative(omega, axis=0)
            enstrophy = DTYPE(0.5) * np.sum(omega**2, dtype=DTYPE) * dx * dy
            palinstro = DTYPE(0.5) * np.sum(dω_dx**2 + dω_dy**2, dtype=DTYPE) * dx * dy
            KE = DTYPE(0.5) * np.sum(u_x**2 + u_y**2, dtype=DTYPE) * dx * dy
            ux_dx = spectral_derivative(u_x, axis=1)
            uy_dy = spectral_derivative(u_y, axis=0)
            div_L2 = np.sqrt(np.sum((ux_dx + uy_dy)**2, dtype=DTYPE) * dx * dy)
            u_mean = np.mean(u_x, axis=1)
            dUdy = np.gradient(u_mean, y)
            vort_norm = ((u_mean.max() - u_mean.min()) / np.max(np.abs(dUdy))) / delta_vort0
            with open(diag_file, "a") as f:
                f.write(f"{step},{float(t)},{float(KE)},{float(enstrophy)},{float(palinstro)},{float(div_L2)},{float(vort_norm)}\n")
            if do_plots:
                ext_omega = np.concatenate([omega, omega], axis=1)
                ext_u_x = np.concatenate([u_x, u_x], axis=1)
                ext_u_y = np.concatenate([u_y, u_y], axis=1)
                ext_X = np.concatenate([X, X + Lx], axis=1)
                ext_Y = np.concatenate([Y, Y], axis=1)
                start = Nx // 2
                end = start + Nx
                omega_crop = ext_omega[:, start:end]
                u_x_crop = ext_u_x[:, start:end]
                u_y_crop = ext_u_y[:, start:end]
                X_crop = ext_X[:, start:end]
                Y_crop = ext_Y[:, start:end]
                plot_modes = [("unrolled", X, Y, omega, u_x, u_y), ("tiled_crop", X_crop, Y_crop, omega_crop, u_x_crop, u_y_crop)]
                for mode_name, X_now, Y_now, omega_now, u_x_now, u_y_now in plot_modes:
                    for folder, draw_lines, draw_quiver in combos:
                        out_dir = os.path.join(base_plot_dir, f"{folder}_{mode_name}")
                        os.makedirs(out_dir, exist_ok=True)
                        plt.figure(figsize=(6, 4.5), dpi=300)
                        cf = plt.contourf(X_now, Y_now, omega_now, levels=1000, cmap="jet", vmin=-60, vmax=0)
                        if draw_lines: plt.contour(X_now, Y_now, omega_now, levels=[-60, 0], colors="white", linewidths=0.5)
                        if draw_quiver:
                            stride = 30
                            plt.quiver(X_now[::stride, ::stride], Y_now[::stride, ::stride], u_x_now[::stride, ::stride], u_y_now[::stride, ::stride], pivot="mid", scale=25, width=0.005, alpha=0.8)
                        plt.axis("off")
                        plt.tight_layout()
                        frame_i = step // DX_CONTOUR
                        fname = f"frame_{frame_i:04d}.png"
                        plt.savefig(os.path.join(out_dir, fname))
                        plt.close()

for Nx in grid_list:
    for Re in Re_list:
        for dt_factor, dt_label in dt_label_map.items():
            dt_val = delta0 * dt_scale * dt_factor
            for bits in PRECISION_MAP.keys():
                do_plots = (dt_factor == 0.5 and bits == 80) # Only doing plots for lowest dt highest prec runs ATM
                print(f"Running: Nx={Nx}, Re={Re}, dt={dt_label}, prec={bits}, plots={do_plots}")
                run_simulation(Nx, Re, dt_val, dt_label, bits, do_plots)

Running: Nx=128, Re=100, dt=dt4, prec=32, plots=False
Running: Nx=128, Re=100, dt=dt4, prec=64, plots=False
Running: Nx=128, Re=100, dt=dt4, prec=80, plots=False
Running: Nx=128, Re=100, dt=dt2, prec=32, plots=False
Running: Nx=128, Re=100, dt=dt2, prec=64, plots=False
Running: Nx=128, Re=100, dt=dt2, prec=80, plots=False
Running: Nx=128, Re=100, dt=dt1, prec=32, plots=False
Running: Nx=128, Re=100, dt=dt1, prec=64, plots=False
Running: Nx=128, Re=100, dt=dt1, prec=80, plots=False
Running: Nx=128, Re=100, dt=dt0_5, prec=32, plots=False
Running: Nx=128, Re=100, dt=dt0_5, prec=64, plots=False
Running: Nx=128, Re=100, dt=dt0_5, prec=80, plots=True
Running: Nx=128, Re=1000, dt=dt4, prec=32, plots=False
Running: Nx=128, Re=1000, dt=dt4, prec=64, plots=False
Running: Nx=128, Re=1000, dt=dt4, prec=80, plots=False
Running: Nx=128, Re=1000, dt=dt2, prec=32, plots=False
Running: Nx=128, Re=1000, dt=dt2, prec=64, plots=False
Running: Nx=128, Re=1000, dt=dt2, prec=80, plots=False
Running: Nx=128, 


KeyboardInterrupt

