# case 1

periodic

In [None]:
import jax
import jax.numpy as jnp
from jax.scipy.sparse.linalg import cg
from jax.scipy.ndimage import map_coordinates
from typing import NamedTuple
import matplotlib.pyplot as plt

In [None]:
class FluidConfig:
    N_x = 128            #  (N x N)
    N_y = 128
    Lx = 1.0       # spacial scale. in previous code, we didnt consider the spacial scale. that is, one cell with length 1
    Ly = 1.0
    dx = Lx / N_x
    dy = Ly / N_y

    DT = 0.1          # time step
    VISCOSITY = 0.001 # Diffusion strength
    DIFFUSION_ITER = 20 # CG step for diffusion
    PRESSURE_ITER = 50  # CG step for pressure
    
    # Kolmogorov Flow 
    FORCE_SCALE = 0.0#1.0#80.0 
    WAVE_NUMBER = 4.0  

class FluidState(NamedTuple):
    u: jnp.ndarray     
    v: jnp.ndarray      
    density: jnp.ndarray 

In [None]:
def laplacian(f):
    left  = jnp.roll(f, 1, axis=1)
    right = jnp.roll(f, -1, axis=1)
    down  = jnp.roll(f, 1, axis=0)
    up    = jnp.roll(f, -1, axis=0)
    la_x = (left - 2.0*f + right)/(dx*dx)
    la_y = (up - 2.0*f + down) / (dy*dy)
    return la_x+la_y

def divergence(u, v, dx, dy):
    du_dx = (jnp.roll(u, -1, axis=1) - jnp.roll(u, 1, axis=1)) / (2.0 * dx)
    dv_dy = (jnp.roll(v, -1, axis=0) - jnp.roll(v, 1, axis=0)) / (2.0 * dy)
    return du_dx + dv_dy

def gradient(p, dx, dy):
    du_dx = (jnp.roll(p, -1, axis=1) - jnp.roll(p, 1, axis=1)) / (2.0 * dx)
    dv_dy = (jnp.roll(p, -1, axis=0) - jnp.roll(p, 1, axis=0)) / (2.0 * dy)
    return du_dx, dv_dy

In [None]:
def advect(f,u,v,dt,dx,dy):
    mesh_y, mesh_x = jnp.meshgrid(jnp.arange(N_x), jnp.arange(N_y), indexing='ij')
    mesh_x_new = mesh_x - (u * dt) / dx
    mesh_y_new = mesh_y - (v * dt) / dy
    coords = jnp.stack([mesh_y_new, mesh_x_new], axis=0)
    return map_coordinates(f, coords, order=1, mode='wrap')

def diffuse(f,viscosity,dt, n_iter, dx, dy):
    if viscosity == 0:
        return f
    alpha = viscosity * dt

    def linear_op(x):
        return x - alpha * laplacian(x, dx, dy)

    f_new, _ = cg(linear_op, f, maxiter=n_iter)
    return f_new

def pressure(div, n_iter, dx, dy):
    def laplacian_op(p):
        return laplacian(p, dx, dy)
    p, _ = cg(laplacian_op, div, maxiter=n_iter)
    return p

def project(u, v, n_iter, dx, dy):
    div = divergence(u, v, dx, dy)
    p = solve_pressure(div, n_iter, dx, dy)
    grad_p_x, grad_p_y = gradient(p, dx, dy)
    return u - grad_p_x, v - grad_p_y

def apply_kolmogorov_forcing(u, v, dt, scale, k):
    y = jnp.linspace(0.0, FluidConfig.Ly, FluidConfig.N_y, endpoint=False)
    forcing_grid = jnp.sin(k * y)[:, None]
    
    u_new = u + dt * scale * forcing_grid
    return u_new, v