Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable formatting Python files. #1207

Merged
merged 2 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .bbp-project.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,9 @@ tools:
match:
- ext/.*
- src/language/templates/*
Black:
1uc marked this conversation as resolved.
Show resolved Hide resolved
enable: True
version: ~=24.2.0
include:
match:
- .*\.py$
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
(master_doc, "nmodl.tex", "nmodl Documentation", "BlueBrain HPC team", "manual")
]

imgmath_image_format = 'svg'
imgmath_image_format = "svg"
imgmath_embed = True
imgmath_font_size = 14

Expand Down
1 change: 1 addition & 0 deletions python/nmodl/ast.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Module for vizualization of NMODL abstract syntax trees (ASTs).
"""

import getpass
import json
import os
Expand Down
79 changes: 61 additions & 18 deletions python/nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# accessed through regular imports
major, minor = (int(v) for v in sp.__version__.split(".")[:2])
if major >= 1 and minor >= 7:
known_functions = import_module('sympy.printing.c').known_functions_C99
known_functions = import_module("sympy.printing.c").known_functions_C99
else:
known_functions = import_module('sympy.printing.ccode').known_functions_C99
known_functions.pop('Abs')
known_functions['abs'] = 'fabs'
known_functions = import_module("sympy.printing.ccode").known_functions_C99
known_functions.pop("Abs")
known_functions["abs"] = "fabs"


if not ((major >= 1) and (minor >= 2)):
Expand All @@ -29,7 +29,18 @@
# Some functions are protected inside sympy, if user has declared such a function, it will fail
# because sympy will try to use its own internal one.
# Rename it before and after to a single name
forbidden_var = ["beta", "gamma", "uppergamma", "lowergamma", "polygamma", "loggamma", "digamma", "trigamma"]
forbidden_var = [
"beta",
"gamma",
"uppergamma",
"lowergamma",
"polygamma",
"loggamma",
"digamma",
"trigamma",
]


def search_and_replace_protected_functions_to_sympy(eqs, function_calls):
for c in function_calls:
if c in forbidden_var:
Expand All @@ -38,13 +49,15 @@ def search_and_replace_protected_functions_to_sympy(eqs, function_calls):
eqs = [re.sub(r, f, x) for x in eqs]
return eqs


def search_and_replace_protected_functions_from_sympy(eqs, function_calls):
for c in function_calls:
if c in forbidden_var:
r = f"_sympy_{c}_fun"
eqs = [re.sub(r, f"{c}", x) for x in eqs]
return eqs


def _get_custom_functions(fcts):
custom_functions = {}
for f in fcts:
Expand Down Expand Up @@ -143,13 +156,16 @@ def _sympify_eqs(eq_strings, state_vars, vars):
for state_var in state_vars:
sympy_state_vars.append(sp.sympify(state_var, locals=sympy_vars))
eqs = [
(sp.sympify(eq.split("=", 1)[1], locals=sympy_vars)
- sp.sympify(eq.split("=", 1)[0], locals=sympy_vars)).expand()
(
sp.sympify(eq.split("=", 1)[1], locals=sympy_vars)
- sp.sympify(eq.split("=", 1)[0], locals=sympy_vars)
).expand()
for eq in eq_strings
]

return eqs, sympy_state_vars, sympy_vars


def _interweave_eqs(F, J):
"""Interweave F and J equations so that they are printed in code
rowwise from the equation J x = F. For example:
Expand Down Expand Up @@ -199,13 +215,21 @@ def _interweave_eqs(F, J):
n = len(F)
for i, expr in enumerate(F):
code.append(expr)
for j in range(i * n, (i+1) * n):
for j in range(i * n, (i + 1) * n):
code.append(J[j])

return code


def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_prefix, small_system=False, do_cse=False):
def solve_lin_system(
eq_strings,
vars,
constants,
function_calls,
tmp_unique_prefix,
small_system=False,
do_cse=False,
):
"""Solve linear system of equations, return solution as C code.

If system is small (small_system=True, typically N<=3):
Expand Down Expand Up @@ -233,7 +257,9 @@ def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_pre
vars: list of strings containing new local variables
"""

eq_strings = search_and_replace_protected_functions_to_sympy(eq_strings, function_calls)
eq_strings = search_and_replace_protected_functions_to_sympy(
eq_strings, function_calls
)

eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants)
custom_fcts = _get_custom_functions(function_calls)
Expand All @@ -246,7 +272,9 @@ def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_pre
solution_vector = sp.linsolve(eqs, state_vars).args[0]
if do_cse:
# generate prefix for new local vars that avoids clashes
my_symbols = sp.utilities.iterables.numbered_symbols(prefix=tmp_unique_prefix + '_')
my_symbols = sp.utilities.iterables.numbered_symbols(
prefix=tmp_unique_prefix + "_"
)
sub_exprs, simplified_solution_vector = sp.cse(
solution_vector,
symbols=my_symbols,
Expand All @@ -255,10 +283,14 @@ def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_pre
)
for var, expr in sub_exprs:
new_local_vars.append(sp.ccode(var))
code.append(f"{var} = {sp.ccode(expr.evalf(), user_functions=custom_fcts)}")
code.append(
f"{var} = {sp.ccode(expr.evalf(), user_functions=custom_fcts)}"
)
solution_vector = simplified_solution_vector[0]
for var, expr in zip(state_vars, solution_vector):
code.append(f"{sp.ccode(var)} = {sp.ccode(expr.evalf(), contract=False, user_functions=custom_fcts)}")
code.append(
f"{sp.ccode(var)} = {sp.ccode(expr.evalf(), contract=False, user_functions=custom_fcts)}"
)
else:
# large linear system: construct and return matrix J, vector F such that
# J X = F is the linear system to be solved for X by e.g. LU factorization
Expand All @@ -267,13 +299,17 @@ def solve_lin_system(eq_strings, vars, constants, function_calls, tmp_unique_pre
# construct vector F
vecFcode = []
for i, expr in enumerate(vecF):
vecFcode.append(f"F[{i}] = {sp.ccode(expr.simplify().evalf(), user_functions=custom_fcts)}")
vecFcode.append(
f"F[{i}] = {sp.ccode(expr.simplify().evalf(), user_functions=custom_fcts)}"
)
# construct matrix J
vecJcode = []
for i, expr in enumerate(matJ):
# todo: fix indexing to be ascending order
flat_index = matJ.rows * (i % matJ.rows) + (i // matJ.rows)
vecJcode.append(f"J[{flat_index}] = {sp.ccode(expr.simplify().evalf(), user_functions=custom_fcts)}")
vecJcode.append(
f"J[{flat_index}] = {sp.ccode(expr.simplify().evalf(), user_functions=custom_fcts)}"
)
# interweave
code = _interweave_eqs(vecFcode, vecJcode)

Expand All @@ -299,7 +335,9 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):
List of strings containing assignment statements
"""

eq_strings = search_and_replace_protected_functions_to_sympy(eq_strings, function_calls)
eq_strings = search_and_replace_protected_functions_to_sympy(
eq_strings, function_calls
)

eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants)
custom_fcts = _get_custom_functions(function_calls)
Expand All @@ -310,13 +348,18 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):

vecFcode = []
for i, eq in enumerate(eqs):
vecFcode.append(f"F[{i}] = {sp.ccode(eq.simplify().subs(X_vec_map).evalf(), user_functions=custom_fcts)}")
vecFcode.append(
f"F[{i}] = {sp.ccode(eq.simplify().subs(X_vec_map).evalf(), user_functions=custom_fcts)}"
)

vecJcode = []
for i, j in itertools.product(range(jacobian.rows), range(jacobian.cols)):
flat_index = i + jacobian.rows * j

rhs = sp.ccode(jacobian[i,j].simplify().subs(X_vec_map).evalf(), user_functions=custom_fcts)
rhs = sp.ccode(
jacobian[i, j].simplify().subs(X_vec_map).evalf(),
user_functions=custom_fcts,
)
vecJcode.append(f"J[{flat_index}] = {rhs}")

# interweave
Expand Down
1 change: 1 addition & 0 deletions src/language/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# SPDX-License-Identifier: Apache-2.0


class Argument:
"""Utility class for holding all arguments for node classes"""

Expand Down
22 changes: 16 additions & 6 deletions src/language/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def jinja_template(self, path):
return self.jinja_env.get_template(name)

def _cmake_deps_task(self, tasks):
""""Construct the JinjaTask generating the CMake file exporting all dependencies
""" "Construct the JinjaTask generating the CMake file exporting all dependencies

Args:
tasks: list of JinjaTask objects
Expand Down Expand Up @@ -196,12 +196,18 @@ def workload(self):
task = JinjaTask(
app=self,
input=filepath,
output=self.base_dir / sub_dir / "pynode_{}.cpp".format(chunk_k),
output=self.base_dir
/ sub_dir
/ "pynode_{}.cpp".format(chunk_k),
context=dict(
nodes=self.nodes[
chunk_k * chunk_length : (chunk_k + 1) * chunk_length
chunk_k
* chunk_length : (chunk_k + 1)
* chunk_length
],
setup_pybind_method="init_pybind_classes_{}".format(chunk_k),
setup_pybind_method="init_pybind_classes_{}".format(
chunk_k
),
),
extradeps=extradeps[filepath],
)
Expand All @@ -212,7 +218,11 @@ def workload(self):
app=self,
input=filepath,
output=self.base_dir / sub_dir / filepath.name,
context=dict(nodes=self.nodes, node_info=node_info, **extracontext[filepath]),
context=dict(
nodes=self.nodes,
node_info=node_info,
**extracontext[filepath],
),
extradeps=extradeps[filepath],
)
tasks.append(task)
Expand All @@ -235,7 +245,7 @@ class JinjaTask(
"""

def execute(self):
""""Perform the Jinja task
""" "Perform the Jinja task

Execute Jinja renderer if the output file is out-of-date.

Expand Down
Loading
Loading