In [None]:
import jax
import jax.numpy as jnp

from kinetix.transport import (
    Advection,
    Cells,
    Dispersion,
    FixedConcentrationBoundary,
    Species,
    System,
    make_solver,
)

In [None]:
jax.config.update("jax_enable_x64", True)

In [None]:
cells = Cells.equally_spaced(10, 200)

In [None]:
dispersion = Dispersion(
    dispersivity=jnp.array(0.1),
    pore_diffusion=Species(
        tracer=jnp.array(1e-9 * 3600 * 24),
    ),
)

In [None]:
advection = Advection(
    limiter_type="minmod"
)

In [None]:
bcs = [
    FixedConcentrationBoundary(
        is_active=lambda t, system: t < 1500,
        left=True,
        species_selector=lambda s: getattr(s, "tracer"),
        fixed_concentration=lambda t: jnp.array(10.0),
    ),
]

In [None]:
system = System(
    porosity=jnp.array(0.3),
    velocity=jnp.array(1 / 365),
    cells=cells,
    advection=advection,
    dispersion=dispersion,
    bcs=bcs
)

In [None]:
t_points = jnp.linspace(0, 3000, 123)
solver = make_solver(t_max=5000, t_points=t_points, rtol=1e-3, atol=1e-3)

In [None]:
val0 = jnp.zeros(cells.n_cells)
#val0 = val0.at[slice(10,20)].set(10.0)

state = Species(
    tracer=val0,
)

solution = solver(state, system)

In [None]:
import matplotlib.pyplot as plt

In [None]:
solution.ys.tracer.sum(1)

In [None]:
plt.plot(cells.centers[:], solution.ys.tracer.T[:,0::10]);

In [None]:
import numpy as np

In [None]:
# numerical dispersion coefficient due to the upstream weighting (see EnviMod2 script page 91) (this is for a fully implicit scheme)
np.abs(system.velocity) * 0.1 / 2

In [None]:
dispersion.dispersivity * np.abs(system.velocity)