In [None]:
import jax
import jax.numpy as jnp
import numpy as np

from kinetix.transport import (
    Advection,
    Cells,
    Dispersion,
    FixedConcentrationBoundary,
    Species,
    System,
    make_solver,
)

In [None]:
jax.config.update("jax_enable_x64", True)

In [None]:
cells = Cells.equally_spaced(10, 200)

In [None]:
dispersion = Dispersion(
    dispersivity=jnp.array(0.1),
    pore_diffusion=Species(
        tracer=jnp.array(1e-9 * 3600 * 24),
    ),
)

In [None]:
advection = Advection(
    limiter_type="minmod"
)

In [None]:
bcs = [
    FixedConcentrationBoundary(
        is_active=lambda t, system: True,
        left=True,
        species_selector=lambda s: getattr(s, "tracer"),
        fixed_concentration=lambda t: jnp.array(10.0),
    ),
    FixedConcentrationBoundary(
        is_active=lambda t, system: True,
        left=False,
        species_selector=lambda s: getattr(s, "tracer"),
        fixed_concentration=lambda t: jnp.array(3.0),
    )
]

In [None]:
system = System(
    porosity=jnp.array(0.3),
    velocity=lambda t: jnp.array(1 / 365) * jnp.sin(np.pi * 2 * 1 / 5000 * t),
    cells=cells,
    advection=advection,
    dispersion=dispersion,
    bcs=bcs
)

In [None]:
t_points = jnp.linspace(0, 8000, 123)
solver = make_solver(t_max=8000, t_points=t_points, rtol=1e-3, atol=1e-3)

In [None]:
val0 = jnp.zeros(cells.n_cells)
#val0 = val0.at[slice(10,20)].set(10.0)

state = Species(
    tracer=val0,
)


In [None]:
solution = solver(state, system)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(cells.centers[:], solution.ys.tracer.T[:,0::10]);

In [None]:
%matplotlib widget
from matplotlib import animation


fig, ax = plt.subplots()

artists = []
for data in solution.ys.tracer:
    container = ax.plot(cells.centers, data, color="C0")
    artists.append(container)


ani = animation.ArtistAnimation(fig=fig, artists=artists, interval=40)
plt.show()