In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx

from dataclasses import dataclass

In [None]:
import jax

jax.config.update("jax_enable_x64", True)

In [None]:
@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class Species:
    tracer: jax.Array

    @classmethod
    def zeros(cls) -> "Species":
        return Species(tracer=jnp.zeros(()))

    def add(self, name, value) -> "Species":
        return dataclasses.replace(self, **{name: value + getattr(self, name)})


@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class Cells:
    n_cells: int
    # The x coordinate of the center of each cell. Shape = (n_cells,)
    centers: jax.Array
    # The x coordinate of the point between cells,
    # and the first and last boundary. Shape = (n_cells + 1,)
    nodes: jax.Array

    def cell_length(self) -> jax.Array:
        return self.node[1:] - self.node[:-1]

    @classmethod
    def equally_spaced(cls, length, n_cells):
        nodes = jnp.linspace(0, length, n_cells + 1)
        return Cells(
            n_cells=n_cells,
            nodes=nodes,
            centers=(nodes[1:] - nodes[:-1]) / 2 + nodes[:-1],
        )


@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class System:
    porosity: jax.Array
    velocity: jax.Array
    cells: Cells
    # retardation_factor: Species


@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class BoundaryCondition:
    def left_flux(self) -> jax.Array:
        return jnp.array(0.0)

    def right_flux(self) -> jax.Array:
        return jnp.array(0.0)


@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class Advection:
    boundary_condition: Species

    def rate(self, state: Species, system: System) -> Species:
        # Rate for a single species
        def flat_rate(concentration, bc: BoundaryCondition):
            # Positive velocity â€“ central differentiation
            flux = system.velocity * (concentration[:-2] - concentration[2:]) / 2
            boundary_flux = bc.left_flux()[None]
            first_cell_rate = boundary_flux - system.velocity * (concentration[0] + concentration[1]) / 2
            # For the last cell, use upstream weighting, so I can compute the flux over the boundary
            last_cell_rate = jnp.array(system.velocity * (concentration[-2] - concentration[-1]))[None]
            pos_flux = jnp.concatenate([first_cell_rate, flux, last_cell_rate])

            # Negative velocity
            flux = system.velocity * (concentration[1:] - concentration[:-1])
            boundary_flux = bc.right_flux()[None]
            last_cell_rate = boundary_flux - system.velocity * concentration[-1]
            neg_flux = jnp.concatenate([flux, last_cell_rate])

            flux = jnp.where(system.velocity < 0, neg_flux, pos_flux)
            return flux / (cells.nodes[1:] - cells.nodes[:-1])

        return jax.tree.map(flat_rate, state, self.boundary_condition)


@jax.tree_util.register_dataclass
@dataclass(frozen=True)
class Dispersion:
    # Longitudinal dispersivity
    dispersivity: jax.Array
    # pore diffusion coefficient
    pore_diffusion: Species
    boundary_condition: Species

    def rate(self, state: Species, system: System) -> Species:
        def flat_rate(concentration, pore_diffusion, bc: BoundaryCondition):
            coef = jnp.abs(system.velocity) * self.dispersivity + pore_diffusion

            diffs = jnp.diff(concentration)
            dc_dx = diffs / (cells.centers[1:] - cells.centers[:-1])
            dx = cells.nodes[1:] - cells.nodes[:-1]
            flux = -dc_dx * coef

            return (
                jnp.concatenate(
                    [
                        bc.left_flux()[None],
                        flux,
                    ]
                )
                - jnp.concatenate(
                    [
                        flux,
                        bc.right_flux()[None],
                    ]
                )
            ) / dx

        return jax.tree.map(
            flat_rate, state, self.pore_diffusion, self.boundary_condition
        )

In [None]:
cells = Cells.equally_spaced(10, 100)

In [None]:
system = System(
    porosity=0.3,
    velocity=1 / 365,
    cells=cells,
)

In [None]:
dispersion = Dispersion(
    dispersivity=0.1,
    pore_diffusion=Species(
        tracer=1e-9 * 3600 * 24,
    ),
    boundary_condition=Species(
        tracer=BoundaryCondition(),
    ),
)

In [None]:
advection = Advection(
    boundary_condition=Species(
        tracer=BoundaryCondition(),
    ),
)

In [None]:
import diffrax
from functools import reduce

def rhs(t, state, args):
    #(system, advection, dispersion) = args
    (system, advection) = args

    #total_rate = jax.tree_map(lambda x, y: x + y, dispersion.rate(state, system), advection.rate(state, system))
    return advection.rate(state, system)


cpu_device = jax.devices("cpu")[0]


def make_solver(*, t_max, t_points, rtol=1e-8, atol=1e-8, solver=None, t0=0, dt0=None):
    if solver is None:
        # solver = diffrax.Dopri5()
        solver = diffrax.Tsit5()
        # root_finder = optimistix.Dogleg(rtol=1e-9, atol=1e-9, norm=optimistix.two_norm)
        # solver = diffrax.Kvaerno3(root_find_max_steps=10, root_finder=root_finder)
        # solver = diffrax.Kvaerno3()

    term = diffrax.ODETerm(rhs)
    stepsize_controller = diffrax.PIDController(
        rtol=rtol,
        atol=atol,
        #dtmax=
        # norm=optimistix.two_norm,
    )
    t_vals = diffrax.SaveAt(ts=t_points)

    @eqx.filter_jit(device=cpu_device)
    def solve(y0: Species, args):
        result = diffrax.diffeqsolve(
            term,
            solver,
            t0=t0,
            t1=t_max,
            dt0=dt0,
            y0=y0,
            saveat=t_vals,
            args=args,
            stepsize_controller=stepsize_controller,
            max_steps=1024 * 32 * 64,
        )
        return result

    return solve

In [None]:
cells.nodes

In [None]:
t_points = jnp.linspace(0, 5000, 123)
solver = make_solver(t_max=5000, t_points=t_points, rtol=1e-3, atol=1e-3)

In [None]:
val0 = jnp.zeros(cells.n_cells)
val0 = val0.at[slice(10,20)].set(10.0)

state = Species(
    tracer=val0,
)

solution = solver(state, (system, advection))

In [None]:
import matplotlib.pyplot as plt

In [None]:
solution.ys.tracer.sum(1)

In [None]:
plt.plot(cells.centers, solution.ys.tracer.T[:, ::10]);

In [None]:
import numpy as np

In [None]:
# numerical dispersion coefficient due to the upstream weighting (see EnviMod2 script page 91) (this is for a fully implicit scheme)
np.abs(system.velocity) * 0.1 / 2

In [None]:
dispersion.dispersivity * np.abs(system.velocity)