In [13]:
import numpy as np
from numpy.random import seed
seed(1)
import pyfftw
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import time as tm
import os
from scipy.fft import fft2, ifft2, fftfreq

### ⚡ FFT-based Poisson Solver (`fps`)

This function solves the 2D Poisson equation

$$
\nabla^2 \psi = f
$$

on a periodic domain using the **Fast Fourier Transform (FFT)**.

**Algorithm:**
1. Compute FFT of the right-hand side:
   $$
   \hat{f}(k_x,k_y) = \mathcal{F}[f(x,y)]
   $$
2. In Fourier space, the equation becomes:
   $$
   -(k_x^2+k_y^2)\hat{\psi}(k_x,k_y) = \hat{f}(k_x,k_y)
   $$
   hence
   $$
   \hat{\psi}(k_x,k_y) = -\frac{\hat{f}(k_x,k_y)}{k_x^2+k_y^2}
   $$
   (with $\hat{\psi}(0,0)=0$ to enforce zero mean).
3. Take the inverse FFT to obtain $\psi(x,y)$.
4. Add **ghost cells** to enforce periodic boundary conditions.

**Why valuable?**
- $O(N \log N)$ complexity (fast for large grids).
- Spectrally accurate for periodic domains.
- Much more efficient than Jacobi or Gauss–Seidel.
- Critical for fluid dynamics problems (e.g. vortex merger) where accuracy of $\psi$ drives correct advection of vorticity.


In [None]:
def fps(nx, ny, dx, dy, f):
    """
    Fast Poisson solver using scipy FFT
    Solves: ∇²ψ = f   on periodic domain
    """
    # extract interior (drop ghost cells)
    rhs = f[1:nx+1, 1:ny+1]

    # Fourier transform
    f_hat = fft2(rhs)

    # Wavenumbers
    kx = fftfreq(nx, d=dx/(2*np.pi))   # scaled to match 2π-periodic domain
    ky = fftfreq(ny, d=dy/(2*np.pi))
    kx, ky = np.meshgrid(kx, ky, indexing='ij')
    ksq = kx**2 + ky**2

    # avoid division by zero
    ksq[0,0] = 1.0

    # Solve in Fourier space
    psi_hat = -f_hat / ksq
    psi_hat[0,0] = 0.0   # enforce zero mean

    # Inverse FFT
    ut = np.real(ifft2(psi_hat))

    # extend to full grid with periodicity
    u = np.empty((nx+3, ny+3))
    u[1:nx+1, 1:ny+1] = ut
    u[:, ny+1] = u[:, 1]
    u[nx+1, :] = u[1, :]
    u[nx+1, ny+1] = u[1,1]
    return u


In [15]:
# -------------------------------
# Periodic boundary condition
# -------------------------------
def bc(nx,ny,u):
    u[:,0] = u[:,ny]
    u[:,ny+2] = u[:,2]
    u[0,:] = u[nx,:]
    u[nx+2,:] = u[2,:]
    return u

### ⚡ RHS with Arakawa Jacobian (`rhs`)

This function computes the RHS of the vorticity transport equation:

$$
\frac{\partial \omega}{\partial t} = - J(\psi, \omega) + \frac{1}{Re}\nabla^2 \omega
$$

The **Arakawa scheme** defines the Jacobian as an average of three discrete forms:

$$
J(\psi, \omega) = \tfrac{1}{3}(J_1 + J_2 + J_3)
$$

1. **Centered form**  
$$
J_1 = \frac{1}{4\Delta x \Delta y} 
\Big[ (\omega_{i+1,j}-\omega_{i-1,j})(\psi_{i,j+1}-\psi_{i,j-1})
    - (\omega_{i,j+1}-\omega_{i,j-1})(\psi_{i+1,j}-\psi_{i-1,j}) \Big]
$$  

2. **Advective form**  
$$
J_2 = \frac{1}{4\Delta x \Delta y} 
\Big[ \omega_{i+1,j}(\psi_{i+1,j+1}-\psi_{i+1,j-1})
     - \omega_{i-1,j}(\psi_{i-1,j+1}-\psi_{i-1,j-1})
     - \omega_{i,j+1}(\psi_{i+1,j+1}-\psi_{i-1,j+1})
     + \omega_{i,j-1}(\psi_{i+1,j-1}-\psi_{i-1,j-1}) \Big]
$$  

3. **Energy–enstrophy form**  
$$
J_3 = \frac{1}{4\Delta x \Delta y} 
\Big[ \omega_{i+1,j+1}(\psi_{i,j+1}-\psi_{i+1,j})
     - \omega_{i-1,j-1}(\psi_{i-1,j}-\psi_{i,j-1})
     - \omega_{i-1,j+1}(\psi_{i,j+1}-\psi_{i-1,j})
     + \omega_{i+1,j-1}(\psi_{i+1,j}-\psi_{i,j-1}) \Big]
$$  

The diffusion term is:

$$
\nabla^2 \omega_{i,j} \approx
\frac{\omega_{i+1,j} - 2\omega_{i,j} + \omega_{i-1,j}}{\Delta x^2}
+
\frac{\omega_{i,j+1} - 2\omega_{i,j} + \omega_{i,j-1}}{\Delta y^2}
$$

Finally, the RHS becomes:

$$
f_{i,j} = -J(\psi,\omega)_{i,j} + \frac{1}{Re}\nabla^2 \omega_{i,j}
$$

**Key property:** The Arakawa scheme conserves both energy and enstrophy, making it much more stable than simple central differences in long-time turbulence simulations.


