# Differentiable Physics VII: Incompressible Navier Stokes

## Navier Stokes Equations

We limit the scope of the solver with the following constraints:

- We use the conservative form of the equations, meaning we are looking at a volume which is fixed in space and the fluid is moving through it.
- We use a 2D grid based approach, where each cell in the grid represents a volume element and has size $\Delta x \times \Delta y$.
- The fluid is incompressible. This means that the density $\rho$ is constant across the grid.

The following quantities are defined on the grid:

- $V$ : velocity vector field. $V = \begin{pmatrix} u \\ v \end{pmatrix}$
- $p_{i,j}$ : pressure scalar field
- $\rho$ : density, constant across the grid
- $\nu$: kinematic viscosity, constant across the grid
- $\tau_{i,j}$ : dye concentration vector field ( for visualization purposes )

The equations for the incompressible Navier Stokes equations are:

1. Momentum conservation:
$$\frac{\partial V}{\partial t} + V \cdot \nabla U= -\frac{1}{\rho} \nabla p + \nu \Delta U$$

2. Incompressibility or mass conservation:
$$\nabla \cdot V = 0$$

As we are interested in the 2D case, we can write the equations as:

$$\frac{\partial u}{\partial t} + u \frac{\partial u}{\partial x} + v \frac{\partial u}{\partial y} = -\frac{1}{\rho} \frac{\partial p}{\partial x} + \nu \left( \frac{\partial^2 u}{\partial x^2} + \frac{\partial^2 u}{\partial y^2} \right) \tag{1}$$

$$\frac{\partial v}{\partial t} + u \frac{\partial v}{\partial x} + v \frac{\partial v}{\partial y} = -\frac{1}{\rho} \frac{\partial p}{\partial y} + \nu \left( \frac{\partial^2 v}{\partial x^2} + \frac{\partial^2 v}{\partial y^2} \right) \tag{2}$$

In the incompressible case the momentum conservation equation and the incompressibility equation are not coupled. We can couple them by taking the divergence of the momentum conservation equation and applying the incompressibility property:

$$\nabla \left(\frac{\partial V}{\partial t} + V \cdot \nabla V\right) = \nabla \left( -\frac{1}{\rho} \nabla p + \nu \Delta V  \right)$$

On the LHS the term $\nabla \cdot \vec{v}$ is zero, so we get:


$$\nabla \left(V \cdot \nabla V\right) = \nabla \left( -\frac{1}{\rho} \nabla p + \nu \Delta V \right)$$

On the RHS $\nabla \nu \nabla^2 \vec{v}$ is zero, so we get:

$$\nabla \left(V \cdot \nabla V\right) = \nabla \left( -\frac{1}{\rho} \nabla p \right) = -\frac{1}{\rho} \Delta p$$

Which we can rearange to:

$$\Delta p = -\rho \nabla \left(V \cdot \nabla V\right)$$

Which can be written in 2D as:

$$\frac{\partial^2 p}{\partial x^2} + \frac{\partial^2 p}{\partial y^2} = -\rho \left(  \frac{\partial u}{\partial x} \frac{\partial u}{\partial x} + 2 \frac{\partial u}{\partial y} \frac{\partial v}{\partial x} + \frac{\partial v}{\partial y} \frac{\partial v}{\partial y} \right) \tag{3}$$


Solvers for the Navier Stokes equations are usually based on the **projection method**. We use the Chorin projection method, which is a semi-implicit method. The method is based on the following steps:

1. Solve the momentum conservation equation for the intermediate velocity field $V^*$, ignoring the pressure term:

$$\frac{\partial V^*}{\partial t} + V \cdot \nabla V^* = \nu \Delta V^* \tag{4}$$

2. Solve the Poisson equation for the pressure field $p$:

$$\Delta p^{n+1} = -\rho \nabla \left(V^* \cdot \nabla V^* \right) \tag{5}$$

