In [39]:
import jax
import jax.numpy as jnp
import equinox as eqx

from dataclasses import dataclass, field
from typing import Callable

In [9]:
jax.config.update("jax_enable_x64", True)

In [40]:
@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class Species:
    tracer: jax.Array

    @classmethod
    def zeros(cls) -> "Species":
        return Species(tracer=jnp.zeros(()))

    def add(self, name, value) -> "Species":
        return dataclasses.replace(self, **{name: value + getattr(self, name)})


@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class Cells:
    n_cells: int
    # The x coordinate of the center of each cell. Shape = (n_cells,)
    centers: jax.Array
    # The x coordinate of the point between cells,
    # and the first and last boundary. Shape = (n_cells + 1,)
    nodes: jax.Array

    # TODO add area for each node

    def cell_length(self) -> jax.Array:
        return self.node[1:] - self.node[:-1]

    @property
    def face_distances(self):
        """
        Returns the distance between cell boundaries (dx for flux divergence).
        Shape = (n_cells,)
        """
        return self.nodes[1:] - self.nodes[:-1]

    @property
    def center_distances(self):
        """
        Returns the distance between cell centers (dx for slope computation).
        Shape = (n_cells - 1,)
        """
        return self.centers[1:] - self.centers[:-1]

    @classmethod
    def equally_spaced(cls, length, n_cells):
        nodes = jnp.linspace(0, length, n_cells + 1)
        return Cells(
            n_cells=n_cells,
            nodes=nodes,
            centers=(nodes[1:] - nodes[:-1]) / 2 + nodes[:-1],
        )


@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class System:
    porosity: jax.Array
    velocity: jax.Array
    cells: Cells
    advection: Advection
    dispersion: Dispersion
    bcs: list[BoundaryCondition] = field(default_factory=list)  # avoid shared mutable default!

    # TODO get advection and dispersion
    # TODO add reactions
    # TODO rename to Model, but maybe some attrisbutes in System or so?
    # TODO porosity depending in position?

    # retardation_factor: Species




# TODO delete?
@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class BoundaryCondition:
    def left_flux(self, system: System) -> jax.Array:
        raise NotImplementedError

    def right_flux(self, system: System) -> jax.Array:
        raise NotImplementedError

@dataclass(frozen=True)
class FixedConcentrationCondition(BoundaryCondition):
    concentration: float

    def left_flux(self, interior_value: jax.Array, system: System) -> jax.Array:
        velocity = system.velocity
        return jnp.where(
            velocity > 0,
            velocity * self.concentration,  # inflow
            velocity * interior_value       # outflow
        )

    def right_flux(self, interior_value: jax.Array, system: System) -> jax.Array:
        velocity = system.velocity
        return jnp.where(
            velocity < 0,
            velocity * self.concentration,  # inflow (from right)
            velocity * interior_value       # outflow
        )


@dataclass(frozen=True)
class FixedFluxCondition(BoundaryCondition):
    flux_value: float

    def left_flux(self, interior_value: jax.Array, system: System) -> jax.Array:
        return jnp.array(self.flux_value)

    def right_flux(self, interior_value: jax.Array, system: System) -> jax.Array:
        return jnp.array(self.flux_value)