In [None]:
# -------------------------------
# RHS using Arakawa scheme
# -------------------------------
def rhs(nx,ny,dx,dy,re,w,s,x,y,ts):
    aa = 1.0/(dx*dx)
    bb = 1.0/(dy*dy)
    gg = 1.0/(4.0*dx*dy)
    hh = 1.0/3.0

    f = np.zeros((nx+3,ny+3))

    j1 = gg*((w[2:nx+3,1:ny+2]-w[0:nx+1,1:ny+2])*(s[1:nx+2,2:ny+3]-s[1:nx+2,0:ny+1]) \
             -(w[1:nx+2,2:ny+3]-w[1:nx+2,0:ny+1])*(s[2:nx+3,1:ny+2]-s[0:nx+1,1:ny+2]))
    j2 = gg*( w[2:nx+3,1:ny+2]*(s[2:nx+3,2:ny+3]-s[2:nx+3,0:ny+1]) \
            - w[0:nx+1,1:ny+2]*(s[0:nx+1,2:ny+3]-s[0:nx+1,0:ny+1]) \
            - w[1:nx+2,2:ny+3]*(s[2:nx+3,2:ny+3]-s[0:nx+1,2:ny+3]) \
            + w[1:nx+2,0:ny+1]*(s[2:nx+3,0:ny+1]-s[0:nx+1,0:ny+1]))
    j3 = gg*( w[2:nx+3,2:ny+3]*(s[1:nx+2,2:ny+3]-s[2:nx+3,1:ny+2]) \
            - w[0:nx+1,0:ny+1]*(s[0:nx+1,1:ny+2]-s[1:nx+2,0:ny+1]) \
            - w[0:nx+1,2:ny+3]*(s[1:nx+2,2:ny+3]-s[0:nx+1,1:ny+2]) \
            + w[2:nx+3,0:ny+1]*(s[2:nx+3,1:ny+2]-s[1:nx+2,0:ny+1]) )
    jac = (j1+j2+j3)*hh

    lap = aa*(w[2:nx+3,1:ny+2]-2.0*w[1:nx+2,1:ny+2]+w[0:nx+1,1:ny+2]) \
        + bb*(w[1:nx+2,2:ny+3]-2.0*w[1:nx+2,1:ny+2]+w[1:nx+2,0:ny+1])

    f[1:nx+2,1:ny+2] = -jac + lap/re 
    return f

In [17]:
# -------------------------------
# Initial condition (vortex merger)
# -------------------------------
def vm_ic(nx,ny,x,y):
    w = np.empty((nx+3,ny+3))
    sigma = np.pi
    xc1, yc1 = np.pi-np.pi/4.0, np.pi
    xc2, yc2 = np.pi+np.pi/4.0, np.pi
    w[1:nx+2, 1:ny+2] = np.exp(-sigma*((x[0:nx+1, 0:ny+1]-xc1)**2 + (y[0:nx+1, 0:ny+1]-yc1)**2)) \
                       + np.exp(-sigma*((x[0:nx+1, 0:ny+1]-xc2)**2 + (y[0:nx+1, 0:ny+1]-yc2)**2))
    return bc(nx,ny,w)

In [18]:
# -------------------------------
# Simulation parameters
# -------------------------------
nd = 128
nt = 3000
re = 560.0
dt = 0.01
save_freq = 20

nx, ny = nd, nd
lx, ly = 2.0*np.pi, 2.0*np.pi
dx, dy = lx/nx, ly/ny
x, y = np.meshgrid(np.linspace(0.0,2.0*np.pi,nx+1),
                   np.linspace(0.0,2.0*np.pi,ny+1), indexing='ij')

# -------------------------------
# Initialization
# -------------------------------
w = vm_ic(nx,ny,x,y)
s = fps(nx, ny, dx, dy, -w); s = bc(nx,ny,s)
t = np.empty_like(w)
r = np.empty_like(w)

snapshots = [np.copy(w)]
aa, bb = 1/3, 2/3

# -------------------------------
# Time stepping (RK3)
# -------------------------------
for k in range(1, nt+1):
    time = k*dt
    r = rhs(nx,ny,dx,dy,re,w,s,x,y,time)

    # stage 1
    t[1:nx+2,1:ny+2] = w[1:nx+2,1:ny+2] + dt*r[1:nx+2,1:ny+2]
    t = bc(nx,ny,t)
    s = fps(nx, ny, dx, dy, -t); s = bc(nx,ny,s)
    r = rhs(nx,ny,dx,dy,re,t,s,x,y,time)

    # stage 2
    t[1:nx+2,1:ny+2] = 0.75*w[1:nx+2,1:ny+2] + 0.25*t[1:nx+2,1:ny+2] + 0.25*dt*r[1:nx+2,1:ny+2]
    t = bc(nx,ny,t)
    s = fps(nx, ny, dx, dy, -t); s = bc(nx,ny,s)
    r = rhs(nx,ny,dx,dy,re,t,s,x,y,time)

    # stage 3
    w[1:nx+2,1:ny+2] = aa*w[1:nx+2,1:ny+2] + bb*t[1:nx+2,1:ny+2] + bb*dt*r[1:nx+2,1:ny+2]
    w = bc(nx,ny,w)
    s = fps(nx, ny, dx, dy, -w); s = bc(nx,ny,s)

    if k % save_freq == 0:
        snapshots.append(np.copy(w))

In [19]:
# -------------------------------
# Animation
# -------------------------------
fig, ax = plt.subplots(figsize=(6,6))
cmap = 'jet'

def animate(i):
    ax.clear()
    cs = ax.contourf(snapshots[i][1:nx+2,1:ny+2].T, 80, cmap=cmap)
    ax.set_title(f"Step {i*save_freq}, t={i*save_freq*dt:.2f}")
    return cs

ani = animation.FuncAnimation(fig, animate, frames=len(snapshots), interval=100, blit=False)
plt.close(fig)
# Show inline in Jupyter
display(HTML(ani.to_jshtml()))