# streamfunction_vorticity_jax_solvers

> Solvers based on the JAX package for the full Navier-Stokes equations in the streamfunction-vorticity form

In [None]:
#| default_exp streamfunction_vorticity_jax

# Imports

In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from uom_project import core

from functools import partial

import numpy as np

import jax
import jax.numpy as jnp

from jax import config
config.update("jax_enable_x64", True)

In [None]:
from uom_project import poisson_solvers, streamfunction_vorticity_newton

from fastcore.test import test_eq, test_close

# JAX solver

In [None]:
#| exporti
@jax.jit
def f_jax(x, Re, U_wall_top):
    N = int(np.sqrt(x.shape[0] // 2 + 1))
    h = 1 / N

    psi = x[:(N-1)**2].reshape(N-1, N-1)
    w_left   = x[(N-1)**2 + 0*(N-1) : (N-1)**2 + 1*(N-1)][:, 0]
    w_right  = x[(N-1)**2 + 1*(N-1) : (N-1)**2 + 2*(N-1)][:, 0]
    w_bottom = x[(N-1)**2 + 2*(N-1) : (N-1)**2 + 3*(N-1)][:, 0]
    w_top    = x[(N-1)**2 + 3*(N-1) : (N-1)**2 + 4*(N-1)][:, 0]
    w_middle = x[(N-1)**2 + 4*(N-1) :].reshape(N-1, N-1)

    # Calculate the equations coming from the Poisson equation
    f_poisson = -4 * psi
    f_poisson = f_poisson.at[:-1, :].add(psi[1:, :])
    f_poisson = f_poisson.at[1:, :].add(psi[:-1, :])
    f_poisson = f_poisson.at[:, :-1].add(psi[:, 1:])
    f_poisson = f_poisson.at[:, 1:].add(psi[:, :-1])
    f_poisson = f_poisson + h ** 2 * w_middle

    # Calculate the sides first
    # y = 0, U_wall = 0
    f_w_bottom = h ** 2 * (w_middle[:, 0] + 3 * w_bottom) + 8 * psi[:, 0]
    # y = 1, U_wall is known here
    f_w_top = h ** 2 * (w_middle[:, -1] + 3 * w_top) + 8 * (
        h * U_wall_top + psi[:, -1]
    )
    # x = 0
    f_w_left = h ** 2 * (w_middle[0, :] + 3 * w_left) + 8 * psi[0, :]
    # x = 1
    f_w_right = h ** 2 * (w_middle[-1, :] + 3 * w_right) + 8 * psi[-1, :]

    f_w_middle = -4 * w_middle
    f_w_middle = f_w_middle.at[:-1, :].add(w_middle[1:, :])
    f_w_middle = f_w_middle.at[-1:, :].add(w_right)
    f_w_middle = f_w_middle.at[1:, :].add(w_middle[:-1, :])
    f_w_middle = f_w_middle.at[:1, :].add(w_left)
    f_w_middle = f_w_middle.at[:, :-1].add(w_middle[:, 1:])
    f_w_middle = f_w_middle.at[:, -1].add(w_top)
    f_w_middle = f_w_middle.at[:, 1:].add(w_middle[:, :-1])
    f_w_middle = f_w_middle.at[:, 0].add(w_bottom)

    f_w_middle = f_w_middle.at[1:-1, 1:-1].add(Re * (
        (psi[2:, 1:-1] - psi[:-2, 1:-1]) * (w_middle[1:-1, 2:] - w_middle[1:-1, :-2]) -
        (psi[1:-1, 2:] - psi[1:-1, :-2]) * (w_middle[2:, 1:-1] - w_middle[:-2, 1:-1])
    ) / 4)
    f_w_middle = f_w_middle.at[:1, 1:-1].add(Re * (
        psi[1, 1:-1] * (w_middle[0, 2:] - w_middle[0, :-2]) -
        (psi[0, 2:] - psi[0, :-2]) * (w_middle[1, 1:-1] - w_left[1:-1])
    ) / 4)
    f_w_middle = f_w_middle.at[-1:, 1:-1].add(-Re * (
        psi[-2, 1:-1] * (w_middle[-1, 2:] - w_middle[-1, :-2]) +
        (psi[-1, 2:] - psi[-1, :-2]) * (w_right[1:-1] - w_middle[-2, 1:-1])
    ) / 4)
    f_w_middle = f_w_middle.at[1:-1, 0].add(Re * (
        (psi[2:, 0] - psi[:-2, 0]) * (w_middle[1:-1, 1] - w_bottom[1:-1]) -
        psi[1:-1, 1] * (w_middle[2:, 0] - w_middle[:-2, 0])
    ) / 4)
    f_w_middle = f_w_middle.at[1:-1, -1].add(Re * (
        (psi[2:, -1] - psi[:-2, -1]) * (w_top[1:-1] - w_middle[1:-1, -2]) +
        psi[1:-1, -2] * (w_middle[2:, -1] - w_middle[:-2, -1])
    ) / 4)
    f_w_middle = f_w_middle.at[0, 0].add(Re * (
        psi[1, 0] * (w_middle[0, 1] - w_bottom[0]) -
        psi[0, 1] * (w_middle[1, 0] - w_left[0])
    ) / 4)
    f_w_middle = f_w_middle.at[-1, 0].add(-Re * (
        psi[-2, 0] * (w_middle[-1, 1] - w_bottom[-1]) +
        psi[-1, 1] * (w_right[0] - w_middle[-2, 0])
    ) / 4)
    f_w_middle = f_w_middle.at[0, -1].add(Re * (
        psi[1, -1] * (w_top[0] - w_middle[0, -2]) +
        psi[0, -2] * (w_middle[1, -1] - w_left[-1])
    ) / 4)
    f_w_middle = f_w_middle.at[-1, -1].add(-Re * (
        psi[-2, -1] * (w_top[-1] - w_middle[-1, -2]) -
        psi[-1, -2] * (w_right[-1] - w_middle[-2, -1])
    ) / 4)

    return jnp.concatenate([
        f_poisson.flatten(), f_w_left, f_w_right, f_w_bottom, f_w_top,
        f_w_middle.flatten(),
    ], axis=0)[:, None]


In [None]:
N = 40
size = (N - 1) ** 2 + (N + 1) ** 2 - 4
x = np.random.randn(size, 1)
Re = 10
kernel_matrix = poisson_solvers.construct_laplacian_kernel_matrix(N=N-1, h=1)
U_wall_top = np.sin(np.pi * np.arange(1, N) / N) ** 2

res_1 = streamfunction_vorticity_newton.f(x, Re, U_wall_top, kernel_matrix)


%time x_jax = jax.device_put(x).astype(jnp.float64)
%time U_wall_top_jax = jax.device_put(U_wall_top)
%time res_2 = f_jax(x_jax, Re, U_wall_top_jax).block_until_ready() # Compile once

# Check MSE
test_eq(np.allclose(res_1, res_2), True)
test_close(res_1, res_2, eps=1e-8)

CPU times: user 11.6 ms, sys: 5.17 ms, total: 16.8 ms
Wall time: 25.4 ms
CPU times: user 486 µs, sys: 56 µs, total: 542 µs
Wall time: 1.87 ms
CPU times: user 668 ms, sys: 30.7 ms, total: 699 ms
Wall time: 1.03 s


In [None]:
#| eval: false
%timeit streamfunction_vorticity_newton.f(x, Re, U_wall_top, kernel_matrix)

221 µs ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
#| eval: false
%timeit f_jax(x_jax, Re, U_wall_top_jax).block_until_ready()

66.6 µs ± 15.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
res_1.shape, res_2.shape

((3198, 1), (3198, 1))

In [None]:
#| eval: false
jac = jax.jit(jax.jacfwd(f_jax))
%time jac(x_jax, Re, U_wall_top_jax).squeeze().block_until_ready(); # Compile once

CPU times: user 1.27 s, sys: 180 ms, total: 1.45 s
Wall time: 2.09 s


In [None]:
#| eval: false
%timeit jac(x_jax, Re, U_wall_top_jax).squeeze().block_until_ready()

376 ms ± 76.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
import torch
from torch.func import jacfwd
from uom_project import streamfunction_vorticity_pytorch

x = torch.randn(size, 1).double()
U_wall_top = torch.sin(torch.pi * torch.arange(1, N) / N) ** 2
print(kernel_matrix.shape)

py_jac = jacfwd(streamfunction_vorticity_pytorch.f_pytorch)

(1521, 1521)


In [None]:
#| eval: false
%timeit py_jac(x, Re, U_wall_top).squeeze()

572 ms ± 86.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
#| exporti

def reconstruct_w_jax(w_tmp, N):
    w = jnp.zeros((N+1, N+1))

    w = w.at[:1, 1:-1].set(w_tmp[0*(N-1):1*(N-1)].T)
    w = w.at[-1:, 1:-1].set(w_tmp[1*(N-1):2*(N-1)].T)
    w = w.at[1:-1, :1].set(w_tmp[2*(N-1):3*(N-1)])
    w = w.at[1:-1, -1:].set(w_tmp[3*(N-1):4*(N-1)])
    w = w.at[1:-1, 1:-1].set(w_tmp[4*(N-1):].reshape((N - 1, N - 1)))

    return w


@partial(jax.jit, static_argnums=(0, 1, 6,))
def newton_iteration_step(
    f, get_jacobian, x, Re, U_wall_top, f_current, algorithm,
):
    jacobian = get_jacobian(x, Re, U_wall_top).squeeze()

    dx = core.solve_sparse_linear_system_jax(
        A=jacobian, b=-f_current, algorithm=algorithm,
    )
    x = x + dx

    f_current = f(x=x, Re=Re, U_wall_top=U_wall_top)
    f_norm = jnp.linalg.norm(f_current)

    return x, dx, f_current, f_norm


def newton_iterator_jax(
    f, get_jacobian, N, Re, U_wall_top,
    algorithm="base", TOL=1e-8, max_iter=10, quiet=True
):
    '''
        - f: evaluates the function given x, Re
        - get_jacobian: evaluates the Jacobian given N, h
        - N: number of grid points
        - h: grid size
        - Re: Reynolds number
    '''

    n_iter = 0 # number of iterations

    # Initialization
    # Size (N - 1) ** 2         + (N + 1) ** 2    - 4
    # Size (for streamfunction) + (for vorticity) - (corners of vorticity)
    x = jnp.zeros(((N - 1) ** 2 + (N + 1) ** 2 - 4, 1), dtype=jnp.float64)
    f_current = f(x=x, Re=Re, U_wall_top=U_wall_top)

    # Check if the initial guess is a solution
    f_norm = jnp.linalg.norm(f_current)
    if f_norm <= TOL:
        if not quiet:
            print(f"n_iter={n_iter}")

        return x, n_iter

    # Iterate
    while n_iter < max_iter:
        n_iter += 1

        x, dx, f_current, f_norm = newton_iteration_step(
            f=f, get_jacobian=get_jacobian,
            x=x, Re=Re, U_wall_top=U_wall_top,
            f_current=f_current,
            algorithm=algorithm,
        )

        # Check for convergence
        if not quiet:
            print(f"iter={n_iter}; residual={f_norm}; dx={jnp.linalg.norm(dx)}")
        if f_norm <= TOL:
            break

    if not quiet:
        print(f"n_iter={n_iter}")

    return x, n_iter


In [None]:
#| export
def newton_solver_jax(
    f, get_jacobian, N, Re, U_wall_top,
    algorithm="base", TOL=1e-8, max_iter=10, quiet=True
):

    solution, n_iter = newton_iterator_jax(
        f=f, get_jacobian=get_jacobian, N=N, Re=Re, U_wall_top=U_wall_top,
        algorithm=algorithm,
        TOL=TOL, max_iter=max_iter, quiet=quiet
    )

    psi, w = solution[:(N - 1) ** 2], solution[(N - 1) ** 2:]

    # Get final psi
    psi = psi.reshape(N - 1, N - 1)
    psi = jnp.pad(psi, (1, 1), mode="constant", constant_values=0)

    # Get final w
    w = reconstruct_w_jax(w_tmp=w, N=N)
    w = w.reshape(N + 1, N + 1)

    return w, psi, n_iter


In [None]:
N = 20
Re = 0 # i.e. viscosity mu = inf
U_wall_top = np.sin(np.pi * np.arange(1, N) / N) ** 2

w, psi, _ = streamfunction_vorticity_newton.newton_solver(
    f=partial(streamfunction_vorticity_newton.f, U_wall_top=U_wall_top),
    get_jacobian=streamfunction_vorticity_newton.get_jacobian,
    N=N, Re=Re, quiet=False
)

U_wall_top_jax = jax.device_put(U_wall_top)
w2, psi2, n_iter = newton_solver_jax(
    f=f_jax,
    get_jacobian=jax.jit(jax.jacfwd(f_jax)),
    U_wall_top=U_wall_top_jax,
    N=N, Re=Re, quiet=False,
    algorithm="lineax_base",
)
test_eq(np.allclose(w, w2), True)
test_eq(np.allclose(psi, psi2), True)
test_eq(n_iter, 1)

iter=1; residual=2.914118169851718e-14; dx=39.71450891546215
n_iter=1
iter=1; residual=1.5969150036293684e-14; dx=39.71450891546271
n_iter=1


In [None]:
N = 20
Re = 10
U_wall_top = np.sin(np.pi * np.arange(1, N) / N) ** 2

w, psi, _ = streamfunction_vorticity_newton.newton_solver(
    f=partial(streamfunction_vorticity_newton.f, U_wall_top=U_wall_top),
    get_jacobian=streamfunction_vorticity_newton.get_jacobian,
    N=N, Re=Re, quiet=False
)

U_wall_top_jax = jax.device_put(U_wall_top)
w2, psi2, n_iter = newton_solver_jax(
    f=f_jax,
    get_jacobian=jax.jit(jax.jacfwd(f_jax)),
    U_wall_top=U_wall_top_jax,
    N=N, Re=Re, quiet=False,
    algorithm="lineax_base",
)
test_eq(np.allclose(w, w2), True)
test_eq(np.allclose(psi, psi2), True)
test_eq(n_iter, 3)

iter=1; residual=1.083638166644002; dx=39.71450891546215
iter=2; residual=0.001540492931316756; dx=2.273926820301316
iter=3; residual=8.799678364852772e-10; dx=0.001464259783746122
n_iter=3
iter=1; residual=1.0836381666440216; dx=39.71450891546271
iter=2; residual=0.0015404929313102245; dx=2.2739268203012495
iter=3; residual=8.799683858705381e-10; dx=0.0014642597837290449
n_iter=3


In [None]:
N = 20
Re = 10
U_wall_top = 1.0

w, psi, _ = streamfunction_vorticity_newton.newton_solver(
    f=partial(streamfunction_vorticity_newton.f, U_wall_top=U_wall_top),
    get_jacobian=streamfunction_vorticity_newton.get_jacobian,
    N=N, Re=Re, quiet=False
)

U_wall_top_jax = jax.device_put(U_wall_top)
w2, psi2, n_iter = newton_solver_jax(
    f=f_jax,
    get_jacobian=jax.jit(jax.jacfwd(f_jax)),
    U_wall_top=U_wall_top_jax,
    N=N, Re=Re, quiet=False,
    algorithm="lineax_base",
)
test_eq(np.allclose(w, w2), True)
test_eq(np.allclose(psi, psi2), True)
test_eq(n_iter, 3)

iter=1; residual=5.3021836249034475; dx=92.2537418024731
iter=2; residual=0.006521635545184045; dx=4.1175639110773234
iter=3; residual=3.950960452092918e-09; dx=0.003689348557317348
n_iter=3
iter=1; residual=5.302183624903549; dx=92.25374180247387
iter=2; residual=0.006521635545207267; dx=4.117563911077535
iter=3; residual=3.950963700483887e-09; dx=0.0036893485572719108
n_iter=3


In [None]:
N = 40
Re = 10
U_wall_top = np.sin(np.pi * np.arange(1, N) / N) ** 2
U_wall_top_jax = jax.device_put(U_wall_top)

fun = partial(streamfunction_vorticity_newton.f, U_wall_top=U_wall_top)

jac = streamfunction_vorticity_newton.get_jacobian
jac_jax = jax.jit(jax.jacfwd(f_jax))

%time newton_solver_jax(f=f_jax, get_jacobian=jac_jax, U_wall_top=U_wall_top_jax, N=N, Re=Re, algorithm="base");
%time newton_solver_jax(f=f_jax, get_jacobian=jac_jax, U_wall_top=U_wall_top_jax, N=N, Re=Re, algorithm="lineax_base");

CPU times: user 4.91 s, sys: 429 ms, total: 5.34 s
Wall time: 5.14 s
CPU times: user 4.22 s, sys: 362 ms, total: 4.58 s
Wall time: 5.28 s


In [None]:
#| eval: false
%timeit streamfunction_vorticity_newton.newton_solver(f=fun, get_jacobian=jac, N=N, Re=Re)

4.26 s ± 851 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
#| eval: false
%timeit newton_solver_jax(f=f_jax, get_jacobian=jac_jax, U_wall_top=U_wall_top_jax, N=N, Re=Re, algorithm="base")

2.26 s ± 1.05 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
#| eval: false
%timeit newton_solver_jax(f=f_jax, get_jacobian=jac_jax, U_wall_top=U_wall_top_jax, N=N, Re=Re, algorithm="lineax_base")

1.71 s ± 215 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Export

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()