In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from kinetix import (
    Advection,
    Cells,
    Dispersion,
    FixedConcentrationBoundary,
    System,
    make_solver,
    declare_species,
    KineticReaction,
    reaction,
    SpatiallyConst,
    SpatiallyVarying
)

from dataclasses import dataclass

In [None]:
Species = declare_species(["tracer", "reactive_tracer"])

In [None]:
@reaction
class FirstOrderDecay(KineticReaction):
    decay_coefficient: jax.Array

    def rate(self, time, state, system):
        return self.decay_coefficient * state.reactive_tracer

    def stoichiometry(self, time, state, system):
        return {
            "reactive_tracer": -1,
        }

In [None]:
y = Species(tracer=1, reactive_tracer=5)
FirstOrderDecay(decay_coefficient=0.3)._eval_dcdt(0, y, None)

In [None]:
n_cells = 200
FirstOrderDecay(decay_coefficient=SpatiallyVarying(jnp.ones(n_cells)))

In [None]:
reactions = [FirstOrderDecay(decay_coefficient=1/1000)]

In [None]:
jax.config.update("jax_enable_x64", True)
n_cells = 200
interface_areas = jnp.ones(n_cells + 1)
#interface_areas = interface_areas.at[100:].set(2)
cells = Cells.equally_spaced(10, n_cells, interface_area=interface_areas)
dispersion = Dispersion.build(
    cells=cells,
    dispersivity=jnp.array(0.1),
    pore_diffusion=Species(
        tracer=jnp.array(1e-9 * 3600 * 24),
        reactive_tracer=jnp.array(1e-9 * 3600 * 24),
    ),
)
advection = Advection.build(
    limiter_type="minmod",
)
bcs = [
    FixedConcentrationBoundary(
        boundary="left",
        species_selector=lambda s: getattr(s, "tracer"),
        fixed_concentration=lambda t: jnp.array(10.0),
    ),
    FixedConcentrationBoundary(
        boundary="right",
        species_selector=lambda s: getattr(s, "tracer"),
        fixed_concentration=lambda t: jnp.array(3.0),
    ),
    FixedConcentrationBoundary(
        boundary="left",
        species_selector=lambda s: getattr(s, "reactive_tracer"),
        fixed_concentration=lambda t: jnp.array(10.0),
    ),
    FixedConcentrationBoundary(
        boundary="right",
        species_selector=lambda s: getattr(s, "reactive_tracer"),
        fixed_concentration=lambda t: jnp.array(3.0),
    )
]

porosity= jnp.ones(n_cells) * 0.3
porosity = porosity.at[100:].set(0.1)
system = System.build(
    porosity=porosity,
    # velocity=lambda t: jnp.array(1 / 365) * jnp.sin(np.pi * 2 * 1 / 5000 * t),
    discharge=lambda t: jnp.array(1 / 365) * 0.3,
    cells=cells,
    advection=advection,
    dispersion=dispersion,
    bcs=bcs,
    reactions=reactions
)

In [None]:
t_points = jnp.linspace(0, 8000, 123)
solver = make_solver(t_max=8000, t_points=t_points, rtol=1e-3, atol=1e-3)
val0 = jnp.zeros(cells.n_cells)
#val0 = val0.at[slice(10,20)].set(10.0)

state = Species(
    tracer=val0,
    reactive_tracer=val0
)


In [None]:
solution = solver(state, system)
%timeit solution = solver(state, system)

In [None]:
plt.plot(cells.centers[:], solution.ys.tracer.T[:,0::10]);
plt.show()

In [None]:
plt.plot(cells.centers[:], solution.ys.reactive_tracer.T[:,0::10]);
plt.show()

In [None]:
%matplotlib widget
from matplotlib import animation, collections

collections.Collection()
fig, ax = plt.subplots()

artists = []
for data in zip(solution.ys.tracer, solution.ys.reactive_tracer):
    containers = [ax.plot(cells.centers, y, color=f"C{i}") for i, y in enumerate(data)]
    artist = []
    for container in containers:
        artist.extend(container)
    artists.append(artist)


ani = animation.ArtistAnimation(fig=fig, artists=artists, interval=40)
plt.show()

In [None]:
mass_in = system.discharge(solution.ts) * bcs[0].fixed_concentration(solution.ts) * solution.ts
mass_in_system = (solution.ys.tracer * cells.cell_area * cells.face_distances * system.porosity).sum(axis=1)

In [None]:
plt.plot(solution.ts, mass_in)
plt.plot(solution.ts, mass_in_system)
plt.show()