diff --git a/.bbp-project.yaml b/.bbp-project.yaml index ee8dd4949..d233e4e16 100644 --- a/.bbp-project.yaml +++ b/.bbp-project.yaml @@ -19,3 +19,9 @@ tools: match: - ext/.* - src/language/templates/* + Black: + enable: True + version: ~=24.2.0 + include: + match: + - .*\.py$ diff --git a/docs/conf.py b/docs/conf.py index ff9c2afbc..2fd3e5230 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -170,7 +170,7 @@ (master_doc, "nmodl.tex", "nmodl Documentation", "BlueBrain HPC team", "manual") ] -imgmath_image_format = 'svg' +imgmath_image_format = "svg" imgmath_embed = True imgmath_font_size = 14 diff --git a/python/nmodl/ast.py b/python/nmodl/ast.py index b8f990f59..1963e1ae2 100644 --- a/python/nmodl/ast.py +++ b/python/nmodl/ast.py @@ -1,6 +1,7 @@ """ Module for vizualization of NMODL abstract syntax trees (ASTs). """ + import getpass import json import os diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index e25da9337..cbb848839 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -16,11 +16,11 @@ # accessed through regular imports major, minor = (int(v) for v in sp.__version__.split(".")[:2]) if major >= 1 and minor >= 7: - known_functions = import_module('sympy.printing.c').known_functions_C99 + known_functions = import_module("sympy.printing.c").known_functions_C99 else: - known_functions = import_module('sympy.printing.ccode').known_functions_C99 -known_functions.pop('Abs') -known_functions['abs'] = 'fabs' + known_functions = import_module("sympy.printing.ccode").known_functions_C99 +known_functions.pop("Abs") +known_functions["abs"] = "fabs" if not ((major >= 1) and (minor >= 2)): @@ -29,7 +29,18 @@ # Some functions are protected inside sympy, if user has declared such a function, it will fail # because sympy will try to use its own internal one. # Rename it before and after to a single name -forbidden_var = ["beta", "gamma", "uppergamma", "lowergamma", "polygamma", "loggamma", "digamma", "trigamma"] +forbidden_var = [ + "beta", + "gamma", + "uppergamma", + "lowergamma", + "polygamma", + "loggamma", + "digamma", + "trigamma", +] + + def search_and_replace_protected_functions_to_sympy(eqs, function_calls): for c in function_calls: if c in forbidden_var: @@ -38,6 +49,7 @@ def search_and_replace_protected_functions_to_sympy(eqs, function_calls): eqs = [re.sub(r, f, x) for x in eqs] return eqs + def search_and_replace_protected_functions_from_sympy(eqs, function_calls): for c in function_calls: if c in forbidden_var: @@ -45,6 +57,7 @@ def search_and_replace_protected_functions_from_sympy(eqs, function_calls): eqs = [re.sub(r, f"{c}", x) for x in eqs] return eqs + def _get_custom_functions(fcts): custom_functions = {} for f in fcts: @@ -143,13 +156,16 @@ def _sympify_eqs(eq_strings, state_vars, vars): for state_var in state_vars: sympy_state_vars.append(sp.sympify(state_var, locals=sympy_vars)) eqs = [ - (sp.sympify(eq.split("=", 1)[1], locals=sympy_vars) - - sp.sympify(eq.split("=", 1)[0], locals=sympy_vars)).expand() + ( + sp.sympify(eq.split("=", 1)[1], locals=sympy_vars) + - sp.sympify(eq.split("=", 1)[0], locals=sympy_vars) + ).expand() for eq in eq_strings ] return eqs, sympy_state_vars, sympy_vars + def _interweave_eqs(F, J): """Interweave F and J equations so that they are printed in code rowwise from the equation J x = F. For example: @@ -199,13 +215,21 @@ def _interweave_eqs(F, J): n = len(F) for i, expr in enumerate(F): code.append(expr) - for j in range(i * n, (i+1) * n): + for j in range(i * n, (i + 1) * n): code.append(J[j]) return code -def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_prefix, small_system=False, do_cse=False): +def solve_lin_system( + eq_strings, + vars, + constants, + function_calls, + tmp_unique_prefix, + small_system=False, + do_cse=False, +): """Solve linear system of equations, return solution as C code. If system is small (small_system=True, typically N<=3): @@ -233,7 +257,9 @@ def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_pre vars: list of strings containing new local variables """ - eq_strings = search_and_replace_protected_functions_to_sympy(eq_strings, function_calls) + eq_strings = search_and_replace_protected_functions_to_sympy( + eq_strings, function_calls + ) eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants) custom_fcts = _get_custom_functions(function_calls) @@ -246,7 +272,9 @@ def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_pre solution_vector = sp.linsolve(eqs, state_vars).args[0] if do_cse: # generate prefix for new local vars that avoids clashes - my_symbols = sp.utilities.iterables.numbered_symbols(prefix=tmp_unique_prefix + '_') + my_symbols = sp.utilities.iterables.numbered_symbols( + prefix=tmp_unique_prefix + "_" + ) sub_exprs, simplified_solution_vector = sp.cse( solution_vector, symbols=my_symbols, @@ -255,10 +283,14 @@ def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_pre ) for var, expr in sub_exprs: new_local_vars.append(sp.ccode(var)) - code.append(f"{var} = {sp.ccode(expr.evalf(), user_functions=custom_fcts)}") + code.append( + f"{var} = {sp.ccode(expr.evalf(), user_functions=custom_fcts)}" + ) solution_vector = simplified_solution_vector[0] for var, expr in zip(state_vars, solution_vector): - code.append(f"{sp.ccode(var)} = {sp.ccode(expr.evalf(), contract=False, user_functions=custom_fcts)}") + code.append( + f"{sp.ccode(var)} = {sp.ccode(expr.evalf(), contract=False, user_functions=custom_fcts)}" + ) else: # large linear system: construct and return matrix J, vector F such that # J X = F is the linear system to be solved for X by e.g. LU factorization @@ -267,13 +299,17 @@ def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_pre # construct vector F vecFcode = [] for i, expr in enumerate(vecF): - vecFcode.append(f"F[{i}] = {sp.ccode(expr.simplify().evalf(), user_functions=custom_fcts)}") + vecFcode.append( + f"F[{i}] = {sp.ccode(expr.simplify().evalf(), user_functions=custom_fcts)}" + ) # construct matrix J vecJcode = [] for i, expr in enumerate(matJ): # todo: fix indexing to be ascending order flat_index = matJ.rows * (i % matJ.rows) + (i // matJ.rows) - vecJcode.append(f"J[{flat_index}] = {sp.ccode(expr.simplify().evalf(), user_functions=custom_fcts)}") + vecJcode.append( + f"J[{flat_index}] = {sp.ccode(expr.simplify().evalf(), user_functions=custom_fcts)}" + ) # interweave code = _interweave_eqs(vecFcode, vecJcode) @@ -299,7 +335,9 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls): List of strings containing assignment statements """ - eq_strings = search_and_replace_protected_functions_to_sympy(eq_strings, function_calls) + eq_strings = search_and_replace_protected_functions_to_sympy( + eq_strings, function_calls + ) eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants) custom_fcts = _get_custom_functions(function_calls) @@ -310,13 +348,18 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls): vecFcode = [] for i, eq in enumerate(eqs): - vecFcode.append(f"F[{i}] = {sp.ccode(eq.simplify().subs(X_vec_map).evalf(), user_functions=custom_fcts)}") + vecFcode.append( + f"F[{i}] = {sp.ccode(eq.simplify().subs(X_vec_map).evalf(), user_functions=custom_fcts)}" + ) vecJcode = [] for i, j in itertools.product(range(jacobian.rows), range(jacobian.cols)): flat_index = i + jacobian.rows * j - rhs = sp.ccode(jacobian[i,j].simplify().subs(X_vec_map).evalf(), user_functions=custom_fcts) + rhs = sp.ccode( + jacobian[i, j].simplify().subs(X_vec_map).evalf(), + user_functions=custom_fcts, + ) vecJcode.append(f"J[{flat_index}] = {rhs}") # interweave diff --git a/src/language/argument.py b/src/language/argument.py index 8b8f7ad59..5745b193d 100644 --- a/src/language/argument.py +++ b/src/language/argument.py @@ -3,6 +3,7 @@ # # SPDX-License-Identifier: Apache-2.0 + class Argument: """Utility class for holding all arguments for node classes""" diff --git a/src/language/code_generator.py b/src/language/code_generator.py index f2d09ae49..90113d012 100644 --- a/src/language/code_generator.py +++ b/src/language/code_generator.py @@ -111,7 +111,7 @@ def jinja_template(self, path): return self.jinja_env.get_template(name) def _cmake_deps_task(self, tasks): - """"Construct the JinjaTask generating the CMake file exporting all dependencies + """ "Construct the JinjaTask generating the CMake file exporting all dependencies Args: tasks: list of JinjaTask objects @@ -196,12 +196,18 @@ def workload(self): task = JinjaTask( app=self, input=filepath, - output=self.base_dir / sub_dir / "pynode_{}.cpp".format(chunk_k), + output=self.base_dir + / sub_dir + / "pynode_{}.cpp".format(chunk_k), context=dict( nodes=self.nodes[ - chunk_k * chunk_length : (chunk_k + 1) * chunk_length + chunk_k + * chunk_length : (chunk_k + 1) + * chunk_length ], - setup_pybind_method="init_pybind_classes_{}".format(chunk_k), + setup_pybind_method="init_pybind_classes_{}".format( + chunk_k + ), ), extradeps=extradeps[filepath], ) @@ -212,7 +218,11 @@ def workload(self): app=self, input=filepath, output=self.base_dir / sub_dir / filepath.name, - context=dict(nodes=self.nodes, node_info=node_info, **extracontext[filepath]), + context=dict( + nodes=self.nodes, + node_info=node_info, + **extracontext[filepath], + ), extradeps=extradeps[filepath], ) tasks.append(task) @@ -235,7 +245,7 @@ class JinjaTask( """ def execute(self): - """"Perform the Jinja task + """ "Perform the Jinja task Execute Jinja renderer if the output file is out-of-date. diff --git a/src/language/language_parser.py b/src/language/language_parser.py index df7d49aff..5f8e1e386 100644 --- a/src/language/language_parser.py +++ b/src/language/language_parser.py @@ -26,7 +26,6 @@ def __init__(self, filename, debug=False): self.filename = filename self.debug = debug - def parse_child_rule(self, child): """parse child specification and return argument as properties @@ -40,72 +39,72 @@ def parse_child_rule(self, child): args.varname = varname # type i.e. class of the variable - args.class_name = properties['type'] + args.class_name = properties["type"] if self.debug: - print('Child {}, {}'.format(args.varname, args.class_name)) + print("Child {}, {}".format(args.varname, args.class_name)) # if there is add method for member in the class - if 'add' in properties: - args.add_method = properties['add'] + if "add" in properties: + args.add_method = properties["add"] # if variable is an optional member - if 'optional' in properties: - args.is_optional = properties['optional'] + if "optional" in properties: + args.is_optional = properties["optional"] # if variable is vector, the separator to use while # printing back to nmodl - if 'separator' in properties: - args.separator = properties['separator'] + if "separator" in properties: + args.separator = properties["separator"] # if variable is public member - if 'public' in properties: - args.is_public = properties['public'] + if "public" in properties: + args.is_public = properties["public"] # if variable if of vector type - if 'vector' in properties: - args.is_vector = properties['vector'] + if "vector" in properties: + args.is_vector = properties["vector"] # if get_node_name method required - if 'node_name' in properties: - args.get_node_name = properties['node_name'] + if "node_name" in properties: + args.get_node_name = properties["node_name"] # brief description of member variable - if 'brief' in properties: - args.brief = properties['brief'] + if "brief" in properties: + args.brief = properties["brief"] # description of member variable - if 'description' in properties: - args.description = properties['description'] + if "description" in properties: + args.description = properties["description"] # if getter method required - if 'getter' in properties: - if 'name' in properties['getter']: - args.getter_method = properties['getter']['name'] - if 'override' in properties['getter']: - args.getter_override = properties['getter']['override'] + if "getter" in properties: + if "name" in properties["getter"]: + args.getter_method = properties["getter"]["name"] + if "override" in properties["getter"]: + args.getter_override = properties["getter"]["override"] # if there is nmodl name - if 'nmodl' in properties: - args.nmodl_name = properties['nmodl'] + if "nmodl" in properties: + args.nmodl_name = properties["nmodl"] # prefix while printing back to NMODL - if 'prefix' in properties: - args.prefix = properties['prefix']['value'] + if "prefix" in properties: + args.prefix = properties["prefix"]["value"] # if prefix is compulsory to print in NMODL then make suffix empty - if 'force' in properties['prefix']: - if properties['prefix']['force']: + if "force" in properties["prefix"]: + if properties["prefix"]["force"]: args.force_prefix = args.prefix args.prefix = "" # suffix while printing back to NMODL - if 'suffix' in properties: - args.suffix = properties['suffix']['value'] + if "suffix" in properties: + args.suffix = properties["suffix"]["value"] # if suffix is compulsory to print in NMODL then make suffix empty - if 'force' in properties['suffix']: - if properties['suffix']['force']: + if "force" in properties["suffix"]: + if properties["suffix"]["force"]: args.force_suffix = args.suffix args.suffix = "" @@ -124,15 +123,17 @@ def parse_yaml_rules(self, nodelist, base_class=None): continue args = Argument() - args.url = properties.get('url', None) + args.url = properties.get("url", None) args.class_name = class_name - args.brief = properties.get('brief', '') - args.description = properties.get('description', '') + args.brief = properties.get("brief", "") + args.description = properties.get("description", "") # yaml file has abstract classes and their subclasses with children as a property - if 'children' in properties: + if "children" in properties: # recursively parse all sub-classes of current abstract class - child_abstract_nodes, child_nodes = self.parse_yaml_rules(properties['children'], class_name) + child_abstract_nodes, child_nodes = self.parse_yaml_rules( + properties["children"], class_name + ) # append all parsed subclasses abstract_nodes.extend(child_abstract_nodes) @@ -146,26 +147,26 @@ def parse_yaml_rules(self, nodelist, base_class=None): abstract_nodes.append(node) nodes.insert(0, node) if self.debug: - print('Abstract {}'.format(node)) + print("Abstract {}".format(node)) else: - args.base_class = base_class if base_class else 'Ast' + args.base_class = base_class if base_class else "Ast" # store token in every node args.has_token = True # name of the node while printing back to NMODL - args.nmodl_name = properties['nmodl'] if 'nmodl' in properties else None + args.nmodl_name = properties["nmodl"] if "nmodl" in properties else None # create tree node and add to the list node = Node(args) nodes.append(node) if self.debug: - print('Class {}'.format(node)) + print("Class {}".format(node)) # now process all children specification - if 'members' in properties: - for child in properties['members']: + if "members" in properties: + for child in properties["members"]: args = self.parse_child_rule(child) node.add_child(args) @@ -179,15 +180,18 @@ def parse_yaml_rules(self, nodelist, base_class=None): return abstract_nodes, nodes def parse_file(self): - """ parse nmodl YAML specification file for AST creation """ + """parse nmodl YAML specification file for AST creation""" - with open(self.filename, 'r') as stream: + with open(self.filename, "r") as stream: try: rules = yaml.safe_load(stream) _, nodes = self.parse_yaml_rules(rules) except yaml.YAMLError as e: - print("Error while parsing YAML definition file {0} : {1}".format( - self.filename, e.strerror)) + print( + "Error while parsing YAML definition file {0} : {1}".format( + self.filename, e.strerror + ) + ) sys.exit(1) return nodes diff --git a/src/language/nodes.py b/src/language/nodes.py index 9e37a40c5..c1644e4a2 100644 --- a/src/language/nodes.py +++ b/src/language/nodes.py @@ -17,7 +17,7 @@ class BaseNode: - """base class for all node types (parent + child) """ + """base class for all node types (parent + child)""" def __init__(self, args): self.class_name = args.class_name @@ -35,7 +35,7 @@ def __lt__(self, other): return self.class_name < other.class_name def get_data_type_name(self): - """ return type name for the node """ + """return type name for the node""" return node_info.DATA_TYPES[self.class_name] @property @@ -145,8 +145,9 @@ def is_enum_node(self): @property def is_pointer_node(self): - return not (self.class_name in node_info.PTR_EXCLUDE_TYPES or - self.is_base_type_node) + return not ( + self.class_name in node_info.PTR_EXCLUDE_TYPES or self.is_base_type_node + ) @property def is_ptr_excluded_node(self): @@ -154,9 +155,11 @@ def is_ptr_excluded_node(self): @property def requires_default_constructor(self): - return (self.class_name in node_info.LEXER_DATA_TYPES or - self.is_program_node or - self.is_ptr_excluded_node) + return ( + self.class_name in node_info.LEXER_DATA_TYPES + or self.is_program_node + or self.is_ptr_excluded_node + ) @property def has_template_methods(self): @@ -205,7 +208,11 @@ def get_rvalue_typename(self): typename = self.get_typename() if self.is_vector: return "const " + typename + "&" - if self.is_base_type_node and not self.is_integral_type_node or self.is_ptr_excluded_node: + if ( + self.is_base_type_node + and not self.is_integral_type_node + or self.is_ptr_excluded_node + ): return "const " + typename + "&" return typename @@ -242,18 +249,23 @@ def member_typename(self): @property def _is_member_type_wrapped_as_shared_pointer(self): - return not (self.is_vector or self.is_base_type_node or self.is_ptr_excluded_node) + return not ( + self.is_vector or self.is_base_type_node or self.is_ptr_excluded_node + ) @property def member_rvalue_typename(self): """returns rvalue reference type when used as returned or parameter type""" typename = self.member_typename - if not self.is_integral_type_node and not self._is_member_type_wrapped_as_shared_pointer: + if ( + not self.is_integral_type_node + and not self._is_member_type_wrapped_as_shared_pointer + ): return "const " + typename + "&" return typename def get_add_methods_declaration(self): - s = '' + s = "" if self.add_method: method = f""" /** @@ -304,11 +316,11 @@ def get_add_methods_declaration(self): */ void reset_{to_snake_case(self.class_name)}({self.class_name}Vector::const_iterator position, std::shared_ptr<{self.class_name}> n); """ - s = textwrap.indent(textwrap.dedent(method), ' ') + s = textwrap.indent(textwrap.dedent(method), " ") return s def get_add_methods_definition(self, parent): - s = '' + s = "" if self.add_method: set_parent = "n->set_parent(this); " if self.optional: @@ -407,7 +419,7 @@ def get_add_methods_definition(self, parent): return s def get_add_methods_inline_definition(self, parent): - s = '' + s = "" if self.add_method: set_parent = "n->set_parent(this); " if self.optional: @@ -435,7 +447,7 @@ def get_add_methods_inline_definition(self, parent): return s def get_node_name_method_declaration(self): - s = '' + s = "" if self.get_node_name: # string node should be evaluated and hence eval() method method = f""" @@ -451,11 +463,11 @@ def get_node_name_method_declaration(self): * \\sa Ast::get_node_type_name */ std::string get_node_name() const override;""" - s = textwrap.indent(textwrap.dedent(method), ' ') + s = textwrap.indent(textwrap.dedent(method), " ") return s def get_node_name_method_definition(self, parent): - s = '' + s = "" if self.get_node_name: # string node should be evaluated and hence eval() method method_name = "eval" if self.is_string_node else "get_node_name" @@ -466,31 +478,47 @@ def get_node_name_method_definition(self, parent): return s def get_getter_method(self, class_name): - getter_method = self.getter_method if self.getter_method else "get_" + to_snake_case(self.varname) + getter_method = ( + self.getter_method + if self.getter_method + else "get_" + to_snake_case(self.varname) + ) getter_override = "override" if self.getter_override else "" return_type = self.member_rvalue_typename - return textwrap.indent(textwrap.dedent(f""" + return textwrap.indent( + textwrap.dedent( + f""" /** * \\brief Getter for member variable \\ref {class_name}.{self.varname} */ {return_type} {getter_method}() const noexcept {getter_override} {{ return {self.varname}; }} - """), ' ') + """ + ), + " ", + ) def get_setter_method_declaration(self, class_name): setter_method = "set_" + to_snake_case(self.varname) setter_type = self.member_typename reference = "" if self.is_base_type_node else "&&" if self.is_base_type_node: - return textwrap.indent(textwrap.dedent(f""" + return textwrap.indent( + textwrap.dedent( + f""" /** * \\brief Setter for member variable \\ref {class_name}.{self.varname} */ void {setter_method}({setter_type} {self.varname}); - """), ' ') + """ + ), + " ", + ) else: - return textwrap.indent(textwrap.dedent(f""" + return textwrap.indent( + textwrap.dedent( + f""" /** * \\brief Setter for member variable \\ref {class_name}.{self.varname} (rvalue reference) */ @@ -500,14 +528,16 @@ def get_setter_method_declaration(self, class_name): * \\brief Setter for member variable \\ref {class_name}.{self.varname} */ void {setter_method}(const {setter_type}& {self.varname}); - """), ' ') + """ + ), + " ", + ) def get_setter_method_definition(self, class_name): setter_method = "set_" + to_snake_case(self.varname) setter_type = self.member_typename reference = "" if self.is_base_type_node else "&&" - if self.is_base_type_node: return f""" void {class_name}::{setter_method}({setter_type} {self.varname}) {{ @@ -563,12 +593,10 @@ def get_setter_method_definition(self, class_name): }} """ - - - def __repr__(self): return "ChildNode(class_name='{}', nmodl_name='{}')".format( - self.class_name, self.nmodl_name) + self.class_name, self.nmodl_name + ) __str__ = __repr__ @@ -595,7 +623,9 @@ def cpp_header_deps(self): for child in self.children: if child.is_ptr_excluded_node or child.is_vector: dependent_classes.add(child.class_name) - return sorted(["ast/{}.hpp".format(to_snake_case(clazz)) for clazz in dependent_classes]) + return sorted( + ["ast/{}.hpp".format(to_snake_case(clazz)) for clazz in dependent_classes] + ) @property def ast_enum_name(self): @@ -631,12 +661,14 @@ def has_parent_block_node(self): @property def has_setters(self): """returns True if the class has at least one setter member method""" - return any([ - self.is_name_node, - self.has_token, - self.is_symtab_needed, - self.is_data_type_node and not self.is_enum_node - ]) + return any( + [ + self.is_name_node, + self.has_token, + self.is_symtab_needed, + self.is_data_type_node and not self.is_enum_node, + ] + ) @property def is_base_block_node(self): @@ -669,13 +701,13 @@ def is_symtab_method_required(self): :return: True if need to print visit method for node in symtabjsonvisitor otherwise False """ - return (self.has_children() and - (self.is_symbol_var_node or - self.is_symbol_block_node or - self.is_symbol_helper_node or - self.is_program_node or - self.has_parent_block_node() - )) + return self.has_children() and ( + self.is_symbol_var_node + or self.is_symbol_block_node + or self.is_symbol_helper_node + or self.is_program_node + or self.has_parent_block_node() + ) @property def is_base_class_number_node(self): @@ -685,12 +717,12 @@ def is_base_class_number_node(self): return self.base_class == node_info.NUMBER_NODE def ctor_declaration(self): - args = [f'{c.get_rvalue_typename()} {c.varname}' for c in self.children] + args = [f"{c.get_rvalue_typename()} {c.varname}" for c in self.children] return f"explicit {self.class_name}({', '.join(args)});" def ctor_definition(self): - args = [f'{c.get_rvalue_typename()} {c.varname}' for c in self.children] - initlist = [f'{c.varname}({c.varname})' for c in self.children] + args = [f"{c.get_rvalue_typename()} {c.varname}" for c in self.children] + initlist = [f"{c.varname}({c.varname})" for c in self.children] s = f"""{self.class_name}::{self.class_name}({', '.join(args)}) : {', '.join(initlist)} {{ set_parent_in_children(); }} @@ -698,12 +730,12 @@ def ctor_definition(self): return textwrap.dedent(s) def ctor_shrptr_declaration(self): - args = [f'{c.member_rvalue_typename} {c.varname}' for c in self.children] + args = [f"{c.member_rvalue_typename} {c.varname}" for c in self.children] return f"explicit {self.class_name}({', '.join(args)});" def ctor_shrptr_definition(self): - args = [f'{c.member_rvalue_typename} {c.varname}' for c in self.children] - initlist = [f'{c.varname}({c.varname})' for c in self.children] + args = [f"{c.member_rvalue_typename} {c.varname}" for c in self.children] + initlist = [f"{c.varname}({c.varname})" for c in self.children] s = f"""{self.class_name}::{self.class_name}({', '.join(args)}) : {', '.join(initlist)} {{ set_parent_in_children(); }} @@ -711,16 +743,20 @@ def ctor_shrptr_definition(self): return textwrap.dedent(s) def has_ptr_children(self): - return any(not (c.is_vector or c.is_base_type_node or c.is_ptr_excluded_node) - for c in self.children) + return any( + not (c.is_vector or c.is_base_type_node or c.is_ptr_excluded_node) + for c in self.children + ) def public_members(self): """ Return public members of the node """ - members = [[child.member_typename, child.varname, None, child.brief] - for child in self.children - if child.is_public] + members = [ + [child.member_typename, child.varname, None, child.brief] + for child in self.children + if child.is_public + ] return members @@ -728,18 +764,41 @@ def private_members(self): """ Return private members of the node """ - members = [[child.member_typename, child.varname, None, child.brief] - for child in self.children - if not child.is_public] + members = [ + [child.member_typename, child.varname, None, child.brief] + for child in self.children + if not child.is_public + ] if self.has_token: - members.append(["std::shared_ptr", "token", None, "token with location information"]) + members.append( + [ + "std::shared_ptr", + "token", + None, + "token with location information", + ] + ) if self.is_symtab_needed: - members.append(["symtab::SymbolTable*", "symtab", "nullptr", "symbol table for a block"]) + members.append( + [ + "symtab::SymbolTable*", + "symtab", + "nullptr", + "symbol table for a block", + ] + ) if self.is_program_node: - members.append(["symtab::ModelSymbolTable", "model_symtab", None, "global symbol table for model"]) + members.append( + [ + "symtab::ModelSymbolTable", + "model_symtab", + None, + "global symbol table for model", + ] + ) return members @@ -747,12 +806,28 @@ def properties(self): """ Return private members of the node destined to be pybind properties """ - members = [[child.member_typename, child.varname, child.is_base_type_node, None, child.brief] - for child in self.children - if not child.is_public] + members = [ + [ + child.member_typename, + child.varname, + child.is_base_type_node, + None, + child.brief, + ] + for child in self.children + if not child.is_public + ] if self.has_token: - members.append(["std::shared_ptr", "token", True, None, "token with location information"]) + members.append( + [ + "std::shared_ptr", + "token", + True, + None, + "token with location information", + ] + ) return members @@ -764,18 +839,19 @@ def get_description(self): """ Return description for the node in doxygen form """ - lines = self.description.split('\n') + lines = self.description.split("\n") description = "" for i, line in enumerate(lines): if i == 0: - description = ' ' + line + '\n' + description = " " + line + "\n" else: - description += ' * ' + line + '\n' + description += " * " + line + "\n" return description def __repr__(self): return "Node(class_name='{}', base_class='{}', nmodl_name='{}')".format( - self.class_name, self.base_class, self.nmodl_name) + self.class_name, self.base_class, self.nmodl_name + ) def __eq__(self, other): """ diff --git a/src/language/utils.py b/src/language/utils.py index 9aeb04d6c..82dd3a033 100644 --- a/src/language/utils.py +++ b/src/language/utils.py @@ -8,8 +8,8 @@ def camel_case_to_underscore(name): """convert string from 'AaaBbbbCccDdd' -> 'Aaa_Bbbb_Ccc_Ddd'""" - s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) - typename = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1) + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + typename = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1) return typename diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index ec92958da..387cfb801 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -67,7 +67,8 @@ def test_differentiate2c(): # multiple prev_eqs to substitute # (these statements should be in the same order as in the mod file) assert _equivalent( - differentiate2c("a*x + b", "x", {"a", "b"}, ["b = 2*x", "a = -2*x*x"]), "-6*x*x+2" + differentiate2c("a*x + b", "x", {"a", "b"}, ["b = 2*x", "a = -2*x*x"]), + "-6*x*x+2", ) assert _equivalent( differentiate2c("a*x + b", "x", {"a", "b"}, ["b = 2*x*x", "a = -2*x"]), "0" @@ -113,7 +114,7 @@ def test_integrate2c(): ("a*x", "x*exp(a*dt)"), ("a*x+b", "(-b + (a*x + b)*exp(a*dt))/a"), ] - for (eq, sol) in test_cases: + for eq, sol in test_cases: assert _equivalent( integrate2c(f"x'={eq}", "dt", var_list, use_pade_approx=False), f"x = {sol}" ) @@ -125,7 +126,7 @@ def test_integrate2c(): ("a*x", "-x*(a*dt+2)/(a*dt-2)"), ("a*x+b", "-(a*dt*x+2*b*dt+2*x)/(a*dt-2)"), ] - for (eq, sol) in pade_test_cases: + for eq, sol in pade_test_cases: assert _equivalent( integrate2c(f"x'={eq}", "dt", var_list, use_pade_approx=True), f"x = {sol}" ) diff --git a/test/unit/pybind/test_ast.py b/test/unit/pybind/test_ast.py index b0b353ebe..e387895b6 100644 --- a/test/unit/pybind/test_ast.py +++ b/test/unit/pybind/test_ast.py @@ -7,36 +7,37 @@ import nmodl.dsl as nmodl import pytest + class TestAst(object): def test_empty_program(self): pnode = ast.Program() - assert str(pnode) == '' + assert str(pnode) == "" def test_ast_construction(self): string = ast.String("tau") name = ast.Name(string) - assert nmodl.to_nmodl(name) == 'tau' + assert nmodl.to_nmodl(name) == "tau" int_macro = nmodl.ast.Integer(1, ast.Name(ast.String("x"))) - assert nmodl.to_nmodl(int_macro) == 'x' + assert nmodl.to_nmodl(int_macro) == "x" statements = [] block = ast.StatementBlock(statements) neuron_block = ast.NeuronBlock(block) - assert nmodl.to_nmodl(neuron_block) == 'NEURON {\n}' + assert nmodl.to_nmodl(neuron_block) == "NEURON {\n}" def test_get_parent(self): x_name = ast.Name(ast.String("x")) int_macro = nmodl.ast.Integer(1, x_name) - assert x_name.parent == int_macro # getting the parent + assert x_name.parent == int_macro # getting the parent def test_set_parent(self): x_name = ast.Name(ast.String("x")) y_name = ast.Name(ast.String("y")) int_macro = nmodl.ast.Integer(1, x_name) - y_name.parent = int_macro # setting the parent + y_name.parent = int_macro # setting the parent int_macro.macro = y_name - assert nmodl.to_nmodl(int_macro) == 'y' + assert nmodl.to_nmodl(int_macro) == "y" def test_ast_node_repr(self): string = ast.String("tau") diff --git a/test/unit/pybind/test_symtab.py b/test/unit/pybind/test_symtab.py index f79e444f2..f44485483 100644 --- a/test/unit/pybind/test_symtab.py +++ b/test/unit/pybind/test_symtab.py @@ -6,16 +6,20 @@ import io from nmodl.dsl import ast, visitor, symtab + def test_symtab(ch_ast): v = symtab.SymtabVisitor() v.visit_program(ch_ast) s = ch_ast.get_symbol_table() - m = s.lookup('m') + m = s.lookup("m") assert m is not None assert m.get_name() == "m" - assert m.has_all_properties(symtab.NmodlType.state_var | symtab.NmodlType.prime_name) is True + assert ( + m.has_all_properties(symtab.NmodlType.state_var | symtab.NmodlType.prime_name) + is True + ) - mInf = s.lookup('mInf') + mInf = s.lookup("mInf") assert mInf is not None assert mInf.get_name() == "mInf" assert mInf.has_any_property(symtab.NmodlType.range_var) is True diff --git a/test/unit/pybind/test_visitor.py b/test/unit/pybind/test_visitor.py index 14329ba63..2ba566666 100644 --- a/test/unit/pybind/test_visitor.py +++ b/test/unit/pybind/test_visitor.py @@ -40,20 +40,29 @@ def test_json_visitor(ch_ast): # test compact json prime_str = nmodl.dsl.to_nmodl(primes[0]) prime_json = nmodl.dsl.to_json(primes[0], True) - assert prime_json == '{"PrimeName":[{"String":[{"name":"m"}]},{"Integer":[{"name":"1"}]}]}' + assert ( + prime_json + == '{"PrimeName":[{"String":[{"name":"m"}]},{"Integer":[{"name":"1"}]}]}' + ) # test json with expanded keys result_json = nmodl.dsl.to_json(primes[0], compact=True, expand=True) - expected_json = ('{"children":[{"children":[{"name":"m"}],' - '"name":"String"},{"children":[{"name":"1"}],' - '"name":"Integer"}],"name":"PrimeName"}') + expected_json = ( + '{"children":[{"children":[{"name":"m"}],' + '"name":"String"},{"children":[{"name":"1"}],' + '"name":"Integer"}],"name":"PrimeName"}' + ) assert result_json == expected_json # test json with nmodl embedded - result_json = nmodl.dsl.to_json(primes[0], compact=True, expand=True, add_nmodl=True) - expected_json = ('{"children":[{"children":[{"name":"m"}],"name":"String","nmodl":"m"},' - '{"children":[{"name":"1"}],"name":"Integer","nmodl":"1"}],' - '"name":"PrimeName","nmodl":"m\'"}') + result_json = nmodl.dsl.to_json( + primes[0], compact=True, expand=True, add_nmodl=True + ) + expected_json = ( + '{"children":[{"children":[{"name":"m"}],"name":"String","nmodl":"m"},' + '{"children":[{"name":"1"}],"name":"Integer","nmodl":"1"}],' + '"name":"PrimeName","nmodl":"m\'"}' + ) assert result_json == expected_json @@ -88,6 +97,7 @@ def test_modify_ast(): RANGE x } """ + class ModifyVisitor(visitor.AstVisitor): def __init__(self, old_name, new_name): visitor.AstVisitor.__init__(self) diff --git a/test/usecases/cnexp_array/simulate.py b/test/usecases/cnexp_array/simulate.py index fdcd1ea75..04f4fa031 100644 --- a/test/usecases/cnexp_array/simulate.py +++ b/test/usecases/cnexp_array/simulate.py @@ -20,7 +20,7 @@ t = np.array(t_hoc.as_numpy()) rate = (0.1 - 1.0) * (0.7 * 0.8 * 0.9) -x_exact = 42.0 * np.exp(rate*t) +x_exact = 42.0 * np.exp(rate * t) rel_err = np.abs(x - x_exact) / x_exact assert np.all(rel_err < 1e-12) diff --git a/test/usecases/cnexp_non_trivial/demo.py b/test/usecases/cnexp_non_trivial/demo.py new file mode 100644 index 000000000..ca316ce4a --- /dev/null +++ b/test/usecases/cnexp_non_trivial/demo.py @@ -0,0 +1,24 @@ +import numpy as np +import matplotlib.pyplot as plt + + +t_end = 4.0 + +n = 10000 +x_approx = np.empty(n) +x_approx[0] = 0.1 + +x_exact = np.empty(n) +x_exact[0] = 0.1 + +dt = t_end / n +t = np.linspace(0.0, t_end, n) + +for i in range(1, n): + x_approx[i] = 1.4142135623730951 * np.sqrt(dt + 0.5 * x_approx[i - 1] ** 2.0) + x_exact[i] = x_exact[i - 1] + dt * 1 / x_exact[i - 1] + +plt.plot(t, x_approx) +plt.plot(t, x_exact) + +plt.show() diff --git a/test/usecases/cnexp_non_trivial/simulate.py b/test/usecases/cnexp_non_trivial/simulate.py new file mode 100644 index 000000000..3245670ca --- /dev/null +++ b/test/usecases/cnexp_non_trivial/simulate.py @@ -0,0 +1,27 @@ +import numpy as np + +from neuron import h, gui +from neuron.units import ms + +nseg = 1 + +s = h.Section() +s.insert("leonhard") +s.nseg = nseg + +x_hoc = h.Vector().record(s(0.5)._ref_x_leonhard) +t_hoc = h.Vector().record(h._ref_t) + +h.stdinit() +h.tstop = 5.0 * ms +h.run() + +x = np.array(x_hoc.as_numpy()) +t = np.array(t_hoc.as_numpy()) + +x0 = 42.0 +x_exact = 42.0 * np.exp(-t) +rel_err = np.abs(x - x_exact) / x_exact + +assert np.all(rel_err < 1e-12) +print("leonhard: success") diff --git a/test/usecases/comments/simulate.py b/test/usecases/comments/simulate.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/test/usecases/comments/simulate.py @@ -0,0 +1 @@ + diff --git a/test/usecases/func_proc/simulate.py b/test/usecases/func_proc/simulate.py index ca1ffc73b..69162fef4 100644 --- a/test/usecases/func_proc/simulate.py +++ b/test/usecases/func_proc/simulate.py @@ -7,7 +7,7 @@ s.insert("test_func_proc") coords = [(0.5 + k) * 1.0 / nseg for k in range(nseg)] -values = [ 0.1 + k for k in range(nseg)] +values = [0.1 + k for k in range(nseg)] for x in coords: s(x).test_func_proc.set_x_42() diff --git a/test/usecases/global_breakpoint/simulate.py b/test/usecases/global_breakpoint/simulate.py index 1aabec09d..d7fc2bd05 100644 --- a/test/usecases/global_breakpoint/simulate.py +++ b/test/usecases/global_breakpoint/simulate.py @@ -20,7 +20,7 @@ t = np.array(t_hoc.as_numpy()) x_exact = 2.0 * np.ones_like(t) -x_exact[0] = 42; +x_exact[0] = 42 abs_err = np.abs(x - x_exact) assert np.all(abs_err < 1e-12), f"{abs_err=}" diff --git a/test/usecases/nonspecific_current/simulate.py b/test/usecases/nonspecific_current/simulate.py index ffa404832..6a192ef71 100644 --- a/test/usecases/nonspecific_current/simulate.py +++ b/test/usecases/nonspecific_current/simulate.py @@ -22,7 +22,7 @@ erev = 1.5 rate = 0.005 / 1e-3 v0 = -65.0 -v_exact = erev + (v0 - erev)*np.exp(-rate*t) +v_exact = erev + (v0 - erev) * np.exp(-rate * t) rel_err = np.abs(v - v_exact) / np.max(np.abs(v_exact)) assert np.all(rel_err < 1e-1), f"rel_err = {rel_err}" diff --git a/test/usecases/parameter/simulate.py b/test/usecases/parameter/simulate.py index 7fe9ad6f1..9dadf17f2 100644 --- a/test/usecases/parameter/simulate.py +++ b/test/usecases/parameter/simulate.py @@ -4,9 +4,8 @@ test_parameter_pp = h.test_parameter(s(0.5)) -assert test_parameter_pp.x == 42. +assert test_parameter_pp.x == 42.0 test_parameter_pp.x = 42.1 assert test_parameter_pp.x == 42.1 -