# NMODL Python Interface Tutorial

## Visualization Library Setup

In [127]:
import json

from IPython.display import HTML, Javascript, display

In [128]:
%%javascript
require.config({
    paths: {
        d3: "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.5/d3.min"
     }
});

<IPython.core.display.Javascript object>

In [129]:
Javascript(filename="tree.js")

<IPython.core.display.Javascript object>

In [130]:
HTML(
    """
<style>

.node {
  cursor: pointer;
}

.node circle {
  fill: #d49c9c;
  stroke: #8c6666;
  stroke-width: 1.5px;
}

.node text {
  font-size: 11px !important;
  font-family: sans-serif;
  fill: #4545b5;
}

.link {
  fill: none;
  stroke: #efcece;
  stroke: #efceed;
  stroke-width: 2px;
}

.templink {
  fill: none;
  stroke: red;
  stroke-width: 2px;
}
</style>
"""
)

## Introduction

NMODL is a code generation framework for NEURON Modeling Language. It is primarily designed to support optimised code generation backends for morphologically detailed neuron simulators. It provides high level Python interface that can be used for model introspection as well as performing various analysis on underlying model.

The main goals of the NMODL framework are :

* Support for full NMODL specification
* Providing modular tools for lexing, parsing and analysis
* Optimised code generation for modern architectures
* Compatibility with exisiting simulators
* Ability to implement new simulator backends with minimal efforts

This tutorial provides introduction to python API with examples.

## Installation 
<a id='installation'></a> NMODL can be installed as CMake project or using python setuptools. See README.md for detailed installation instructions. For example :

```bash
python3 -m venv myenv
. myenv/bin/activate
pip3 install nmodl_source_directory/
```

With this you should have nmodl installed.

## Parsing Model And Constructing AST


In [131]:
import nmodl.dsl as nmodl

In [132]:
channel = """
NEURON  {
    SUFFIX CaDynamics
    USEION ca READ ica WRITE cai
    RANGE decay, gamma, minCai, depth
}

UNITS   {
    (mV) = (millivolt)
    (mA) = (milliamp)
    FARADAY = (faraday) (coulombs)
    (molar) = (1/liter)
    (mM) = (millimolar)
    (um)    = (micron)
}

PARAMETER   {
    gamma = 0.05 : percent of free calcium (not buffered)
    decay = 80 (ms) : rate of removal of calcium
    depth = 0.1 (um) : depth of shell
    minCai = 1e-4 (mM)
}

ASSIGNED    {ica (mA/cm2)}

INITIAL {
    cai = minCai
}

STATE   {
    cai (mM)
}

BREAKPOINT  { SOLVE states METHOD cnexp }

DERIVATIVE states   {
    cai' = -(10000)*(ica*gamma/(2*FARADAY*depth)) - (cai - minCai)/decay
}

FUNCTION foo() {
    LOCAL temp
    foo = 1.0 + gamma
}
"""

* Parse any valid NMODL constructs using `nmodl::NmodlDriver`
* then use `parse_string` method

In [146]:
driver = nmodl.NmodlDriver()
modast = driver.parse_string(channel)

If we simply print AST object, we can see JSON representation :

In [134]:
print("%.100s" % modast)  # only first 100 characters
import json

json_data = json.loads(nmodl.to_json(modast, True))
json_data_expand = json.loads(nmodl.to_json(modast, True, True))

