# Calcium dynamics in a dendritic spine

Here, we implement the model presented in [Bell et al 2019, Journal of General Physiology](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6683673/), in which calcium enters into a dendritic spine through the plasma membrane and is also taken up into a specialized type of endoplasmic reticulum (ER) known as the spine apparatus (SA).

The geometry in this model is divided into 4 domains - two volumes and two surfaces:
- plasma membrane (PM)
- Cytosol
- Spine apparatus membrane (SAm)
- Spine apparatus lumen (volume inside the SA)

In this model, the plasma membrane domain is further divided into multiple regions:
- postsynaptic density (PSD) - a region rich in receptors and channels associated with neuron activation
- "passive" PM - still contains other channels and receptors; for ease of use here, we label this PM

This model has four species:
- $\text{Ca}^{2+}$ in the cytosol
- fixed buffers ($b_f$) attached to the plasma membrane
- mobile buffers ($b_m$) throughout the cytosol
- $\text{Ca}^{2+}$ in the spine apparatus
```

In [None]:
import dolfin as d
import sympy as sym
import numpy as np
import logging
import argparse

from smart import config, mesh, model, mesh_tools, visualization
from smart.units import unit
from smart.model_assembly import (
    Compartment,
    Parameter,
    Reaction,
    Species,
    SpeciesContainer,
    CompartmentContainer,
    sbmodel_from_locals,
)

from matplotlib import pyplot as plt
from ca2_parser_args import add_run_dendritic_spine_arguments

logger = logging.getLogger("smart")
logger.setLevel(logging.DEBUG)

In [None]:
parser = argparse.ArgumentParser()
add_run_dendritic_spine_arguments(parser)
args = vars(parser.parse_args())

First, we define the various units for the inputs

In [None]:
# Aliases - base units
uM = unit.uM
um = unit.um
molecule = unit.molecule
sec = unit.sec
dimensionless = unit.dimensionless
# Aliases - units used in model
D_unit = um**2 / sec
flux_unit = uM * um / sec
vol_unit = uM
surf_unit = molecule / um**2
# electrical units
voltage_unit = unit.millivolt
current_unit = unit.picoampere
conductance_unit = unit.picosiemens

## Create and load in mesh

Here, we consider an "ellipsoid-in-an-ellipsoid" geometry. The inner ellipsoid represents the SA and the volume between the SA boundary and the boundary of the outer ellipsoid represents the cytosol.

In [None]:
spine_rad = 0.237
SA_rad = 0.08
ar_list = [27, 15, 7]
ar_1 = (1 / ((ar_list[1] / ar_list[0]) * (ar_list[2] / ar_list[0]))) ** (1 / 3)
ar_2 = ar_1 * ar_list[1] / ar_list[0]
ar_3 = ar_1 * ar_list[2] / ar_list[0]
z_PSD = 0.8 * spine_rad * ar_3
# spine_mesh, facet_markers, cell_markers = mesh_tools.create_ellipsoids(
#     (spine_rad * ar_1, spine_rad * ar_2, spine_rad * ar_3),
#     (SA_rad * ar_1, SA_rad * ar_2, SA_rad * ar_3),
#     hEdge=0.02,
# )
# Load mesh
parent_mesh = mesh.ParentMesh(
    mesh_filename=str(args["mesh_file"]),
    mesh_filetype="hdf5",
    name="parent_mesh",
)
geo = mesh_tools.load_mesh(filename=args["mesh_file"], mesh=parent_mesh.dolfin_mesh)
# facet_markers = dolfin.M
# facet_array = geo.mf_facet.array()[:]
# for i in range(len(facet_array)):
#     if (
#         facet_array[i] == 11
#     ):  # this indicates PSD; in this case, set to 10 to indicate it is a part of the PM
#         facet_array[i] = 10
mesh = geo.mesh

mf_cell = geo.mf_cell
mf_facet = geo.mf_facet

if args["num_refinements"] > 0:
    print(
        f"Original mesh has {mesh.num_cells()} cells, "
        f"{mesh.num_facets()} facets and "
        f"{mesh.num_vertices()} vertices"
    )
    d.parameters["refinement_algorithm"] = "plaza_with_parent_facets"
    for _ in range(args["num_refinements"]):
        mesh = d.adapt(mesh)
        mf_cell = d.adapt(mf_cell, mesh)
        mf_facet = d.adapt(mf_facet, mesh)
    print(
        f"Original mesh has {mesh.num_cells()} cells, "
        f"{mesh.num_facets()} facets and "
        f"{mesh.num_vertices()} vertices"
    )





integrateDomain = mf_facet
ds = d.Measure("ds", domain=mesh, subdomain_data=integrateDomain)
A_PSD = d.assemble(1.0 * ds(11))
visualization.plot_dolfin_mesh(mesh, mf_cell)

# ## Model generation
#
# For each step of model generation, refer to Example 3 or API documentation for further details.
#
# We first define compartments and the compartment container. Note that we can specify nonadjacency for surfaces in the model, which is not required, but can speed up the solution process.

In [None]:
Cyto = Compartment("Cyto", 3, um, 1)
PM = Compartment("PM", 2, um, 10)
SA = Compartment("SA", 3, um, 2)
SAm = Compartment("SAm", 2, um, 12)
PM.specify_nonadjacency(["SAm", "SA"])
SAm.specify_nonadjacency(["PM"])

cc = CompartmentContainer()
cc.add([Cyto, PM, SA, SAm])

Define species and place them in a species container. Note that `NMDAR` and `VSCC` are stationary PM surface variables, effectively just serving the role restricting NMDAR calcium influx to the PSD and VSCC influx to the spine (not the dendritic shaft)

In [None]:
Ca = Species("Ca", 0.1, vol_unit, 220.0, D_unit, "Cyto")
n_PMr = 0.1011  # vol to surf area ratio for a realistic dendritic spine
NMDAR_loc = f"(1 + sign(z - {z_PSD}))/2"
NMDAR = Species(
    "NMDAR", NMDAR_loc, dimensionless, 0.0, D_unit, "PM"
)  # specify species to localize NMDAR calcium influx to PSD
VSCC_zThresh = -10  # 0.3 #-0.25 for single spine, 0.3 for 2 spine
VSCC_loc = f"(1 + sign(z - {VSCC_zThresh}))/2"
VSCC = Species(
    "VSCC", VSCC_loc, dimensionless, 0.0, D_unit, "PM"
)  # specify species to localize VSCC calcium influx to spine body and neck

Bf = Species("Bf", 78.7 * n_PMr, vol_unit * um, 0.0, D_unit, "PM")
Bm = Species("Bm", 20.0, vol_unit, 20.0, D_unit, "Cyto")
CaSA = Species(
    "CaSA", 60.0, vol_unit, 6.27, D_unit, "SA"
)  # effective D due to buffering
sc = SpeciesContainer()
sc.add([Ca, NMDAR, Bf, Bm, CaSA])

Define parameters and reactions at the plasma membrane:
* a1: influx of calcium through N-methyl-D-aspartate receptors (NMDARs), localized to the PSD
* a2: calcium entry through voltage-sensitive calcium channels (VSCCs) throughout the PM
* a3: calcium efflux through PMCA (all PM)
* a4: calcium efflux through NCX (all PM)
* a5: calcium binding to immobilized buffers

Calcium entry through NMDARs and VSCCs are given by time-dependent functions. Each depends on the voltage over time, which is specified to match the expected dynamics due to the back-propagating action potential (BPAP) and the excitatory postsynaptic potential (EPSP):

$$
V_m (t) = V_{rest} + BPAP(t) + EPSP(t)\\
= V_{rest} + [BPAP]_{max} \left( I_{bsf} e^{-(t-t_{delay,bp})/t_{bsf}} + I_{bss} e^{-(t-t_{delay,bp})/t_{bss}}\right) + s_{term} \left( e^{-(t-t_{delay})/t_{ep1}} - e^{-(t-t_{delay})/t_{ep2}}\right)
$$


In [None]:
# Both NMDAR and VSCC fluxes depend on the voltage over time
Vrest_expr = -65
Vrest = Parameter("Vrest", Vrest_expr, voltage_unit)
bpmax = 38
Ibsf, tbsf, Ibss, tbss = 0.75, 0.003, 0.25, 0.025
tdelaybp, tdelay = 0.002, 0.0
sterm = 25
tep1, tep2 = 0.05, 0.005
t = sym.symbols("t")
BPAP = (
    bpmax
    * (Ibsf * sym.exp(-(t - tdelaybp) / tbsf) + Ibss * sym.exp(-(t - tdelaybp) / tbss))
    * (1 + sym.sign(t))
    / 2
)
EPSP = (
    sterm
    * (sym.exp(-(t - tdelay) / tep1) - sym.exp(-(t - tdelay) / tep2))
    * (1 + sym.sign(t))
    / 2
)
Vm_expr = Vrest_expr + BPAP + EPSP
Vm = Parameter.from_expression(
    "Vm",
    Vm_expr,
    voltage_unit,
    use_preintegration=False,
)
# Define known constants
N_A = 6.022e23  # molecules per mole
F = 96485.332  # Faraday's constant (Coulombs per mole)
Q = 1.602e-19  # Coulombs per elementary charge

n_PMr = Parameter("n_PMr", 0.1011, um)
# NMDAR calcium influx
P0 = 0.5
CaEC = 2  # mM
h = 11.3  # pS/mM
G0 = 65.6  # pS
r = 0.5
ginf = 15.2  # pS
convert_factor = 1e-15  # A/mV per pS
zeta_i = (G0 + r * CaEC * h) / (
    1 + r * CaEC * h / ginf
)  # single channel conductance in pS
G_NMDARVal = convert_factor * zeta_i / (2 * Q)
G_NMDAR = Parameter("G_NMDAR", G_NMDARVal, molecule / (voltage_unit * sec))
If, tau_f, Is, tau_s = 0.5, 0.05, 0.5, 0.2
Km = 0.092  # (1/mV)
Mg, MgScale = 1, 3.57  # mM
B_V = 1 / (1 + sym.exp(-Km * Vm_expr * Mg / MgScale))
gamma_i_scale = (
    P0 * (If * sym.exp(-t / tau_f) + Is * sym.exp(-t / tau_s)) * B_V * 1
)  # (1+sym.sign(t))/2
beta_NMDAR = 85
# A_PSD = 2*np.pi*spine_rad*(spine_rad - z_PSD)
J0_NMDAR_expr = gamma_i_scale / (beta_NMDAR * A_PSD)
J0_NMDAR = Parameter.from_expression(
    "J0_NMDAR",
    J0_NMDAR_expr,
    1 / um**2,
    use_preintegration=False,
)
a1 = Reaction(
    "a1",
    [],
    ["Ca"],
    species_map={"NMDAR": "NMDAR"},
    param_map={"J0": "J0_NMDAR", "G_NMDAR": "G_NMDAR", "Vm": "Vm", "Vrest": "Vrest"},
    eqn_f_str="J0*NMDAR*G_NMDAR*(Vm - Vrest)",
    explicit_restriction_to_domain="PM",
)
# VSCC calcium influx
gamma = 3.72  # pS
k_Ca = (
    -convert_factor
    * gamma
    * Vm_expr
    * N_A
    * (0.393 - sym.exp(-Vm_expr / 80.36))
    / (2 * F * (1 - sym.exp(Vm_expr / 80.36)))
)
alpha4, beta4 = 34700, 3680
VSCC_biexp = (sym.exp(-alpha4 * t) - sym.exp(-beta4 * t)) * (1 + sym.sign(t)) / 2
VSCCNum = 2  # molecules/um^2
J_VSCC = Parameter.from_expression(
    "J_VSCC",
    VSCCNum * k_Ca * VSCC_biexp,
    molecule / (um**2 * sec),
    use_preintegration=False,
)
a2 = Reaction(
    "a2",
    [],
    ["Ca"],
    species_map={"VSCC": "VSCC"},
    param_map={"J": "J_VSCC"},
    eqn_f_str="J*VSCC",
    explicit_restriction_to_domain="PM",
)
# PMCA
Prtote = Parameter("Prtote", 191, vol_unit)
Kme = Parameter("Kme", 2.43, vol_unit)
Kmx = Parameter("Kmx", 0.139, vol_unit)
Vmax_lr23 = Parameter("Vmax_lr23", 0.113, vol_unit / sec)
Km_lr23 = Parameter("Km_lr23", 0.442, vol_unit)
Vmax_hr23 = Parameter("Vmax_hr23", 0.59, vol_unit / sec)
Km_hr23 = Parameter("Km_hr23", 0.442, vol_unit)
beta_PMCA = 100
beta_i_str = (
    "(1 + Prtote*Kme/(Kme+c)**2 + Prtote*Kmx/(Kmx+c)**2)**(-1)"  # buffering term
)
PMCA_str = "Vmax_lr23*c**2/(Km_lr23**2 + c**2) + Vmax_hr23*c**5/(Km_hr23**5 + c**5)"
a3 = Reaction(
    "a3",
    ["Ca"],
    [],
    {
        "Prtote": "Prtote",
        "Kme": "Kme",
        "Kmx": "Kmx",
        "Vmax_lr23": "Vmax_lr23",
        "Km_lr23": "Km_lr23",
        "Vmax_hr23": "Vmax_hr23",
        "Km_hr23": "Km_hr23",
        "n_PMr": "n_PMr",
    },
    {"c": "Ca"},
    eqn_f_str=f"{beta_PMCA}*({beta_i_str})*({PMCA_str})*n_PMr",
    explicit_restriction_to_domain="PM",
)
# NCX
Vmax_r22 = Parameter("Vmax_r22", 0.1, vol_unit / sec)
Km_r22 = Parameter("Km_r22", 1, vol_unit)  # uM
beta_NCX = 1000
NCX_str = "Vmax_r22*c/(Km_r22 + c)"
a4 = Reaction(
    "a4",
    ["Ca"],
    [],
    {
        "Prtote": "Prtote",
        "Kme": "Kme",
        "Kmx": "Kmx",
        "Vmax_r22": "Vmax_r22",
        "Km_r22": "Km_r22",
        "n_PMr": "n_PMr",
    },
    {"c": "Ca"},
    eqn_f_str=f"{beta_NCX}*({beta_i_str})*({NCX_str})*n_PMr",
    explicit_restriction_to_domain="PM",
)
# Immobilized buffers
kBf_on = Parameter("kBf_on", 1, 1 / (uM * sec))
kBf_off = Parameter("kBf_off", 2, 1 / sec)
Bf_tot = Parameter("Bf_tot", 78.7 * n_PMr.value, vol_unit * um)
a5 = Reaction(
    "a5",
    ["Ca", "Bf"],
    [],
    {"kon": "kBf_on", "koff": "kBf_off", "Bf_tot": "Bf_tot"},
    eqn_f_str="kon*Ca*Bf - koff*(Bf_tot - Bf)",
    explicit_restriction_to_domain="PM",
)

In [None]:
# Create a results folder
result_folder = args["outdir"]
result_folder.mkdir(exist_ok=True)

We can plot the time-dependent stimulus from a1-a3 using lambdify.

In [None]:
%matplotlib inline
from sympy.utilities.lambdify import lambdify

Vm_func = lambdify(t, Vm_expr, "numpy")  # returns a numpy-ready function
tArray = np.linspace(-0.01, 0.05, 600)

fig, ax = plt.subplots(3, 1)
fig.set_size_inches(15, 10)
ax[0].plot(tArray, Vm_func(tArray))
ax[0].set(xlabel="Time (s)", ylabel="Membrane voltage\n(mV)")

NMDAR_func = lambdify(
    t, A_PSD * J0_NMDAR_expr * G_NMDARVal * (Vm_expr - Vrest_expr), "numpy"
)
ax[1].plot(tArray, NMDAR_func(tArray))
ax[1].set(xlabel="Time (s)", ylabel="NMDAR flux\n(molecules/(um^2*s))")

VSCC_func = lambdify(t, 0.8057 * VSCCNum * k_Ca * VSCC_biexp, "numpy")
# VSCC_func = lambdify(t, VSCC_biexp, 'numpy')
ax[2].plot(tArray, VSCC_func(tArray))
ax[2].set(xlabel="Time (s)", ylabel="VSCC flux\n(molecules/(um^2*s))")
fig.savefig(result_folder / "time_dependent.png")

Now we define the cytosolic reactions. Here, there is only one reaction: mobile buffer binding calcium (b1). Note that because we assume that the buffering protein and the buffering protein bound to calcium have the same diffusion coefficient, we know that the total amount of buffering protein does not change over time or space, and we can write $[CaB_m] = B_{m,tot} - B_m$

In [None]:
# calcium buffering in the cytosol
kBm_on = Parameter("kBm_on", 1, 1 / (uM * sec))
kBm_off = Parameter("kBm_off", 1, 1 / sec)
Bm_tot = Parameter("Bm_tot", 20, vol_unit)
b1 = Reaction(
    "b1",
    ["Ca", "Bm"],
    [],
    param_map={"kon": "kBm_on", "koff": "kBm_off", "Bm_tot": "Bm_tot"},
    eqn_f_str="kon*Ca*Bm - koff*(Bm_tot - Bm)",
)

Finally, we define reactions associated with the spine apparatus:
* c1: calcium pumping into the SA through SERCA
* c2: calcium leak out of the SA

In [None]:
# SERCA flux
n_SAr = Parameter("n_SAr", 0.0113, um)
Vmax_r19 = Parameter("Vmax_r19", 113, vol_unit / sec)
KP_r19 = Parameter("KP_r19", 0.2, vol_unit)
beta_SERCA = 1000
VmaxSERCA_str = "Vmax_r19*c**2/(KP_r19**2 + c**2)"
c1 = Reaction(
    "c1",
    ["Ca"],
    ["CaSA"],
    {
        "Prtote": "Prtote",
        "Kme": "Kme",
        "Kmx": "Kmx",
        "Vmax_r19": "Vmax_r19",
        "KP_r19": "KP_r19",
        "n_SAr": "n_SAr",
    },
    {"c": "Ca"},
    eqn_f_str=f"{beta_SERCA}*({beta_i_str})*({VmaxSERCA_str})*n_SAr",
    explicit_restriction_to_domain="SAm",
)
# calcium leak out of the SA
k_leak = Parameter("k_leak", 0.1608, 1 / sec)
c2 = Reaction(
    "c2",
    ["CaSA"],
    ["Ca"],
    {"k_leak": "k_leak", "n_SAr": "n_SAr"},
    {"c": "Ca", "cSA": "CaSA"},
    eqn_f_str="k_leak*(cSA - c)*n_SAr",
    explicit_restriction_to_domain="SAm",
)

xi = 0.0227272727  # scaling factor to account for rapid buffering in SA
for c in [c1, c2]:
    c.flux_scaling = {"CaSA": xi}
    c.__post_init__()

Now we add all parameters and reactions to their SMART containers.

In [None]:
# pc =ParameterContainer()
# pc.add([n_PMr])
# rc = ReactionContainer()
# rc.add([a1, a2, a3, a4, a5, b1, c1, c2])
pc, sc, cc, rc = sbmodel_from_locals(locals().values())

Initialize model and solver.

In [None]:
configCur = config.Config()
configCur.flags.update({"allow_unused_components": True})
model_cur = model.Model(pc, sc, cc, rc, configCur, parent_mesh)
configCur.solver.update(
    {
        "final_t": 0.025,
        "initial_dt": args["time_step"],
        "time_precision": 8,
        "use_snes": True,
        # "print_assembly": False,
    }
)
import json
# Dump config to results folder
(result_folder / "config.json").write_text(
    json.dumps(
        {
            "solver": configCur.solver.__dict__,
            "reaction_database": configCur.reaction_database,
            "mesh_file": str(args["mesh_file"]),
            "outdir": str(args["outdir"]),
            "num_refinements": args["num_refinements"],
            "time_step": args["time_step"],
        }
    )
)


model_cur.initialize(initialize_solver=False)

# # set initial condition vector for NMDAR
# sp = model_cur.sc['NMDAR']
# u = model_cur.cc[sp.compartment_name].u["u"]
# indices = sp.dof_map
# uvec    = u.vector()
# values  = uvec.get_local()
# dof_coord = model_cur.cc[sp.compartment_name].V.tabulate_dof_coordinates().reshape((-1,3)) # coordinates for current compartment
# dof_coord = dof_coord[indices, :] # pull out coordinates for the current dof
# vertex_coords = [] # save all vertex coordinates that need to be corrected
# idx = 0
# for facet in d.facets(spine_mesh):
#     cur_marker = facet_markers_orig.array()[idx]
#     if cur_marker == 10:
#         for vertex in d.vertices(facet):
#             cur_vertex = list(vertex.point().array())
#             if len(vertex_coords)==0:
#                 vertex_coords.append(cur_vertex)
#             else:
#                 compare_vertices = np.sum((np.array(vertex_coords)-np.array(cur_vertex))**2,1)
#                 if not any(np.isclose(compare_vertices,0)): # unique point not seen yet
#                     vertex_coords.append(cur_vertex)
#     idx = idx + 1
# for i in range(len(dof_coord)):
#     compare_vertices = np.sum((np.array(vertex_coords)-dof_coord[i,:])**2,1)
#     if np.any(np.isclose(compare_vertices, 0)):
#         values[indices[i]] = 0
#     print(f'Completed value {i} of {len(dof_coord)}')

# uvec.set_local(values)
# uvec.apply("insert")
# nvec = model_cur.cc[sp.compartment_name].u["n"].vector()
# nvec.set_local(values)
# nvec.apply("insert")

model_cur.initialize_discrete_variational_problem_and_solver()

Initialize XDMF files for saving results, save model information to .pkl file, then solve the system until `model_cur.t > model_cur.final_t`

In [None]:
# Write initial condition(s) to file
results = dict()
for species_name, species in model_cur.sc.items:
    results[species_name] = d.XDMFFile(
        model_cur.mpi_comm_world, str(result_folder / f"{species_name}.xdmf")
    )
    results[species_name].parameters["flush_output"] = True
    results[species_name].write(model_cur.sc[species_name].u["u"], model_cur.t)
model_cur.to_pickle(result_folder / "model_cur.pkl")

# Set loglevel to warning in order not to pollute notebook output
# logger.setLevel(logging.WARNING)

concVec = np.array([model_cur.sc["Ca"].initial_condition])
tvec = np.array([0.0])
# Solve
displayed = False
while True:
    # Solve the system
    model_cur.monolithic_solve()
    # Save results for post processing
    for species_name, species in model_cur.sc.items:
        results[species_name].write(model_cur.sc[species_name].u["u"], model_cur.t)
    cytoMesh = model_cur.cc["Cyto"].dolfin_mesh
    dx = d.Measure("dx", domain=cytoMesh)
    int_val = d.assemble(model_cur.sc["Ca"].u["u"] * dx)
    volume = d.assemble(1.0 * dx)
    curConc = np.array([int_val / volume])
    concVec = np.concatenate((concVec, curConc))
    tvec = np.concatenate((tvec, np.array([float(model_cur.t)])))
    np.savetxt(result_folder / "tvec.txt", np.array(model_cur.tvec).astype(np.float32))
    if model_cur.t > 0.025 and not displayed:  # display first time after .025 s
        visualization.plot(model_cur.sc["Ca"].u["u"])
        displayed = True
    # End if we've passed the final time
    if model_cur.t >= model_cur.final_t:
        break

Plot results side-by-side with figure from original paper. This graph from the paper uses a spherical cell geometry, whereas we use an ellipsoidal case here, so we expect only qualitatively similar dynamics.

In [None]:
fig, ax = plt.subplots()
ax.plot(tvec, concVec)
ax.set_xlabel("Time (s)")
ax.set_ylabel("Cytosolic calcium (μM)")
fig.savefig(result_folder / "ca2+-example.png")

Save results

In [None]:
np.save(result_folder / "tvec.npy", tvec)
np.save(result_folder / "concVec.npy", concVec)

Calculate area under the curve (AUC)

In [None]:
auc_cur = np.trapz(concVec, tvec)
print(auc_cur)

In [None]:
timings = d.timings(
    d.TimingClear.keep,
    [d.TimingType.wall, d.TimingType.user, d.TimingType.system],
).str(True)

In [None]:
print(timings)

In [None]:
(result_folder / "timings.txt").write_text(timings)