Skip to content

Commit

Permalink
extend reconstruction interface
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Dec 12, 2022
1 parent e848542 commit 4f52871
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 62 deletions.
10 changes: 6 additions & 4 deletions pyshocks/advection/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ def upwind_flux(

from pyshocks.reconstruction import reconstruct

ul, ur = reconstruct(scheme.rec, grid, bc.boundary_type, u)
al, ar = reconstruct(scheme.rec, grid, bc.boundary_type, scheme.velocity)
a = scheme.velocity
ul, ur = reconstruct(scheme.rec, grid, bc.boundary_type, u, u, a)
al, ar = reconstruct(scheme.rec, grid, bc.boundary_type, a, a, a)

aavg = (ar[:-1] + al[1:]) / 2
fnum = jnp.where(aavg > 0, ur[:-1], ul[1:])
Expand Down Expand Up @@ -150,9 +151,10 @@ def esweno_lf_flux(

from pyshocks.reconstruction import reconstruct

# FIXME: what the hell is this? Why are we reconstructing f?
f = scheme.velocity * u
ul, ur = reconstruct(scheme.rec, grid, bc.boundary_type, u)
fl, fr = reconstruct(scheme.rec, grid, bc.boundary_type, f)
ul, ur = reconstruct(scheme.rec, grid, bc.boundary_type, u, u, u)
fl, fr = reconstruct(scheme.rec, grid, bc.boundary_type, f, u, scheme.velocity)

a = jnp.max(jnp.abs(scheme.velocity))
fnum = 0.5 * (fl[1:] + fr[:-1]) - 0.5 * a * (ul[1:] - ur[:-1])
Expand Down
2 changes: 1 addition & 1 deletion pyshocks/burgers/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _numerical_flux_burgers_engquist_osher(
from pyshocks.scalar import scalar_flux_engquist_osher

return scalar_flux_engquist_osher(
scheme, grid, bc.boundary_type, t, u, omega=scheme.omega
scheme, grid, bc.boundary_type, t, u, u, omega=scheme.omega
)


Expand Down
4 changes: 2 additions & 2 deletions pyshocks/burgers/ssweno.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def _apply_operator_burgers_ssweno242(
fm = (f - alpha * u) / 2

# reconstruct
fp, _ = reconstruct(scheme.rec, grid, bc.boundary_type, fp)
_, fm = reconstruct(scheme.rec, grid, bc.boundary_type, fm)
fp, _ = reconstruct(scheme.rec, grid, bc.boundary_type, fp, u, u)
_, fm = reconstruct(scheme.rec, grid, bc.boundary_type, fm, u, u)

# {{{ inviscid flux

Expand Down
5 changes: 3 additions & 2 deletions pyshocks/continuity/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ def _numerical_flux_continuity_godunov(

from pyshocks.reconstruction import reconstruct

ul, ur = reconstruct(scheme.rec, grid, bc.boundary_type, u)
al, ar = reconstruct(scheme.rec, grid, bc.boundary_type, scheme.velocity)
a = scheme.velocity
ul, ur = reconstruct(scheme.rec, grid, bc.boundary_type, u, u, a)
al, ar = reconstruct(scheme.rec, grid, bc.boundary_type, a, a, a)

aavg = (ar[:-1] + al[1:]) / 2
fnum = jnp.where(aavg > 0, ar[:-1] * ur[:-1], al[1:] * ul[1:])
Expand Down
146 changes: 100 additions & 46 deletions pyshocks/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
if TYPE_CHECKING:
from pyshocks.schemes import BoundaryType

# {{{

# {{{ interface


@dataclass(frozen=True)
Expand Down Expand Up @@ -75,28 +76,41 @@ def stencil_width(self) -> int:

@singledispatch
def reconstruct(
rec: Reconstruction, grid: Grid, bc: "BoundaryType", u: jnp.ndarray
rec: Reconstruction,
grid: Grid,
bc: "BoundaryType",
f: jnp.ndarray,
u: jnp.ndarray,
wavespeed: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
r"""Reconstruct face values from the cell-averaged values *u*.
r"""Reconstruct face values from the cell-averaged values *f* of a
function of *u*.
In this implementation, we use the convention that::
i - 1 i i + 1
--------------|--------------|--------------
u^R_{i - 1} u^R_i
u^L_i u^L_{i + 1}
f^R_{i - 1} f^R_i
f^L_i f^L_{i + 1}
i.e. :math:`u^R_i` refers to the right reconstructed value in cell :math:`i`
and represents the value at :math:`x_{i + \frac{1}{2}}` and :math:`u^L_i`
i.e. :math:`f^R_i` refers to the right reconstructed value in cell :math:`i`
and represents the value at :math:`x_{i + \frac{1}{2}}` and :math:`f^L_i`
refers the left reconstructed value in the cell :math:`i` and represents
the value at :math:`x_{i - \frac{1}{2}}`.
Note that this notation can be directly interpreted in a finite difference
setting, where the :math:`u_i, f_i` and :math:`w_i` are point values at
cell faces used to reconstruct an in-cell quantity.
:arg rec: a reconstruction algorithm.
:arg grid: the computational grid representation.
:arg bc: generic type of the boundary required to build the reconstruction
:arg u: variable to reconstruct at the left and right cell faces.
:arg f: variable to reconstruct as a function of *u*.
:arg u: base variable used in the reconstruction.
:arg vavespeed: wave speed with which *u* is transported, that can be used
to additionally upwind the reconstruction of *f*.
:returns: a :class:`tuple` of ``(ul, ur)`` containing a reconstructed
:returns: a :class:`tuple` of ``(fl, fr)`` containing a reconstructed
state on the left and right side of each cell face. The dimension
of the returned arrays matches :attr:`pyshocks.Grid.x`.
"""
Expand All @@ -106,7 +120,7 @@ def reconstruct(
# }}}


# {{{ first-order
# {{{ first-order: constant reconstruction


@dataclass(frozen=True)
Expand All @@ -117,7 +131,7 @@ class ConstantReconstruction(Reconstruction):
.. math::
(u^R_i, u^L_i) = (u_i, u_i).
(f^R_i, f^L_i) = (f_i, f_i).
which results in a first-order scheme.
"""
Expand All @@ -137,16 +151,21 @@ def stencil_width(self) -> int:

@reconstruct.register(ConstantReconstruction)
def _reconstruct_first_order(
rec: ConstantReconstruction, grid: Grid, bc: "BoundaryType", u: jnp.ndarray
rec: ConstantReconstruction,
grid: Grid,
bc: "BoundaryType",
f: jnp.ndarray,
u: jnp.ndarray,
wavespeed: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
assert grid.nghosts >= rec.stencil_width
return u, u
return f, f


# }}}


# {{{ MUSCL
# {{{ MUSCL: Flux Limiter


@dataclass(frozen=True)
Expand All @@ -160,13 +179,15 @@ class MUSCL(Reconstruction):
.. math::
\begin{aligned}
u^R_i =\,\, &
u_i + \frac{\phi(r_i)}{2} (u_{i + 1} - u_i), \\
u^L_i =\,\, &
u_i - \frac{\phi(r_i^{-1})}{2} (u_i - u_{i - 1}), \\
f^R_i =\,\, &
f_i + \frac{\phi(r_i)}{2} (f_{i + 1} - f_i), \\
f^L_i =\,\, &
f_i - \frac{\phi(r_i^{-1})}{2} (f_i - f_{i - 1}), \\
\end{aligned}
where :math:`\phi` is a limiter given by :attr:`lm`.
where :math:`\phi` is a limiter given by :attr:`lm`. Note that here
:math:`f \equiv f(u)` and the limiter :math:`\phi` is computed based on
the underlying variable :math:`u`.
.. attribute:: lm
Expand All @@ -191,7 +212,12 @@ def stencil_width(self) -> int:

@reconstruct.register(MUSCL)
def _reconstruct_muscl(
rec: MUSCL, grid: Grid, bc: "BoundaryType", u: jnp.ndarray
rec: MUSCL,
grid: Grid,
bc: "BoundaryType",
f: jnp.ndarray,
u: jnp.ndarray,
wavespeed: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
from pyshocks import UniformGrid

Expand All @@ -204,6 +230,7 @@ def _reconstruct_muscl(

from pyshocks.limiters import evaluate, local_slope_ratio

# FIXME: should this be f?
r = jnp.pad(local_slope_ratio(u, atol=rec.atol), 1)
phi_r = evaluate(rec.lm, r)

Expand All @@ -214,16 +241,16 @@ def _reconstruct_muscl(
inv_r = jnp.where(jnp.abs(r) < rec.atol, rec.atol, 1 / r)
phi_inv_r = evaluate(rec.lm, inv_r)

ur = jnp.pad(u[:-1] + 0.5 * phi_r[:-1] * (u[1:] - u[:-1]), (0, 1))
ul = jnp.pad(u[1:] - 0.5 * phi_inv_r[1:] * (u[1:] - u[:-1]), (1, 0))
ur = jnp.pad(f[:-1] + 0.5 * phi_r[:-1] * (f[1:] - f[:-1]), (0, 1))
ul = jnp.pad(f[1:] - 0.5 * phi_inv_r[1:] * (f[1:] - f[:-1]), (1, 0))

return ul, ur


# }}}


# {{{ MUSCL-slope
# {{{ MUSCL: Slope Limiter


@dataclass(frozen=True)
Expand All @@ -235,12 +262,21 @@ class MUSCLS(MUSCL):
approximation in each cell. The results are similar in most cases, but
it is a bit simpler to construct. Note that not all limiters support
:func:`~pyshocks.limiters.slope_limit`.
Also worth noting is that, unlike the :class:`MUSCL` reconstruction, the
limiter uses the function value *f* to construct the slope, not the
underlying variable *u*. This is imposed by the formulation.
"""


@reconstruct.register(MUSCLS)
def _reconstruct_muscls(
rec: MUSCLS, grid: Grid, bc: "BoundaryType", u: jnp.ndarray
rec: MUSCLS,
grid: Grid,
bc: "BoundaryType",
f: jnp.ndarray,
u: jnp.ndarray,
wavespeed: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
from pyshocks import UniformGrid

Expand All @@ -253,10 +289,10 @@ def _reconstruct_muscls(

from pyshocks.limiters import slope_limit

du = slope_limit(rec.lm, grid, u)
du = slope_limit(rec.lm, grid, f)

ur = u + 0.5 * grid.dx * du
ul = u - 0.5 * grid.dx * du
ur = f + 0.5 * grid.dx * du
ul = f - 0.5 * grid.dx * du

return ul, ur

Expand Down Expand Up @@ -316,16 +352,21 @@ def order(self) -> int:
return 3


def _reconstruct_weno_js_side(rec: WENOJS, u: jnp.ndarray) -> jnp.ndarray:
omega = weno.weno_js_weights(rec.s, u, eps=rec.eps)
uhat = weno.weno_reconstruct(rec.s, u)
def _reconstruct_weno_js_side(rec: WENOJS, f: jnp.ndarray) -> jnp.ndarray:
omega = weno.weno_js_weights(rec.s, f, eps=rec.eps)
uhat = weno.weno_reconstruct(rec.s, f)

return jnp.sum(omega * uhat, axis=0)


@reconstruct.register(WENOJS)
def _reconstruct_wenojs(
rec: WENOJS, grid: Grid, bc: "BoundaryType", u: jnp.ndarray
rec: WENOJS,
grid: Grid,
bc: "BoundaryType",
f: jnp.ndarray,
u: jnp.ndarray,
wavespeed: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
from pyshocks import UniformGrid

Expand All @@ -334,8 +375,8 @@ def _reconstruct_wenojs(
if not isinstance(grid, UniformGrid):
raise NotImplementedError("WENO-JS is only implemented for uniform grids")

ur = _reconstruct_weno_js_side(rec, u)
ul = _reconstruct_weno_js_side(rec, u[::-1])[::-1]
ur = _reconstruct_weno_js_side(rec, f)
ul = _reconstruct_weno_js_side(rec, f[::-1])[::-1]

return ul, ur

Expand Down Expand Up @@ -373,16 +414,21 @@ def stencil_width(self) -> int:
return 2


def _reconstruct_es_weno_side(rec: ESWENO32, u: jnp.ndarray) -> jnp.ndarray:
omega = weno.es_weno_weights(rec.s, u, eps=rec.eps)
uhat = weno.weno_reconstruct(rec.s, u)
def _reconstruct_es_weno_side(rec: ESWENO32, f: jnp.ndarray) -> jnp.ndarray:
omega = weno.es_weno_weights(rec.s, f, eps=rec.eps)
uhat = weno.weno_reconstruct(rec.s, f)

return jnp.sum(omega * uhat, axis=0)


@reconstruct.register(ESWENO32)
def _reconstruct_esweno32(
rec: ESWENO32, grid: Grid, bc: "BoundaryType", u: jnp.ndarray
rec: ESWENO32,
grid: Grid,
bc: "BoundaryType",
f: jnp.ndarray,
u: jnp.ndarray,
wavespeed: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
from pyshocks import UniformGrid

Expand All @@ -391,8 +437,8 @@ def _reconstruct_esweno32(
if not isinstance(grid, UniformGrid):
raise NotImplementedError("ESWENO is only implemented for uniform grids")

ur = _reconstruct_es_weno_side(rec, u)
ul = _reconstruct_es_weno_side(rec, u[::-1])[::-1]
ur = _reconstruct_es_weno_side(rec, f)
ul = _reconstruct_es_weno_side(rec, f[::-1])[::-1]

return ul, ur

Expand Down Expand Up @@ -434,20 +480,23 @@ def stencil_width(self) -> int:


def _reconstruct_ss_weno_side(
rec: SSWENO242, grid: Grid, bc: "BoundaryType", u: jnp.ndarray
rec: SSWENO242,
grid: Grid,
bc: "BoundaryType",
f: jnp.ndarray,
) -> jnp.ndarray:
if grid.nghosts >= rec.stencil_width:
w = u
w = f
else:
assert grid.nghosts == 0

# FIXME: put this in the weno code to "prepare" for WENO
from pyshocks.schemes import BoundaryType

if bc == BoundaryType.Periodic:
w = jnp.pad(u, rec.stencil_width, mode="wrap")
w = jnp.pad(f, rec.stencil_width, mode="wrap")
else:
w = jnp.pad(u, rec.stencil_width, constant_values=jnp.inf)
w = jnp.pad(f, rec.stencil_width, constant_values=jnp.inf)

omega = weno.ss_weno_242_weights(rec.si, w, eps=rec.eps)
what = weno.weno_reconstruct(rec.si, w)
Expand All @@ -457,15 +506,20 @@ def _reconstruct_ss_weno_side(

@reconstruct.register(SSWENO242)
def _reconstruct_ssweno242(
rec: SSWENO242, grid: Grid, bc: "BoundaryType", u: jnp.ndarray
rec: SSWENO242,
grid: Grid,
bc: "BoundaryType",
f: jnp.ndarray,
u: jnp.ndarray,
wavespeed: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
from pyshocks import UniformGrid

if not isinstance(grid, UniformGrid):
raise NotImplementedError("SSWENO is only implemented for uniform grids")

ur = _reconstruct_ss_weno_side(rec, grid, bc, u)
ul = _reconstruct_ss_weno_side(rec, grid, bc, u[::-1])[::-1]
ur = _reconstruct_ss_weno_side(rec, grid, bc, f)
ul = _reconstruct_ss_weno_side(rec, grid, bc, f[::-1])[::-1]

return ul, ur

Expand Down
Loading

0 comments on commit 4f52871

Please sign in to comment.