# Example 2: Simple 2D cell signaling model - showcase mass conservation issue

We model a reaction between the cell interior and cell membrane in a 2D geometry:
- Cyto - 2D cell "volume"
- PM - 1D cell boundary (represents plasma membrane)

Model from [Rangamani et al, 2013, Cell](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3874130/). A cytosolic species, "A", reacts with a species on the PM, "B", to form a new species on the PM, "X" (see full equations and diagram in main example 2).

Here, we show the enforcement of mass conservation in the case of a very small diffusion coefficient for cytosolic species "A". Without fixes to the original SMART code, we current violate mass conservation between species A and B (number of molecules A + number of molecules B should be constant). This fix requires us to interpolate "A" onto the plasma membrane domain when expressing certain terms in the variational form for FEniCS (updated on every Newton iteration).

Imports and logger initialization:

In [None]:
from matplotlib import pyplot as plt
import dolfin as d
import sympy as sym
import numpy as np
import pathlib
import logging
import gmsh  # must be imported before pyvista if dolfin is imported first
import time

from smart import config, common, mesh, model, mesh_tools, visualization
from smart.units import unit
from smart.model_assembly import (
    Compartment,
    Parameter,
    Reaction,
    Species,
    SpeciesContainer,
    ParameterContainer,
    CompartmentContainer,
    ReactionContainer,
)
from matplotlib import pyplot as plt
import matplotlib.image as mpimg

logger = logging.getLogger("smart")
logger.setLevel(logging.INFO)

First, we define the various units for use in the model.

In [None]:
um = unit.um
molecule = unit.molecule
sec = unit.sec
dimensionless = unit.dimensionless
D_unit = um**2 / sec
surf_unit = molecule / um**2
flux_unit = molecule / (um * sec)
edge_unit = molecule / um

Next we generate the model by assembling the compartment, species, parameter, and reaction containers (see Example 1 or API documentation for more details).

In [None]:
# =============================================================================================
# Compartments
# =============================================================================================
# name, topological dimensionality, length scale units, marker value
Cyto = Compartment("Cyto", 2, um, 1)
PM = Compartment("PM", 1, um, 3)
cc = CompartmentContainer()
cc.add([Cyto, PM])

# =============================================================================================
# Species
# =============================================================================================
# name, initial concentration, concentration units, diffusion, diffusion units, compartment
A = Species("A", 1.0, surf_unit, 0.01, D_unit, "Cyto")
X = Species("X", 1.0, edge_unit, 1.0, D_unit, "PM")
B = Species("B", 0.0, edge_unit, 1.0, D_unit, "PM")
sc = SpeciesContainer()
sc.add([A, X, B])

# =============================================================================================
# Parameters and Reactions
# =============================================================================================

# Reaction of A and X to make B (Cyto-PM reaction)
kon = Parameter("kon", 1.0, 1/(surf_unit*sec))
koff = Parameter("koff", 1.0, 1/sec)
r1 = Reaction("r1", ["A", "X"], ["B"],
              param_map={"on": "kon", "off": "koff"},
              species_map={"A": "A", "X": "X", "B": "B"})

pc = ParameterContainer()
pc.add([kon, koff])
rc = ReactionContainer()
rc.add([r1])

Now we create a circular mesh (mesh built using gmsh in `smart.mesh_tools`), along with marker functions `mf2` and `mf1`.

In [None]:
# Create mesh
hEdge = 0.05
curRad = 1.0
surf_tag = 1
edge_tag = 3
ellipse_mesh, mf1, mf2 = mesh_tools.create_ellipses(curRad, curRad, hEdge=hEdge,
                                                    outer_tag=surf_tag, outer_marker=edge_tag)
visualization.plot_dolfin_mesh(ellipse_mesh, mf2, view_xy=True)

Write mesh and meshfunctions to file, then create `mesh.ParentMesh` object.

In [None]:
mesh_folder = pathlib.Path("ellipse_mesh_AR1")
mesh_folder.mkdir(exist_ok=True)
mesh_file = mesh_folder / "ellipse_mesh.h5"
mesh_tools.write_mesh(ellipse_mesh, mf1, mf2, mesh_file)

parent_mesh = mesh.ParentMesh(
    mesh_filename=str(mesh_file),
    mesh_filetype="hdf5",
    name="parent_mesh",
)

Solve the system with `enforce_mass_conservation` on vs. off and plot the results for the total number of molecules A + molecules B over time.

In [None]:
# Set loglevel to warning in order not to pollute notebook output
logger.setLevel(logging.WARNING)
enforce_mass_conservation = [True, False]
for i in range(len(enforce_mass_conservation)):
    config_cur = config.Config()
    config_cur.flags.update({
        "allow_unused_components": True, 
        "enforce_mass_conservation": enforce_mass_conservation[i]})
    config_cur.solver.update(
        {
            "final_t": 5.0,
            "initial_dt": 0.05,
            "time_precision": 6,
        }
    )
    model_cur = model.Model(pc, sc, cc, rc, config_cur, parent_mesh)
    model_cur.couple_odes = True

    model_cur.initialize()
    model_cur.to_pickle('model_cur.pkl')
    results = dict()
    result_folder = pathlib.Path(
        f"resultsConservTest_MassConserv{enforce_mass_conservation[i]}")
    result_folder.mkdir(exist_ok=True)
    for species_name, species in model_cur.sc.items:
        results[species_name] = d.XDMFFile(
            model_cur.mpi_comm_world, str(result_folder / f"{species_name}.xdmf")
        )
        results[species_name].parameters["flush_output"] = True
        results[species_name].write(model_cur.sc[species_name].u["u"], model_cur.t)
    tvec = [0]
    cytoMesh = model_cur.cc["Cyto"].dolfin_mesh
    dx_cyto = d.Measure("dx", domain=cytoMesh)
    # ds_cyto = d.Measure("ds", domain=cytoMesh)
    volume = d.assemble(1.0*dx_cyto)
    dx_pm = d.Measure("dx", domain=model_cur.cc["PM"].dolfin_mesh)
    sa = d.assemble(1.0*dx_pm)

    Atot_vec = [sc["A"].initial_condition * volume + sc["B"].initial_condition * sa]
    # Xtot_vec = [sc["X"].initial_condition * sa + sc["B"].initial_condition * sa]

    start_time = time.time()
    while True:
        # Solve the system
        model_cur.monolithic_solve()
        # Save results for post processing
        for species_name, species in model_cur.sc.items:
            results[species_name].write(model_cur.sc[species_name].u["u"], model_cur.t)
        int_val_A = d.assemble(sc["A"].u["u"]*dx_cyto)
        int_val_B = d.assemble(sc["B"].u["u"]*dx_pm)
        # int_val_X = d.assemble(sc["X"].u["u"]*dx_pm)
        Atot_vec.append(int_val_A + int_val_B)
        # Xtot_vec.append(int_val_X + int_val_B)
        tvec.append(model_cur.t)
        # End if we've passed the final time

        if model_cur.t >= model_cur.final_t:
            break
    end_time = time.time()
    print(f"Mass conservation = {enforce_mass_conservation[i]} time elapsed: {end_time - start_time}")
    plt.plot(tvec, 100*np.array(Atot_vec)/Atot_vec[0], 
            label=f"Mass conservation = {enforce_mass_conservation[i]}")
plt.xlabel('Time (s)')
plt.ylabel('Normalized molecule count (%)')
plt.legend()