In [None]:
import numpy as np
import matplotlib.pyplot as plt
from muFFT import FFT
import matplotlib as mpl

mpl.rcParams['figure.dpi'] = 150
mpl.rcParams['figure.figsize'] = [8, 8]


# Parameters
N = 16  # Grid size
L = 2 * np.pi  # Domain size
nu = 0.01  # Kinematic viscosity
dt = 0.1  # Time step
T = 200 # Total simulation time



nb_grid_pts = (N, N)
fft = FFT(nb_grid_pts, engine="mpi", allow_temporary_buffer=True)

def FUCKING(U_hat):
    kx = 2 * np.pi * fft.fftfreq[0] * (1/N)
    ky = 2 * np.pi * fft.fftfreq[1] * (1/N)

    fieldx = 1j * ky * U_hat[0] - 1j * kx * U_hat[1]
    fieldy = 1j * kx * U_hat[1] - 1j * ky * U_hat[0]


    fftCU = FFT(nb_grid_pts, engine="mpi", allow_temporary_buffer=True)

    x_ = fftCU.ifft(fieldx) * fft.normalisation
    y_ = fftCU.ifft(fieldy) * fft.normalisation

    return np.array([x_, y_])

def SHIT(u, curl):
    fftCR = FFT(nb_grid_pts, engine="mpi", allow_temporary_buffer=True)

    fieldx = u[1] * curl[0] - u[0] * curl[1]
    fieldy = u[0] * curl[1] - u[1] * curl[0]

    x_ = fftCR.fft(fieldx)
    y_ = fftCR.fft(fieldy)

    return np.array([x_, y_])


def rk4(f, t: float, y: np.ndarray, dt: float) -> np.ndarray:
    k1 = f(t, y)
    k2 = f(t + dt / 2, y + dt / 2 * k1)
    k3 = f(t + dt / 2, y + dt / 2 * k2)
    k4 = f(t + dt, y + dt * k3)
    return dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)

def analytical_solution(t, A, nu, X, Y):
    u_x = A * np.cos(X) * np.sin(Y) * np.exp(-2 * nu * t)
    u_y = -A * np.sin(X) * np.cos(Y) * np.exp(-2 * nu * t)
    return u_x, u_y

def rhs(t, y):

    # U_hat = fft.fft(y)  * fft.normalisation
    y_= fft.ifft(y)  * fft.normalisation

    curl = FUCKING(U_hat)
    # print(curl.shape)
    # print(y.shape)
    dy = SHIT(y_, curl)
    # dy *= dealias
    # P_hat = np.sum(dy * fft.fftfreq, 0)
    # dy -= P_hat*N
    # dy -= nu*N*N*U_hat

    # dy = fft.ifft(dy)

    return dy

# fft = FFT((16, 16), engine='pocketfft')

# Grid setup
x = np.linspace(0, L, N, endpoint=False)
y = np.linspace(0, L, N, endpoint=False)
X, Y = np.meshgrid(x, y)

# Initial condition: Taylor-Green vortex
A = 1.0
B = 1.0

u_hat = A * np.cos(X) * np.sin(Y)
v_hat = B * np.sin(X) * -np.cos(Y)

y = np.array([u_hat, v_hat])
y_ = np.array([u_hat, v_hat])

# Main time-stepping loop
t = 0
while t < T:

    t_0 = t
    
    u_analytical_old, v_analytical_old = analytical_solution(t_0, A, nu, X, Y)

    U_hat = fft.fft(y)  * fft.normalisation

    dy_hat = rhs(t, U_hat)

    U_hat += dy_hat
    t += dt
    
    y = fft.ifft(U_hat)  * fft.normalisation

    u_analytical, v_analytical = analytical_solution(t, A, nu, X, Y)

    dy_analytical = np.array([u_analytical - u_analytical_old, v_analytical - v_analytical_old])

    print("dy:")
    print(dy)

    print("dy_analytical:")
    print(dy_analytical)

    u_numerical = y[0]
    v_numerical = y[1]

    if (int(t) % 25 == 0 and (t - int(t) < 0.1)):

        # Compute error
        error = np.linalg.norm(y - np.array([u_analytical, v_analytical]))
        print(f"Time: {t}, Error: {error}")

        print("dy:")
        print(dy)
        print("dy_analytical:")
        print(dy_analytical)


        # Plot comparison
        plt.figure(figsize=(12, 6))

        plt.subplot(1, 2, 1)
        plt.title("Analytical Solution")
        plt.quiver(X, Y, u_analytical, v_analytical, scale=30)

        plt.subplot(1, 2, 2)
        plt.title("Numerical Solution")
        plt.quiver(X, Y, u_numerical, v_numerical, scale=30)

        plt.show()