In [41]:
@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class Advection:
    limiter_type: str = "minmod"  # Options: "minmod", "upwind", "MC"

    def rate(self, state: Species, system: System) -> Species:
        def flat_rate(concentration):
            cells = system.cells
            dx_center = cells.center_distances  # (n_cells - 1,)
            dx_face = cells.face_distances      # (n_cells,)

            # Compute slope with chosen limiter
            slope = self.compute_slope(concentration, dx_center)

            # Compute left and right states at interfaces
            QL = concentration - 0.5 * slope * dx_face # left state for cell i
            QR = concentration + 0.5 * slope * dx_face # right state for cell i 

            # Internal faces (i = 1 to n_cells - 1)
            left_state  = QR[:-1]   # right side of cell i-1
            right_state = QL[1:]    # left side of cell i
            internal_flux = self.upwind_flux(left_state, right_state, system.velocity)

            # Handle boundary fluxes using boundary condition object
            left_flux = np.array(0.0)
            right_flux = np.array(0.0)

            full_flux = jnp.concatenate([left_flux[None], internal_flux, right_flux[None]])

            # Compute flux divergence (flux differences over each cell)
            # TODO use area and porocity
            # TODO better names
            flux_div = (full_flux[:-1] - full_flux[1:]) / dx_face

            return flux_div

        return jax.tree.map(flat_rate, state)

    def compute_slope(self, concentration, dx_center):
        """Compute limited slope with padding at boundaries."""
        delta = concentration[1:] - concentration[:-1]
        raw_slope = delta / dx_center
    
        if self.limiter_type == "minmod":
            a = raw_slope[:-1]
            b = raw_slope[1:]
            limited = self.minmod(a, b)
            slope = jnp.concatenate([jnp.array([0.0]), limited, jnp.array([0.0])])
            return slope
    
        elif self.limiter_type == "upwind":
            return jnp.zeros_like(concentration)
    
        elif self.limiter_type == "MC":
            a = raw_slope[:-1]
            b = raw_slope[1:]
            limited = self.mc_limiter(a, b)
            slope = jnp.concatenate([jnp.array([0.0]), limited, jnp.array([0.0])])
            return slope
    
        else:
            raise ValueError(f"Unknown limiter type: {self.limiter_type}")

    @staticmethod
    def upwind_flux(left_state, right_state, velocity):
        """Simple upwind flux function."""
        # TODO porocity?
        return jnp.where(velocity >= 0, velocity * left_state, velocity * right_state)

    @staticmethod
    def minmod(a, b):
        """Standard minmod limiter."""
        cond = (jnp.sign(a) == jnp.sign(b))
        s = jnp.sign(a)
        return jnp.where(cond, s * jnp.minimum(jnp.abs(a), jnp.abs(b)), 0.0)

    @staticmethod
    def minmod3(a, b, c):
        """Minmod of three values, for MC limiter."""
        cond1 = (jnp.sign(a) == jnp.sign(b)) & (jnp.sign(b) == jnp.sign(c))
        s = jnp.sign(a)
        return jnp.where(cond1, s * jnp.minimum(jnp.abs(a), jnp.minimum(jnp.abs(b), jnp.abs(c))), 0.0)

    def mc_limiter(self, a, b):
        """Monotonized central (MC) limiter."""
        return self.minmod3(2 * a, 0.5 * (a + b), 2 * b)


In [42]:
@dataclass(frozen=True)
class BoundaryCondition:
    is_active: Callable[(jax.Array, System), jax.Array]
    species_selector: Callable
    left: bool

    def apply(self, t, system, state, rate, apply_count):
        species_rate = self.species_selector(rate)
        species_state = self.species_selector(state)
        species_apply_count = self.species_selector(apply_count)

        location = 0 if self.left else -1
        flux = self.compute_flux(t, system, state[location])

        active_val = species_rate.at[location].add(flux)
        inactive_val = species_rate

        active_apply_count = species_apply_count.at[location].add(1)
        inactive_apply_count = species_apply_count

        is_active = self.is_active(t, system)
        new_val = jax.lax.select(is_active, active_val, inactive_val)
        new_apply_count = jax.lax.select(is_active, active_apply_count, inactive_apply_count)

        new_rate = eqx.tree_at(self.species, rate, new_val)
        new_apply_count = eqx.tree_at(self.species, apply_count, new_apply_count)
        
        return new_rate, new_apply_count

    def compute_flux(self, t, system, state):
        raise NotImplementedError()


@dataclass(frozen=True)
class FixedConcentrationBoundary(BoundaryCondition):
    fixed_concentration: float | Callable[jax.Array, jax.Array]

    def compute_flux(self, t, system, boundary_cell_state):
        # TODO use area and porosity
        if isinstance(self.fixed_concentration, float):
            fixed_concentration = self.fixed_concentration
        else:
            fixed_concentration = self.fixed_concentration(t)

        if self.left:
            c_interface = jax.lax.select(
                system.velocity > 0,
                fixed_concentration,
                boundary_cell_state,
            )
            advection_sign = 1
        else:
            c_interface = jax.lax.select(
                system.velocity > 0,
                boundary_cell_state,
                fixed_concentration,
            )
            advection_sign = -1

        advection = advection_sign * system.velocity * c_interface

        # dispersion flow
        diff = boundary_cell_state - fixed_concentration
        location = 0 if self.left else -1
        dx = system.cells.face_distances[location] / 2

        return advection - diff * system.dispersion.coeff / dx


def apply_bcs(bcs, t, system, state, rate):
    apply_count = zeros_like(rate)
    for bc in bcs:
        rate, apply_count = bc.apply(t, system, state, rate, apply_count)
    rate = eqx.error_if(rate, (apply_count > 1).any())
    return rate