3. Correct the intermediate velocity field with the pressure field:

$$V^{n+1} = V^* - \frac{1}{\rho} \nabla p^{n+1} \tag{6}$$

The projection method is a semi-implicit method, because the momentum conservation equation is solved explicitly, while the Poisson equation is solved implicitly.

We use the Mac Cormack method to solve the momentum conservation equation which we have already seen in the previous notebook. To solve the Poisson equation we use the Jacobi method, which is a simple iterative method. There are more advanced methods for solving the Poisson equation, but the Jacobi method is simple and easy to implement.

## Solve for the intermediate velocity field $V^*$

To solve for the intermediate velocity field $V^*$ we use the MacCormack method as follows:

$$\bar{V}_{i,j}^{n+1, *} = V_{i,j}^n - \frac{\Delta t}{\Delta x} \left( E_{i+1, j}^{n} - E_{i, j}^{n} \right) - \frac{\Delta t}{\Delta y} \left( F_{i, j+1}^{n} - F_{i, j}^{n} \right)$$


and the corrector step as:


$$V_{i,j}^{n+1, *} = \frac{1}{2} \left[ \left( V_{i,j}^* + \bar{V}_{i,j}^{n+1,*} \right) - \frac{\Delta t}{ \Delta x} \left( \bar{E}_{i, j}^{n+1} - \bar{E}_{i-1, j}^{n+1} \right) - \frac{\Delta t}{\Delta y} \left( \bar{F}_{i, j}^{n+1} - \bar{F}_{i, j-1}^{n+1} \right) \right]$$

where we define the fluxes as:

$$E = \begin{bmatrix} u\cdot u + \nu \frac{\partial u}{\partial x} \\  u\cdot v + \nu \frac{\partial u}{\partial x}  \end{bmatrix} \quad F = \begin{bmatrix} u\cdot v + \nu \frac{\partial v}{\partial x} \\  v\cdot v + \nu \frac{\partial v}{\partial y}  \end{bmatrix}$$

the velocity $V$ and state vector $U$ as:

$$V = \begin{bmatrix} u \\ v \end{bmatrix} \quad U = \begin{bmatrix} u \\ v \\ p \end{bmatrix}$$


In [6]:
import jax.numpy as jnp
from jax import value_and_grad, jit
from functools import partial

@partial(jit, static_argnames=("delta", "axis"))
def fd(f : jnp.ndarray, delta : float, axis : int):
    return jnp.roll(f, -1, axis=axis) - f / delta

@partial(jit, static_argnames=("delta", "axis"))
def bd(f : jnp.ndarray, delta : float, axis : int):
    return f - jnp.roll(f, 1, axis=axis) / delta

@partial(jit, static_argnames=("delta", "axis"))
def cf(f : jnp.ndarray, delta : float, axis : int):
    return (jnp.roll(f, -1, axis=axis) - jnp.roll(f, 1, axis=0)) / (2*delta)

In [11]:
@partial(jit, static_argnames=("dt", "dx", "dy", "nu", "scheme"))
def construct_E_F(
    V : jnp.ndarray, 
    dt : float, 
    dx : float, 
    dy : float, 
    nu : float, 
    scheme : tuple[str, str] = ("f", "f")
) -> tuple[jnp.ndarray, jnp.ndarray]:

    d_dx = partial(fd, delta=dx, axis=0) if scheme[0] == "f" else partial(bd, delta=dx, axis=0)
    d_dy = partial(fd, delta=dy, axis=1) if scheme[1] == "f" else partial(bd, delta=dy, axis=1)
    
    u = V[..., 0]
    v = V[..., 1]

    E = jnp.zeros_like(V)
    F = jnp.zeros_like(V)

    E[..., 0] = u * u + nu * d_dx(u)
    E[..., 1] = u * v + nu * d_dy(u)

    F[..., 0] = u * v + nu * d_dx(v)
    F[..., 1] = v * v + nu * d_dy(v)

