In [1]:
import jax.numpy as jnp
from pde_opt.numerics.utils.derivatives import laplacian
from pde_opt.numerics.utils.boundary_conditions import get_neighbor_indices

In [2]:
import jax.numpy as jnp
from pde_opt.numerics.utils.derivatives import laplacian

# 1D with periodic BCs
lap_1d = laplacian(dx=0.1, boundary_conditions={0: ('periodic', None)})

# 2D with mixed BCs
lap_2d = laplacian(dx=(0.1, 0.1), boundary_conditions={
    0: ('dirichlet', 0.0),      # y-axis: zero Dirichlet
    1: ('periodic', None)       # x-axis: periodic
})

# 3D with Neumann BCs
lap_3d = laplacian(dx=(0.1, 0.1, 0.1), boundary_conditions={
    0: ('neumann', 0.0),        # z-axis: zero Neumann
    1: ('neumann', 0.0),        # y-axis: zero Neumann
    2: ('periodic', None)       # x-axis: periodic
})

In [3]:
test = jnp.array([1, 2, 3])


In [4]:
lap_1d(test)

Array([ 299.99997,    0.     , -299.99997], dtype=float32)

In [5]:
mat = jnp.array([[-2, 1, 1], [1, -2, 1], [1, 1, -2]]) / 0.1**2
mat @ test

Array([ 300.,    0., -300.], dtype=float32, weak_type=True)

In [14]:
# For 2D field (10, 10) with boundary on x-axis
boundary_values = jnp.array([1.0, 1.0, 1.0])

lap_op = laplacian(dx=(0.1, 0.1), boundary_conditions={
    1: ('dirichlet', boundary_values)  # x-axis boundary
})

In [15]:
tmp = jnp.ones((3, 3))


In [16]:
lap_op(tmp, 0.0)

Array([[-100., -100., -100.],
       [   0.,    0.,    0.],
       [-100., -100., -100.]], dtype=float32)