bcs = [
    FixedConcentrationBoundary(
        is_active=lambda t, system: t > 10,
        left=True,
        species_selector=lambda s: getattr(s, "tracer"),
        fixed_concentration=lambda t: 5,
    ),
]

In [43]:
bcs

[FixedConcentrationBoundary(is_active=<function <lambda> at 0x75920c533060>, species_selector=<function <lambda> at 0x75920c587060>, left=True, fixed_concentration=<function <lambda> at 0x75920c587100>)]

In [44]:
@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class Dispersion:
    # Longitudinal dispersivity
    dispersivity: jax.Array
    # pore diffusion coefficient
    pore_diffusion: Species

    def rate(self, state: Species, system: System) -> Species:
        def flat_rate(concentration, pore_diffusion):
            cells = system.cells
            # TODO add areas and porosity
            coef = jnp.abs(system.velocity) * self.dispersivity + pore_diffusion

            diffs = jnp.diff(concentration)
            dc_dx = diffs / (cells.centers[1:] - cells.centers[:-1])
            dx = cells.nodes[1:] - cells.nodes[:-1]
            flux = -dc_dx * coef

            return (
                jnp.concatenate(
                    [
                        jnp.array(0.0)[None],
                        flux,
                    ]
                )
                - jnp.concatenate(
                    [
                        flux,
                        jnp.array(0.0)[None],
                    ]
                )
            ) / dx

        return jax.tree.map(
            flat_rate, state, self.pore_diffusion
        )

In [45]:
cells = Cells.equally_spaced(10, 200)

In [47]:
dispersion = Dispersion(
    dispersivity=0.1,
    pore_diffusion=Species(
        tracer=1e-9 * 3600 * 24,
    ),
)

In [48]:
advection = Advection(
    limiter_type="minmod"
)

In [50]:
system = System(
    porosity=0.3,
    velocity=-1 / 365,
    cells=cells,
    advection=Advection,
    dispersion=Dispersion
)

In [55]:
import diffrax
from functools import reduce

def rhs(t, state, system: System):
    rate = jax.tree.map(
        lambda a, d: a + d,
        system.advection.rate(state, system),
        system.dispersion.rate(state, system),
    )
    return apply_bcs(system.bcs, t, system, state, rate)


cpu_device = jax.devices("cpu")[0]


def make_solver(*, t_max, t_points, rtol=1e-8, atol=1e-8, solver=None, t0=0, dt0=None):
    if solver is None:
        # solver = diffrax.Dopri5()
        solver = diffrax.Tsit5()
        # root_finder = optimistix.Dogleg(rtol=1e-9, atol=1e-9, norm=optimistix.two_norm)
        # solver = diffrax.Kvaerno3(root_find_max_steps=10, root_finder=root_finder)
        # solver = diffrax.Kvaerno3()

    term = diffrax.ODETerm(rhs)
    stepsize_controller = diffrax.PIDController(
        rtol=rtol,
        atol=atol,
        #dtmax=
        # norm=optimistix.two_norm,
    )
    t_vals = diffrax.SaveAt(ts=t_points)

    @eqx.filter_jit(device=cpu_device)
    def solve(y0: Species, args):
        result = diffrax.diffeqsolve(
            term,
            solver,
            t0=t0,
            t1=t_max,
            dt0=dt0,
            y0=y0,
            saveat=t_vals,
            args=args,
            stepsize_controller=stepsize_controller,
            max_steps=1024 * 32 * 64,
        )
        return result

    return solve

In [56]:
t_points = jnp.linspace(0, 5000, 123)
solver = make_solver(t_max=5000, t_points=t_points, rtol=1e-3, atol=1e-3)

In [57]:
val0 = jnp.zeros(cells.n_cells)
#val0 = val0.at[slice(10,20)].set(10.0)

state = Species(
    tracer=val0,
)

solution = solver(state, system)

ValueError: Terms are not compatible with solver! Got:
ODETerm(vector_field=<function rhs>)
but expected:
diffrax.AbstractTerm
Note that terms are checked recursively: if you scroll up you may find a root-cause error that is more specific.

In [None]:
import matplotlib.pyplot as plt

In [None]:
solution.ys.tracer.sum(1)

In [None]:
plt.plot(cells.centers[:], solution.ys.tracer.T[:,0::10]);

In [None]:
import numpy as np

In [None]:
# numerical dispersion coefficient due to the upstream weighting (see EnviMod2 script page 91) (this is for a fully implicit scheme)
np.abs(system.velocity) * 0.1 / 2

In [None]:
dispersion.dispersivity * np.abs(system.velocity)