In [12]:
from jax.tree_util import Partial

@partial(jit, static_argnames=("dt", "dx", "dy", "nu", "pred_scheme", "corr_scheme"))
def mac_cormack_2d(
    U : jnp.ndarray, # state vector
    dt : float, # delta time
    dx : float, # delta x
    dy : float, # delta y
    nu : float, # viscosity
    pred_scheme : tuple[str, str], # prediction scheme
    corr_scheme : tuple[str, str], # correction scheme
) -> jnp.ndarray:
    
    pd_dx = Partial(fd, delta=dx, axis=0) if pred_scheme[0] == 'f' else Partial(bd, delta=dx, axis=0)
    pd_dy = Partial(fd, delta=dy, axis=1) if pred_scheme[1] == 'f' else Partial(bd, delta=dy, axis=1)

    cd_dx = Partial(fd, delta=dx, axis=0) if corr_scheme[0] == 'f' else Partial(bd, delta=dx, axis=0)
    cd_dy = Partial(fd, delta=dy, axis=1) if corr_scheme[1] == 'f' else Partial(bd, delta=dy, axis=1)

    E, F = construct_E_F(U, dt, dx, dy, nu, pred_scheme)

    U_pred = U - \
        dt * pd_dx(E) - \
        dt * pd_dy(F)

    E_pred, F_pred = construct_E_F(U_pred, dt, dx, dy, nu, corr_scheme)

    U_fwd = 1/2 * ( (U + U_pred) - \
        dt * cd_dx(E_pred) - \
        dt * cd_dy(F_pred) )
        

    return U_fwd

## Solve the Poisson equation for the pressure field $p$

To solve the Poisson equation for the pressure field $p$ we use the Jacobi method. The Jacobi method is an iterative method, which means that we start with an initial guess for the pressure field $p$ and then we iteratively update the pressure field until we reach a solution. The Jacobi method is defined as:

$$p_{i,j}^{k+1} = \frac{1}{4} \left( p_{i+1,j}^{k} + p_{i-1,j}^{k} + p_{i,j+1}^{k} + p_{i,j-1}^{k} + \rho \Delta x^2 \nabla \left(V^* \cdot \nabla V^* \right) \right)$$

where $k$ is the iteration number.

In [13]:
def jacobi_step(U : jnp.ndarray, dx : float, dy : float) -> jnp.ndarray:
    u = U[..., 0]
    v = U[..., 1]

    u_xx = fd(fd(u, dx, 0), dx, 0)
    u_yy = fd(fd(u, dy, 1), dy, 1)

    v_xx = fd(fd(v, dx, 0), dx, 0)
    v_yy = fd(fd(v, dy, 1), dy, 1)

    p = (u_xx + v_yy) / 2

    return jnp.stack((u, v, p), axis=-1)

def jacobi(U : jnp.ndarray, dx : float, dy : float, iterations : int) -> jnp.ndarray:
    U_jac = U
    for _ in range(iterations):
        U_jac = jacobi_step(U_jac, dx, dy)
    return U_jac

## Bring it all together

In [14]:
def ns_solver_step(U : jnp.ndarray, dt : float, dx : float, dy : float, nu : float, iterations : int) -> jnp.ndarray:
    U = mac_cormack_2d(U, dt, dx, dy, nu, ("f", "f"), ("f", "f"))
    U_jac = jacobi(U, dx, dy, iterations)
    # correct pressure
    U_jac[..., 2] = U_jac[..., 2] - jnp.mean(U_jac[..., 2])

    return U_jac

def ns_solver(U : jnp.ndarray, dt : float, dx : float, dy : float, nu : float, iterations : int, steps : int) -> jnp.ndarray:
    U_ns = U
    for _ in range(steps):
        U_ns = ns_solver_step(U_ns, dt, dx, dy, nu, iterations)
    return U_ns