### NMODL SymPy Visitor Examples

Some examples of the use of SymPy within NMODL to
- solve differential equations to generate solutions for the CNEXP solver (`SympySolverVisitor`)
- differentiate current expressions to generate CONDUCTANCE statements (`SympyConductanceVisitor`)

Please see the [tutorial notebook](nmodl-python-tutorial.ipynb) for a more general tutorial on using the NMODL python interface, including installation instructions.

In [1]:
import nmodl.dsl as nmodl
from nmodl.dsl import ast, symtab, visitor

Starting with the `hh.mod` input file:

In [2]:
channel = """
NEURON {
    SUFFIX hh
    USEION na READ ena WRITE ina
    USEION k READ ek WRITE ik
    NONSPECIFIC_CURRENT il
    RANGE gnabar, gkbar, gl, el, gna, gk
    RANGE minf, hinf, ninf, mtau, htau, ntau
}
 
UNITS {
    (mV) = (millivolt)
    (S) = (siemens)
}
 
PARAMETER {
    gnabar = .12 (S/cm2)
    gkbar = .036 (S/cm2)
    gl = .0003 (S/cm2)
    el = -54.3 (mV)
    celsius
}
 
STATE {
    m h n
}
 
ASSIGNED {
    v (mV)
 
    gna (S/cm2)
    gk (S/cm2)
    minf
    hinf
    ninf
    mtau (ms)
    htau (ms)
    ntau (ms)
}
 
BREAKPOINT {
    SOLVE states METHOD cnexp
    gna = gnabar*m*m*m*h
    ina = gna*(v - ena)
    gk = gkbar*n*n*n*n
    ik = gk*(v - ek)
    il = gl*(v - el)
}
 
INITIAL {
    rates(v, celsius)
    m = minf
    h = hinf
    n = ninf
}
 
DERIVATIVE states {
    rates(v, celsius)
    m' = (minf-m)/mtau
    h' = (hinf-h)/htau
    n' = (ninf-n)/ntau
}
 
PROCEDURE rates(v, celsius)
{
    LOCAL  alpha, beta, sum, q10
 
    q10 = 3^((celsius - 6.3)/10)
 
    :"m" sodium activation system
    alpha = .1 * vtrap(-(v+40),10)
    beta =  4 * exp(-(v+65)/18)
    sum = alpha + beta
    mtau = 1/(q10*sum)
    minf = alpha/sum
 
    :"h" sodium inactivation system
    alpha = .07 * exp(-(v+65)/20)
    beta = 1 / (exp(-(v+35)/10) + 1)
    sum = alpha + beta
    htau = 1/(q10*sum)
    hinf = alpha/sum
 
    :"n" potassium activation system
    alpha = .01*vtrap(-(v+55),10)
    beta = .125*exp(-(v+65)/80)
    sum = alpha + beta
    ntau = 1/(q10*sum)
    ninf = alpha/sum
}
 
FUNCTION vtrap(x,y) {
    : use built in exprelr(z) = z/(exp(z)-1), which handles the z=0 case correctly
    vtrap = y*exprelr(x/y)
}
"""

### ODE Solver example

In [3]:
def parse_mod_to_ast(mod_string):
    # parse NMDOL file (supplied as a string)
    driver = nmodl.Driver()
    driver.parse_string(mod_string)
    modast = driver.ast()
    # run SymtabVisitor to generate Symbol Table
    symv = symtab.SymtabVisitor()
    symv.visit_program(modast)
    # return AST
    return modast

lookup_visitor = visitor.AstLookupVisitor()

In [4]:
# print solve method
modast = parse_mod_to_ast(channel)
print(nmodl.to_nmodl(lookup_visitor.lookup(modast, ast.AstNodeType.SOLVE_BLOCK)[0]))

SOLVE states METHOD cnexp


In [5]:
# print DERIVATIVE block
print(
    nmodl.to_nmodl(lookup_visitor.lookup(modast, ast.AstNodeType.DERIVATIVE_BLOCK)[0])
)

DERIVATIVE states {
    rates(v, celsius)
    m' = (minf-m)/mtau
    h' = (hinf-h)/htau
    n' = (ninf-n)/ntau
}


If we run `SympySolverVisitor`, it does the following:

