In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from reactix import (
    Advection,
    Cells,
    Dispersion,
    FixedConcentrationBoundary,
    System,
    make_solver,
    declare_species,
    KineticReaction,
    reaction,
    SpatiallyConst,
    SpatiallyVarying,
    user_system_parameters
)

from dataclasses import dataclass

In [None]:
Species = declare_species(["tracer", "mobile_pathogen", "attached_pathogen"])
species_is_mobile = Species(tracer=True, mobile_pathogen=True, attached_pathogen=False)

In [None]:
@reaction
class MobilePathogenDecay(KineticReaction):
    decay_coefficient: jax.Array

    def rate(self, time, state, system):
        return self.decay_coefficient * state.mobile_pathogen

    def stoichiometry(self, time, state, system):
        return {
            "mobile_pathogen": -1,
        }

@reaction
class AttachedPathogenDecay(KineticReaction):
    decay_coefficient: jax.Array

    def rate(self, time, state, system):
        return self.decay_coefficient * state.attached_pathogen

    def stoichiometry(self, time, state, system):
        return {
            "attached_pathogen": -1,
        }

@reaction
class Attachment(KineticReaction):
    attachment_coefficient: jax.Array

    def rate(self, time, state, system):
        return self.attachment_coefficient * state.mobile_pathogen

    def stoichiometry(self, time, state, system):
        bulk_density = system.parameters.bulk_density(system)
        return {
            "mobile_pathogen": -1,
            "attached_pathogen": system.porosity / bulk_density
        }
    
@reaction
class Detachment(KineticReaction):
    detachment_coefficient: jax.Array

    def rate(self, time, state, system):
        return self.detachment_coefficient * state.attached_pathogen

    def stoichiometry(self, time, state, system):
        bulk_density = system.parameters.bulk_density(system)
        return {
            "mobile_pathogen": bulk_density / system.porosity,
            "attached_pathogen": -1,
        }

In [None]:
n_cells = 200
reactions = [
    Attachment(attachment_coefficient=0.02),
    Detachment(detachment_coefficient=1e-2),
    MobilePathogenDecay(decay_coefficient=1e-5),
    AttachedPathogenDecay(decay_coefficient=1e-4),
]

In [None]:
@user_system_parameters
class SystemParameters:
    solid_density: jax.Array

    def bulk_density(self, system):
        return (1 - system.porosity) * self.solid_density

In [None]:
system_parameters = SystemParameters(solid_density=SpatiallyConst(jnp.array(2.65)))  # g/cm3

In [None]:
jax.config.update("jax_enable_x64", True)
interface_areas = jnp.ones(n_cells + 1)
#interface_areas = interface_areas.at[100:].set(2)
cells = Cells.equally_spaced(10, n_cells, interface_area=interface_areas)
dispersion = Dispersion.build(
    cells=cells,
    dispersivity=jnp.array(0.1),
    pore_diffusion=Species(
        tracer=jnp.array(1e-9 * 3600 * 24),
        mobile_pathogen=jnp.array(1e-9 * 3600 * 24),
        attached_pathogen=jnp.array(1e-9 * 3600 * 24),
    ),
)
advection = Advection.build(
    limiter_type="minmod",
)
bcs = [
    FixedConcentrationBoundary(
        boundary="left",
        species_selector=lambda s: getattr(s, "tracer"),
        fixed_concentration=lambda t: jnp.array(10.0),
    ),
    FixedConcentrationBoundary(
        boundary="right",
        species_selector=lambda s: getattr(s, "tracer"),
        fixed_concentration=lambda t: jnp.array(0.0),
    ),
    FixedConcentrationBoundary(
        boundary="left",
        species_selector=lambda s: getattr(s, "mobile_pathogen"),
        fixed_concentration=lambda t: jnp.array(10.0),
    ),
    FixedConcentrationBoundary(
        boundary="right",
        species_selector=lambda s: getattr(s, "mobile_pathogen"),
        fixed_concentration=lambda t: jnp.array(0.0),
    )
]

porosity= jnp.ones(n_cells) * 0.3
system = System.build(
    porosity=porosity,
    # velocity=lambda t: jnp.array(1 / 365) * jnp.sin(np.pi * 2 * 1 / 5000 * t),
    discharge=lambda t: jnp.array(1 / 365) * 0.3,
    cells=cells,
    advection=advection,
    dispersion=dispersion,
    species_is_mobile=species_is_mobile,
    bcs=bcs,
    reactions=reactions,
    parameters=system_parameters
)

In [None]:
t_points = jnp.linspace(0, 8000, 123)
solver = make_solver(t_max=8000, t_points=t_points, rtol=1e-3, atol=1e-3)
val0 = jnp.zeros(cells.n_cells)
#val0 = val0.at[slice(10,20)].set(10.0)

state = Species(
    tracer=val0,
    mobile_pathogen=val0,
    attached_pathogen=val0
)


In [None]:
solution = solver(state, system)
#%timeit solution = solver(state, system)

In [None]:
plt.plot(cells.centers[:], solution.ys.tracer.T[:,0::10]);
plt.show()

In [None]:
plt.plot(cells.centers[:], solution.ys.attached_pathogen.T[:,0::10]);
plt.show()

In [None]:
%matplotlib widget
from matplotlib import animation, collections

collections.Collection()
fig, ax = plt.subplots()

artists = []
for data in zip(solution.ys.tracer, solution.ys.mobile_pathogen, solution.ys.attached_pathogen):
    containers = [ax.plot(cells.centers, y, color=f"C{i}") for i, y in enumerate(data)]
    artist = []
    for container in containers:
        artist.extend(container)
    artists.append(artist)


ani = animation.ArtistAnimation(fig=fig, artists=artists, interval=40)
plt.show()