# core

> Core methods used later in the project

In [None]:
#| default_exp core

# Imports

In [None]:
# Autoreload modules
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import numpy as np
from scipy import sparse

import torch

import jax
import lineax as lx

# Code

In [None]:
#| export
def setup_poisson_problem(N):
    """
    Setup a 2D Poisson problem. Only square domains are supported.
    """

    nx = ny = N + 1

    # Initialize grid
    x_grid, y_grid = np.meshgrid(
        np.linspace(0, 1, nx), np.linspace(0, 1, ny), indexing="ij"
    )

    # Initialize the value of the vorticity on the grid
    w = 2 * np.pi ** 2 * np.sin(np.pi * x_grid) * np.sin(np.pi * y_grid)

    exact_solution = np.sin(np.pi * x_grid) * np.sin(np.pi * y_grid)

    return w, exact_solution, nx, ny, x_grid, y_grid

In [None]:
#| exporti
SPARSE_ALGORITHM_DICT = {
    "base": sparse.linalg.spsolve,
    "cg": sparse.linalg.cg,
    "bicgstab": sparse.linalg.bicgstab,
}

In [None]:
#| export
def solve_sparse_linear_system(A, b, algorithm="base", **kwargs):
    """
    Solve a sparse linear system using the specified algorithm.
    """

    solver = SPARSE_ALGORITHM_DICT[algorithm]

    if algorithm == "base" and len(b.shape) == 1:
        b = sparse.csr_matrix(b[:, None])

    soln = solver(A, b, **kwargs)

    return soln if algorithm == "base" else soln[0]


In [None]:
#| exporti
def lu_solve(A, b):
    LU, pivots = torch.linalg.lu_factor(A)
    return torch.linalg.lu_solve(LU, pivots, b)


TORCH_ALGORITHM_DICT = {
    "base": torch.linalg.solve,
    "lu_solve": lu_solve,
    "lstsq": torch.linalg.lstsq,
}


In [None]:
#| export
def solve_sparse_linear_system_pytorch(A, b, algorithm="base", **kwargs):
    """
    Solve a sparse linear system using the specified algorithm.
    """

    solver = TORCH_ALGORITHM_DICT[algorithm]

    out = solver(A, b, **kwargs)

    if algorithm == "lstsq": out = out.solution

    return out


In [None]:
#| exporti
JAX_ALGORITHM_DICT = {
    "base": jax.numpy.linalg.solve,
    "lineax_base": lx.linear_solve,
}

In [None]:
#| export
def solve_sparse_linear_system_jax(A, b, algorithm="base", **kwargs):
    """
    Solve a sparse linear system using the specified algorithm.
    """

    solver = JAX_ALGORITHM_DICT[algorithm]

    if algorithm != "base":
        A = lx.MatrixLinearOperator(A)
        b = b.squeeze()
    
    out = solver(A, b, **kwargs)

    if algorithm != "base": out = out.value[:, None]

    return out


In [None]:
#| export
def calculate_force(psi):
    h = 1 / (psi.shape[0] - 1)
    ny = psi.shape[1]
    
    return np.sum(
        2 * psi[1:-1, ny - 1] -
        5 * psi[1:-1, ny - 2] +
        4 * psi[1:-1, ny - 3] -
        psi[1:-1, ny - 4]
    ) / h

In [None]:
#| export
def calculate_velocity(psi):
    h = 1 / (psi.shape[0] - 1)

    nx = psi.shape[0]
    
    u = np.zeros_like(psi)
    v = np.zeros_like(psi)
    
    u[1:-1, 1:-1] = (psi[1:-1, 2:] - psi[1:-1, :-2]) / (2 * h)
    v[1:-1, 1:-1] = -(psi[2:, 1:-1] - psi[:-2, 1:-1]) / (2 * h)
    
    u[1:-1, -1] = np.sin(np.pi * np.arange(1, nx - 1) * h) ** 2
    
    return u, v 

# Export

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()