* If the solver method is "cnexp":
    * Get list of all global scope variables from the Symbol Table, as well as any local variables in DERIVATIVE block
    * For each differential equation in DERIVATIVE block:
        * Parse equation into SymPy, giving it the list of variables
        * This gives us a differential equation of the form:
            * $\frac{dm}{dt} = f(m, \dots)$
            * where the function $f$ depends on $m$, as well as possibly other variables reprensented by $\dots$ which we assume do not depend on $m$ or $t$
        * Solve equation analytically using [sympy.dsolve](https://docs.sympy.org/latest/modules/solvers/ode.html) to give a solution of the form:
            * $m(t+dt) = g(m(t), dt, \dots)$
            * where $g$ is some function that depends on the value of $m$ at time t, the timestep $dt$, and the other variables ($\dots$).
        * Return solution from SymPy as C code using [sympy.printing.ccode](https://docs.sympy.org/latest/_modules/sympy/printing/ccode.html)
        * If we failed to find a solution then revert to existing CNEXP solver routine (same as mod2c or nocmodl)

In [6]:
sympy_solver_visitor = visitor.SympySolverVisitor()
sympy_solver_visitor.visit_program(modast)

If we print the DERIVATIVE block again we see the results:

In [7]:
print(
    nmodl.to_nmodl(lookup_visitor.lookup(modast, ast.AstNodeType.DERIVATIVE_BLOCK)[0])
)

DERIVATIVE states {
    rates(v, celsius)
    m = minf+(m-minf)*exp(-dt/mtau)
    h = hinf+(h-hinf)*exp(-dt/htau)
    n = ninf+(n-ninf)*exp(-dt/ntau)
}


There is an option `use_pade_approx` which if enabled does the following extra step:

* Given the analytic solution $f(t)$:
    * Expand the solution in a Taylor series in `dt`, extract the coefficients $a_i$
        * $f(t + dt) = f(t) + dt f'(t) + dt^2 f''(t) / 2 + \dots = a_0 + a_1 dt + a_2 dt^2 + \dots$
    * Construct the (1,1) Pade approximant to the solution using these Taylor coefficients
        * $f_{PADE}(t+dt) = (a_0 a_1 + (a_1^2 - a_0 a_2) dt)/(a_1 - a_2 dt)$
    * Return this approximate solution (correct to second order in $dt$) as C code

(Replacing the exponential with a Pade aproximant here was suggested in sec 5.2 of (https://www.eccomas2016.org/proceedings/pdf/7366.pdf) - since the overall numerical integration scheme in NEURON is only correct to first or second order in $dt$, it is valid to expand the analytic solution here to the same order and so avoid evaluating the exponential function)

If we now run `SympySolverVisitor` with `use_pade_approx=True`, and print the DERIVATIVE block again, we see the results:

In [8]:
modast = parse_mod_to_ast(channel)
sympy_solver_visitor = visitor.SympySolverVisitor(use_pade_approx=True)
sympy_solver_visitor.visit_program(modast)
# print DERIVATIVE block
print(
    nmodl.to_nmodl(lookup_visitor.lookup(modast, ast.AstNodeType.DERIVATIVE_BLOCK)[0])
)

DERIVATIVE states {
    rates(v, celsius)
    m = (-dt*m+2*dt*minf+2*m*mtau)/(dt+2*mtau)
    h = (-dt*h+2*dt*hinf+2*h*htau)/(dt+2*htau)
    n = (-dt*n+2*dt*ninf+2*n*ntau)/(dt+2*ntau)
}


### CONDUCTANCE example

The CONDUCTANCE keyword has been introduced to NEURON as well as CoreNEURON. If the  i/v relation is ohmic in BREAKPOINT block then one can use CONDUCTANCE keyword for efficiency.

In [9]:
modast = parse_mod_to_ast(channel)
# print USEION and NONSPECIFIC current statements
for node in lookup_visitor.lookup(modast, ast.AstNodeType.USEION):
    print(nmodl.to_nmodl(node))
for node in lookup_visitor.lookup(modast, ast.AstNodeType.NONSPECIFIC):
    print(nmodl.to_nmodl(node))

USEION na READ ena WRITE ina
USEION k READ ek WRITE ik
NONSPECIFIC_CURRENT il


In [10]:
# print BREAKPOINT
print(
    nmodl.to_nmodl(lookup_visitor.lookup(modast, ast.AstNodeType.BREAKPOINT_BLOCK)[0])
)

BREAKPOINT {
    SOLVE states METHOD cnexp
    gna = gnabar*m*m*m*h
    ina = gna*(v-ena)
    gk = gkbar*n*n*n*n
    ik = gk*(v-ek)
    il = gl*(v-el)
}


If we run `SympyConductanceVisitor`, it does the following:

* For each ion write statement $i = \dots$ in the BREAKPOINT block
    * Differentiate to find the conductance $g_i=di/dv$
    * If this $g_i$ coincides with an existing variable, e.g. $g$, add to BREAKPOINT the statement:
        * CONDUCTANCE g USEION ion_name
    * If not, also need to declare and asign a variable for the calculated conductance:
        * LOCAL g_i_0
        * CONDUCTANCE g_i_0 USEION ion_name
        * g_i_0 = ...
    * But if there is an existing CONDUCTANCE statement, then do not modify it


* NOTE: currently we just differentiate the equation, assuming that the variables do not depend on v
    * Ideally we should do something like:
        * check after inlining which variables are written to in BREAKPOINT block
        * get rhs of each such expression, and substitute it (recursively) for the lhs in SymPy
        * this should give a (complicated) expression $i = ...$ where all v dependence is explicit
        * then differentiate this w.r.t v
        * then simplify and try to write the result in terms of an existing variable

Here is the BREAKPOINT block after running `SympyConductanceVisitor`:

In [11]:
sympy_conductance_visitor = visitor.SympyConductanceVisitor()
sympy_conductance_visitor.visit_program(modast)
# print BREAKPOINT block
print(
    nmodl.to_nmodl(lookup_visitor.lookup(modast, ast.AstNodeType.BREAKPOINT_BLOCK)[0])
)

BREAKPOINT {
    CONDUCTANCE gna USEION na
    CONDUCTANCE gl
    CONDUCTANCE gk USEION k
    SOLVE states METHOD cnexp
    gna = gnabar*m*m*m*h
    ina = gna*(v-ena)
    gk = gkbar*n*n*n*n
    ik = gk*(v-ek)
    il = gl*(v-el)
}