{"Program":[{"NeuronBlock":[{"StatementBlock":[{"Suffix":[{"Name":[{"String":[{"name":"SUFFIX"}]}]},


In [135]:
Javascript(
    """(function(element){
                require(['draw_tree'], function(draw) { draw(element.get(0), %s) });
           })(element);"""
    % json.dumps(json_data_expand)
)

<IPython.core.display.Javascript object>

## Querying AST object with Visitors

NMODL visitor interface can be used from python!

### The lookup visitor


In [136]:
from nmodl.dsl import ast, visitor

lookup_visitor = visitor.AstLookupVisitor()

Assuming we have created AST object (as shown [here](#create-ast)), we can search for any NMODL construct in the AST using AstLookupVisitor. For example, to find out `STATE` block in the mod file, we can simply do:

In [137]:
states = lookup_visitor.lookup(modast, ast.AstNodeType.STATE_BLOCK)
for state in states:
    print(nmodl.to_nmodl(state))

STATE {
    cai (mM)
}


Nested lookups:

In [139]:
functions = lookup_visitor.lookup(modast, ast.AstNodeType.FUNCTION_BLOCK)
function = functions[0]  # first function

# expression statements include assignments
new_lookup_visitor = visitor.AstLookupVisitor(ast.AstNodeType.EXPRESSION_STATEMENT)

# using accept method of node
function.accept(new_lookup_visitor)
statements = new_lookup_visitor.get_nodes()

for statement in statements:
    print(nmodl.to_nmodl(statement))

foo = 1+gamma


### Symbol Table Visitor

Symbol table visitor is used to find out all variables and their usage in mod file. To use this, first create a visitor object as: 

In [140]:
from nmodl.dsl import symtab

symv = symtab.SymtabVisitor()

Once the visitor object is created, we can run visitor on AST object to populate symbol table. Symbol table provides print method that can be used to print whole symbol table : 

In [141]:
symv.visit_program(modast)
table = modast.get_symbol_table()
table_s = str(table)

print(table_s)


------------------------------------------------------------------------------------------------------------------------------
|                              NMODL_GLOBAL [Program IN None] POSITION : UNKNOWN SCOPE : GLOBAL                              |
------------------------------------------------------------------------------------------------------------------------------
|   NAME   |                  PROPERTIES                   |  STATUS   |  LOCATION   |   VALUE    |  # READS   |  # WRITES   | 
------------------------------------------------------------------------------------------------------------------------------
| ca       | ion                                           |           |     UNKNOWN |            |     0      |      0      | 
| ica      | dependent_def read_ion                        |           |     UNKNOWN |            |     0      |      0      | 
| cai      | prime_name dependent_def write_ion state_var  |           |     UNKNOWN |            |     0  

### Custom AST visitor

If predefined visitors are limited, we can implement new visitor using AstVisitor interface. Lets say we want to implement a visitor that will print every floating point number in MOD file. Here is how we can do this: 

In [142]:
from nmodl.dsl import ast, visitor


class DoubleVisitor(visitor.AstVisitor):
    def visit_double(self, node):
        print(node.eval())  # or, can use nmodl.to_nmodl(node)


d_visitor = DoubleVisitor()
modast.accept(d_visitor)

0.05
0.1
0.0001
10000.0
2.0
1.0


The `AstVisitor` base class provides all necessary methods to traverse different ast nodes. New visitors will inherit from `AstVisitor` and implement only those method where we want different behaviour. For example, in the above case we want to visit ast nodes of type `Double` and print their value. To achieve this we implemented associated method of `Double` node i.e. `visit_double`. When we call `accept` method on the ast object, the entire AST tree will be visited (by `AstVisitor`). But whenever double node type will encounter in AST, the control will be handed back to `DoubleVisitor.visit_double` method. 

Lets implement the example of lookup visitor to print parameters with values :

In [144]:
class ParameterVisitor(visitor.AstVisitor):
    def __init__(self):
        visitor.AstVisitor.__init__(self)
        self.in_parameter = False

    def visit_param_block(self, node):
        self.in_parameter = True
        node.visit_children(self)
        self.in_parameter = False

    def visit_name(self, node):
        if self.in_parameter:
            print(nmodl.to_nmodl(node)+": ",end='')

    def visit_double(self, node):
        if self.in_parameter:
            print(node.eval())

    def visit_integer(self, node):
        if self.in_parameter:
            print(node.eval())


param_visitor = ParameterVisitor()
modast.accept(param_visitor)

gamma: 0.05
decay: 80
depth: 0.1
minCai: 0.0001


In [145]:
mfunc_src = """FUNCTION myfunc(x, y) {
     if (x < y) {
          myfunc = x + y
     } else {
          myfunc = y
     }
}
"""
import nmodl.dsl as nmodl
from nmodl.dsl import ast

driver = nmodl.NmodlDriver()
mfunc_ast = driver.parse_string(mfunc_src)

In [124]:
from nmodl.dsl import ast
from nmodl.dsl import visitor

class PyGenerator(visitor.AstVisitor):
    def __init__(self):
        visitor.AstVisitor.__init__(self)
        self.pycode = ''
        self.indent = 0
        self.func_name = ""
    
    def visit_function_block(self, node):
        params = []
        self.func_name = node.get_node_name()
        for p in node.parameters:
            params.append(p.get_node_name())
        params_str = ", ".join(params)
        self.pycode += f"def {node.get_node_name()}({params_str}):\n"
        node.visit_children(self)
    
    def visit_statement_block(self, node):
        self.indent += 1
        node.visit_children(self)
        self.indent -= 1
    
    def visit_expression_statement(self, node):
        self.pycode += " "*4*self.indent
        expr = node.expression
        if type(expr) is ast.BinaryExpression and expr.op.eval() == "=":
            rhs = expr.rhs
            lhsn = expr.lhs.name.get_node_name()
            if lhsn == self.func_name:
                self.pycode += "return "
                rhs.accept(self)
            else:
                node.visit_children(self)
        else:
            node.visit_children(self)
        self.pycode += "\n"

    
    def visit_if_statement(self, node):
        self.pycode += " "*4*self.indent + "if "
        node.condition.accept(self)
        self.pycode += ":\n"
        node.get_statement_block().accept(self)
        for n in node.elseifs:
            n.accept(self)
        if node.elses:
            node.elses.accept(self)

    def visit_else_statement(self, node):
        self.pycode += " "*4*self.indent + "else:\n"
        node.get_statement_block().accept(self)
        
    
    def visit_binary_expression(self, node):
        lhs = node.lhs
        rhs = node.rhs
        op = node.op.eval()
        if op == "^":
            self.pycode += "pow("
            lhs.accept(self)
            self.pycode += ", "
            rhs.accept(self)
            self.pycode += ")"
        else:
            lhs.accept(self)
            self.pycode += f" {op} "
            rhs.accept(self)
            
    def visit_var_name(self, node):
        self.pycode += node.name.get_node_name()
    
    def visit_integer(self, node):
        self.pycode += nmod.to_nmodl(node)
    
    def visit_double(self, node):
        self.pycode += nmodl.to_nmodl(node)

In [125]:
pygen = PyGenerator()
pygen.visit_program(mfunc_ast)
print(pygen.pycode)

def myfunc(x, y):
    if x < y:
        return x + y
    else:
        return y

