# streamfunction_vorticity_jax_solvers

> Solvers based on the JAX package for the full Navier-Stokes equations in the streamfunction-vorticity form

In [None]:
#| default_exp streamfunction_vorticity_jax

# Imports

In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export
from uom_project import core

from functools import partial

import numpy as np

import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True) # enable JAX to use double precision


In [None]:
from uom_project import streamfunction_vorticity_newton

import optimistix as optx

from fastcore.test import test_eq, test_close

# JAX Newton solver

In [None]:
#| exporti
@jax.jit
def f_jax(x, Re, U_wall_top):
    N = int(np.sqrt(x.shape[0] // 2 + 1))
    h = 1 / N

    psi = x[:(N-1)**2].reshape(N-1, N-1)
    w_left   = x[(N-1)**2 + 0*(N-1) : (N-1)**2 + 1*(N-1)]
    w_right  = x[(N-1)**2 + 1*(N-1) : (N-1)**2 + 2*(N-1)]
    w_bottom = x[(N-1)**2 + 2*(N-1) : (N-1)**2 + 3*(N-1)]
    w_top    = x[(N-1)**2 + 3*(N-1) : (N-1)**2 + 4*(N-1)]
    w_middle = x[(N-1)**2 + 4*(N-1) :].reshape(N-1, N-1)

    # Calculate the equations coming from the Poisson equation
    f_poisson = -4 * psi
    f_poisson = f_poisson.at[:-1, :].add(psi[1:, :])
    f_poisson = f_poisson.at[1:, :].add(psi[:-1, :])
    f_poisson = f_poisson.at[:, :-1].add(psi[:, 1:])
    f_poisson = f_poisson.at[:, 1:].add(psi[:, :-1])
    f_poisson = f_poisson + h ** 2 * w_middle

    # Calculate the sides first
    # y = 0, U_wall = 0
    f_w_bottom = h ** 2 * (w_middle[:, 0] + 3 * w_bottom) + 8 * psi[:, 0]
    # y = 1, U_wall is known here
    f_w_top = h ** 2 * (w_middle[:, -1] + 3 * w_top) + 8 * (
        h * U_wall_top + psi[:, -1]
    )
    # x = 0
    f_w_left = h ** 2 * (w_middle[0, :] + 3 * w_left) + 8 * psi[0, :]
    # x = 1
    f_w_right = h ** 2 * (w_middle[-1, :] + 3 * w_right) + 8 * psi[-1, :]

    f_w_middle = -4 * w_middle
    f_w_middle = f_w_middle.at[:-1, :].add(w_middle[1:, :])
    f_w_middle = f_w_middle.at[-1:, :].add(w_right)
    f_w_middle = f_w_middle.at[1:, :].add(w_middle[:-1, :])
    f_w_middle = f_w_middle.at[:1, :].add(w_left)
    f_w_middle = f_w_middle.at[:, :-1].add(w_middle[:, 1:])
    f_w_middle = f_w_middle.at[:, -1].add(w_top)
    f_w_middle = f_w_middle.at[:, 1:].add(w_middle[:, :-1])
    f_w_middle = f_w_middle.at[:, 0].add(w_bottom)

    f_w_middle = f_w_middle.at[1:-1, 1:-1].add(Re * (
        (psi[2:, 1:-1] - psi[:-2, 1:-1]) * (w_middle[1:-1, 2:] - w_middle[1:-1, :-2]) -
        (psi[1:-1, 2:] - psi[1:-1, :-2]) * (w_middle[2:, 1:-1] - w_middle[:-2, 1:-1])
    ) / 4)
    f_w_middle = f_w_middle.at[:1, 1:-1].add(Re * (
        psi[1, 1:-1] * (w_middle[0, 2:] - w_middle[0, :-2]) -
        (psi[0, 2:] - psi[0, :-2]) * (w_middle[1, 1:-1] - w_left[1:-1])
    ) / 4)
    f_w_middle = f_w_middle.at[-1:, 1:-1].add(-Re * (
        psi[-2, 1:-1] * (w_middle[-1, 2:] - w_middle[-1, :-2]) +
        (psi[-1, 2:] - psi[-1, :-2]) * (w_right[1:-1] - w_middle[-2, 1:-1])
    ) / 4)
    f_w_middle = f_w_middle.at[1:-1, 0].add(Re * (
        (psi[2:, 0] - psi[:-2, 0]) * (w_middle[1:-1, 1] - w_bottom[1:-1]) -
        psi[1:-1, 1] * (w_middle[2:, 0] - w_middle[:-2, 0])
    ) / 4)
    f_w_middle = f_w_middle.at[1:-1, -1].add(Re * (
        (psi[2:, -1] - psi[:-2, -1]) * (w_top[1:-1] - w_middle[1:-1, -2]) +
        psi[1:-1, -2] * (w_middle[2:, -1] - w_middle[:-2, -1])
    ) / 4)
    f_w_middle = f_w_middle.at[0, 0].add(Re * (
        psi[1, 0] * (w_middle[0, 1] - w_bottom[0]) -
        psi[0, 1] * (w_middle[1, 0] - w_left[0])
    ) / 4)
    f_w_middle = f_w_middle.at[-1, 0].add(-Re * (
        psi[-2, 0] * (w_middle[-1, 1] - w_bottom[-1]) +
        psi[-1, 1] * (w_right[0] - w_middle[-2, 0])
    ) / 4)
    f_w_middle = f_w_middle.at[0, -1].add(Re * (
        psi[1, -1] * (w_top[0] - w_middle[0, -2]) +
        psi[0, -2] * (w_middle[1, -1] - w_left[-1])
    ) / 4)
    f_w_middle = f_w_middle.at[-1, -1].add(-Re * (
        psi[-2, -1] * (w_top[-1] - w_middle[-1, -2]) -
        psi[-1, -2] * (w_right[-1] - w_middle[-2, -1])
    ) / 4)

    return jnp.concatenate([
        f_poisson.flatten(), f_w_left, f_w_right, f_w_bottom, f_w_top,
        f_w_middle.flatten(),
    ], axis=0)


In [None]:
N = 40
size = (N - 1) ** 2 + (N + 1) ** 2 - 4
x = np.random.randn(size)
Re = 10
U_wall_top = np.sin(np.pi * np.arange(1, N) / N) ** 2

res_1 = streamfunction_vorticity_newton.f(x, Re, U_wall_top)


%time x_jax = jax.device_put(x).astype(jnp.float64)
%time U_wall_top_jax = jax.device_put(U_wall_top)
%time res_2 = f_jax(x_jax, Re, U_wall_top_jax).block_until_ready() # Compile once

# Check MSE
test_eq(np.allclose(res_1, res_2), True)
test_close(res_1, res_2, eps=1e-8)

CPU times: user 480 µs, sys: 66 µs, total: 546 µs
Wall time: 2.34 ms
CPU times: user 515 µs, sys: 58 µs, total: 573 µs
Wall time: 6.65 ms
CPU times: user 142 µs, sys: 11 µs, total: 153 µs
Wall time: 383 µs


In [None]:
#| eval: false
%timeit streamfunction_vorticity_newton.f(x, Re, U_wall_top)

271 µs ± 153 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
#| eval: false
%timeit f_jax(x_jax, Re, U_wall_top_jax).block_until_ready()

48.1 µs ± 9.85 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
res_1.shape, res_2.shape

((3198,), (3198,))

In [None]:
#| eval: false
jac = jax.jit(jax.jacfwd(f_jax))
%time jacobian = jac(x_jax, Re, U_wall_top_jax).block_until_ready(); # Compile once

jacobian.shape

CPU times: user 1.11 s, sys: 152 ms, total: 1.26 s
Wall time: 1.54 s


(3198, 3198)

In [None]:
#| eval: false
%timeit -r 30 jac(x_jax, Re, U_wall_top_jax).block_until_ready()

271 ms ± 70.4 ms per loop (mean ± std. dev. of 30 runs, 1 loop each)


In [None]:
#| eval: false
import torch
from torch.func import jacfwd
from uom_project import streamfunction_vorticity_pytorch

x = torch.randn(size, 1).double()
U_wall_top = torch.sin(torch.pi * torch.arange(1, N) / N) ** 2

py_jac = jacfwd(streamfunction_vorticity_pytorch.f_pytorch)

In [None]:
#| eval: false
%timeit py_jac(x, Re, U_wall_top).squeeze()

652 ms ± 99.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
#| exporti

def reconstruct_w_jax(w_tmp, N):
    w = jnp.zeros((N+1, N+1))

    w = w.at[0, 1:-1].set(w_tmp[0*(N-1):1*(N-1)])
    w = w.at[-1, 1:-1].set(w_tmp[1*(N-1):2*(N-1)])
    w = w.at[1:-1, 0].set(w_tmp[2*(N-1):3*(N-1)])
    w = w.at[1:-1, -1].set(w_tmp[3*(N-1):4*(N-1)])
    w = w.at[1:-1, 1:-1].set(w_tmp[4*(N-1):].reshape((N - 1, N - 1)))

    return w


@partial(jax.jit, static_argnums=(0, 5,))
def newton_iteration_step(
    f, x, Re, U_wall_top, f_current, algorithm="lineax_lu",
):
    jacobian = jax.jacfwd(f)(x, Re, U_wall_top)

    dx = core.solve_sparse_linear_system_jax(
        A=jacobian, b=-f_current, algorithm=algorithm,
    )
    x = x + dx

    f_current = f(x=x, Re=Re, U_wall_top=U_wall_top)
    f_norm = jnp.linalg.norm(f_current)

    return x, dx, f_current, f_norm


def newton_iterator_jax(
    f, N, Re, U_wall_top,
    algorithm="lineax_lu", TOL=1e-8, max_iter=10, quiet=True
):
    '''
        - f: evaluates the function given x, Re
        - get_jacobian: evaluates the Jacobian given N, h
        - N: number of grid points
        - h: grid size
        - Re: Reynolds number
    '''

    n_iter = 0 # number of iterations

    # Initialization
    # Size (N - 1) ** 2         + (N + 1) ** 2    - 4
    # Size (for streamfunction) + (for vorticity) - (corners of vorticity)
    x = jnp.zeros(((N - 1) ** 2 + (N + 1) ** 2 - 4, ), dtype=jnp.float64)
    f_current = f(x=x, Re=Re, U_wall_top=U_wall_top)

    # Check if the initial guess is a solution
    f_norm = jnp.linalg.norm(f_current)
    if f_norm <= TOL:
        if not quiet:
            print(f"n_iter={n_iter}")

        return x, n_iter

    # Iterate
    while n_iter < max_iter:
        n_iter += 1

        x, dx, f_current, f_norm = newton_iteration_step(
            f=f, x=x, Re=Re, U_wall_top=U_wall_top,
            f_current=f_current,
            algorithm=algorithm,
        )

        # Check for convergence
        if not quiet:
            print(f"iter={n_iter}; residual={f_norm}; dx={jnp.linalg.norm(dx)}")
        if f_norm <= TOL:
            break

    if not quiet:
        print(f"n_iter={n_iter}")

    return x, n_iter


In [None]:
#| export
def newton_solver_jax(
    f, N, Re, U_wall_top,
    algorithm="lineax_lu", TOL=1e-8, max_iter=10, quiet=True
):

    solution, n_iter = newton_iterator_jax(
        f=f, N=N, Re=Re, U_wall_top=U_wall_top,
        algorithm=algorithm,
        TOL=TOL, max_iter=max_iter, quiet=quiet
    )

    psi, w = solution[:(N - 1) ** 2], solution[(N - 1) ** 2:]

    # Get final psi
    psi = psi.reshape(N - 1, N - 1)
    psi = jnp.pad(psi, (1, 1), mode="constant", constant_values=0)

    # Get final w
    w = reconstruct_w_jax(w_tmp=w, N=N)
    w = w.reshape(N + 1, N + 1)

    return w, psi, n_iter


In [None]:
N = 20
Re = 0.0 # i.e. viscosity mu = inf
U_wall_top = np.sin(np.pi * np.arange(1, N) / N) ** 2

w, psi, _ = streamfunction_vorticity_newton.newton_solver(
    f=partial(streamfunction_vorticity_newton.f, U_wall_top=U_wall_top),
    get_jacobian=streamfunction_vorticity_newton.get_jacobian,
    N=N, Re=Re, quiet=False
)

U_wall_top_jax = jax.device_put(U_wall_top)
w2, psi2, n_iter = newton_solver_jax(
    f=f_jax,
    U_wall_top=U_wall_top_jax,
    N=N, Re=Re, quiet=False,
    algorithm="lineax_lu",
)
test_eq(np.allclose(w, w2), True)
test_eq(np.allclose(psi, psi2), True)
test_eq(n_iter, 1)

iter=1; residual=1.8164849050301797e-14; dx=39.71450891546252
n_iter=1
iter=1; residual=1.5969150036293684e-14; dx=39.71450891546271
n_iter=1


In [None]:
N = 20
Re = 10.0
U_wall_top = np.sin(np.pi * np.arange(1, N) / N) ** 2

w, psi, _ = streamfunction_vorticity_newton.newton_solver(
    f=partial(streamfunction_vorticity_newton.f, U_wall_top=U_wall_top),
    get_jacobian=streamfunction_vorticity_newton.get_jacobian,
    N=N, Re=Re, quiet=False
)

U_wall_top_jax = jax.device_put(U_wall_top)
w2, psi2, n_iter = newton_solver_jax(
    f=f_jax,
    U_wall_top=U_wall_top_jax,
    N=N, Re=Re, quiet=False,
    algorithm="lineax_lu",
)
test_eq(np.allclose(w, w2), True)
test_eq(np.allclose(psi, psi2), True)
test_eq(n_iter, 3)

iter=1; residual=1.0836381666439978; dx=39.71450891546252
iter=2; residual=0.0015404929313136678; dx=2.273926820301393
iter=3; residual=8.799686196636605e-10; dx=0.001464259783738233
n_iter=3
iter=1; residual=1.0836381666440216; dx=39.71450891546271
iter=2; residual=0.0015404929313102245; dx=2.2739268203012495
iter=3; residual=8.799683858705381e-10; dx=0.0014642597837290449
n_iter=3


In [None]:
N = 20
Re = 10.0
U_wall_top = 1.0

w, psi, _ = streamfunction_vorticity_newton.newton_solver(
    f=partial(streamfunction_vorticity_newton.f, U_wall_top=U_wall_top),
    get_jacobian=streamfunction_vorticity_newton.get_jacobian,
    N=N, Re=Re, quiet=False
)

U_wall_top_jax = jax.device_put(U_wall_top)
w2, psi2, n_iter = newton_solver_jax(
    f=f_jax,
    U_wall_top=U_wall_top_jax,
    N=N, Re=Re, quiet=False,
    algorithm="lineax_lu",
)
test_eq(np.allclose(w, w2), True)
test_eq(np.allclose(psi, psi2), True)
test_eq(n_iter, 3)

iter=1; residual=5.302183624903474; dx=92.25374180247326
iter=2; residual=0.006521635545205214; dx=4.11756391107747
iter=3; residual=3.950959477471047e-09; dx=0.0036893485573113297
n_iter=3
iter=1; residual=5.302183624903549; dx=92.25374180247387
iter=2; residual=0.006521635545207267; dx=4.117563911077535
iter=3; residual=3.950963700483887e-09; dx=0.0036893485572719108
n_iter=3


In [None]:
#| eval: false
N = 40
Re = 10.0
U_wall_top = np.sin(np.pi * np.arange(1, N) / N) ** 2
U_wall_top_jax = jax.device_put(U_wall_top)

fun = partial(streamfunction_vorticity_newton.f, U_wall_top=U_wall_top)

jac = streamfunction_vorticity_newton.get_jacobian

# NOTE: lineax version compiles about 2x faster than jax the version
%time newton_solver_jax(f=f_jax, U_wall_top=U_wall_top_jax, N=N, Re=Re, algorithm="jax_base");
%time newton_solver_jax(f=f_jax, U_wall_top=U_wall_top_jax, N=N, Re=Re, algorithm="lineax_lu");

CPU times: user 4.83 s, sys: 438 ms, total: 5.27 s
Wall time: 4.42 s
CPU times: user 3.78 s, sys: 354 ms, total: 4.14 s
Wall time: 2.88 s


In [None]:
#| eval: false
%timeit streamfunction_vorticity_newton.newton_solver(f=fun, get_jacobian=jac, N=N, Re=Re)

3.4 s ± 679 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
#| eval: false
%timeit -r 30 newton_solver_jax(f=f_jax, U_wall_top=U_wall_top_jax, N=N, Re=Re, algorithm="jax_base")
# %timeit newton_solver_jax(f=f_jax, U_wall_top=U_wall_top_jax, N=N, Re=Re, algorithm="jax_base")

The slowest run took 4.07 times longer than the fastest. This could mean that an intermediate result is being cached.
1.95 s ± 786 ms per loop (mean ± std. dev. of 30 runs, 1 loop each)


On a GPU:

520 ms ± 2.56 ms per loop (mean ± std. dev. of 30 runs, 1 loop each)

In [None]:
#| eval: false
%timeit -r 30 newton_solver_jax(f=f_jax, U_wall_top=U_wall_top_jax, N=N, Re=Re, algorithm="lineax_lu");
# %timeit newton_solver_jax(f=f_jax, U_wall_top=U_wall_top_jax, N=N, Re=Re, algorithm="lineax_lu");

2.25 s ± 668 ms per loop (mean ± std. dev. of 30 runs, 1 loop each)


On a GPU:

525 ms ± 2.81 ms per loop (mean ± std. dev. of 30 runs, 1 loop each)

# JAX root solver

In [None]:
# NOTE: other algorithms are available in optimistix, but even Newton is
# slower than my baseline, so I did not bother checking the others
JAX_ROOT_SOLVING_ALGORITHMS = {
    "newton": optx.Newton,
}

def root_solver_jax(
    f, N, Re, U_wall_top, rtol=1e-8, atol=1e-8, max_steps=10, algorithm="newton",
):
    solver = JAX_ROOT_SOLVING_ALGORITHMS[algorithm](atol=atol, rtol=rtol)

    fn = lambda x, args: f(x, *args)
    solution = optx.root_find(
        fn=fn,
        args=(Re, U_wall_top),
        y0=jnp.zeros((N - 1) ** 2 + (N + 1) ** 2 - 4, dtype=jnp.float64),
        solver=solver,
        max_steps=max_steps,
        options={
            "norm": optx._misc.two_norm,
            # "norm": optx._misc.max_norm,
        }
    )

    psi, w = solution.value[:(N - 1) ** 2], solution.value[(N - 1) ** 2:]

    # Get final psi
    psi = psi.reshape(N - 1, N - 1)
    psi = jnp.pad(psi, (1, 1), mode="constant", constant_values=0)

    # Get final w
    w = reconstruct_w_jax(w_tmp=w, N=N)
    w = w.reshape(N + 1, N + 1)

    return w, psi, solution


In [None]:
#| eval: false

N = 20
Re = 0.0 # i.e. viscosity mu = inf
U_wall_top = np.sin(np.pi * np.arange(1, N) / N) ** 2

w, psi, _ = streamfunction_vorticity_newton.newton_solver(
    f=partial(streamfunction_vorticity_newton.f, U_wall_top=U_wall_top),
    get_jacobian=streamfunction_vorticity_newton.get_jacobian,
    N=N, Re=Re, quiet=False
)

U_wall_top_jax = jax.device_put(U_wall_top)
w2, psi2, solution = root_solver_jax(
    f=f_jax, U_wall_top=U_wall_top_jax,
    N=N, Re=Re,
    algorithm="newton",
)

test_eq(np.allclose(w, w2), True)
test_eq(np.allclose(psi, psi2), True)
test_eq(solution.stats["num_steps"].item(), 2)

iter=1; residual=1.8164849050301797e-14; dx=39.71450891546252
n_iter=1


In [None]:
#| eval: false

N = 20
Re = 10.0
U_wall_top = np.sin(np.pi * np.arange(1, N) / N) ** 2

w, psi, _ = streamfunction_vorticity_newton.newton_solver(
    f=partial(streamfunction_vorticity_newton.f, U_wall_top=U_wall_top),
    get_jacobian=streamfunction_vorticity_newton.get_jacobian,
    N=N, Re=Re, quiet=False
)

U_wall_top_jax = jax.device_put(U_wall_top)
w2, psi2, solution = root_solver_jax(
    f=f_jax, U_wall_top=U_wall_top_jax,
    N=N, Re=Re,
    algorithm="newton",
)

test_eq(np.allclose(w, w2), True)
test_eq(np.allclose(psi, psi2), True)
test_eq(solution.stats["num_steps"].item(), 4)

iter=1; residual=1.0836381666439978; dx=39.71450891546252
iter=2; residual=0.0015404929313136678; dx=2.273926820301393
iter=3; residual=8.799686196636605e-10; dx=0.001464259783738233
n_iter=3


In [None]:
#| eval: false

N = 20
Re = 10.0
U_wall_top = 1.0

w, psi, _ = streamfunction_vorticity_newton.newton_solver(
    f=partial(streamfunction_vorticity_newton.f, U_wall_top=U_wall_top),
    get_jacobian=streamfunction_vorticity_newton.get_jacobian,
    N=N, Re=Re, quiet=False
)

U_wall_top_jax = jax.device_put(U_wall_top)
w2, psi2, solution = root_solver_jax(
    f=f_jax, U_wall_top=U_wall_top_jax,
    N=N, Re=Re,
    algorithm="newton",
)

test_eq(np.allclose(w, w2), True)
test_eq(np.allclose(psi, psi2), True)
test_eq(solution.stats["num_steps"].item(), 4)

iter=1; residual=5.302183624903474; dx=92.25374180247326
iter=2; residual=0.006521635545205214; dx=4.11756391107747
iter=3; residual=3.950959477471047e-09; dx=0.0036893485573113297
n_iter=3


In [None]:
#| eval: false

N = 20
Re = 10.0
U_wall_top = np.sin(np.pi * np.arange(1, N) / N) ** 2
U_wall_top_jax = jax.device_put(U_wall_top)

fun = partial(streamfunction_vorticity_newton.f, U_wall_top=U_wall_top)

jac = streamfunction_vorticity_newton.get_jacobian

%time root_solver_jax(f=f_jax, U_wall_top=U_wall_top_jax, N=N, Re=Re, algorithm="newton");

CPU times: user 3.87 s, sys: 147 ms, total: 4.01 s
Wall time: 4.2 s


In [None]:
#| eval: false
%timeit streamfunction_vorticity_newton.newton_solver(f=fun, get_jacobian=jac, N=N, Re=Re)

The slowest run took 4.26 times longer than the fastest. This could mean that an intermediate result is being cached.
683 ms ± 470 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
#| eval: false
%timeit root_solver_jax(f=f_jax, U_wall_top=U_wall_top_jax, N=N, Re=Re, algorithm="newton")

3.8 s ± 192 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Export

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()