Skip to content

Commit

Permalink
Make "sparse" solver check if equations are linear.
Browse files Browse the repository at this point in the history
If the system is linear, then newtons method always converges
in exactly one iteration. When using the sparse solver on
linear systems omit the newtons iteration and solve directly.

This should make the resulting code run marginally faster by
skipping the check for convergence. Currently the check for
convergence is implemented as "error = sqrt(|F|^2)".
  • Loading branch information
ctrl-z-9000-times committed May 4, 2022
1 parent cde5dbf commit d868bd3
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 42 deletions.
16 changes: 15 additions & 1 deletion nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):

eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants)

linear = _is_linear(eqs, state_vars, sympy_vars)

custom_fcts = _get_custom_functions(function_calls)

jacobian = sp.Matrix(eqs).jacobian(state_vars)
Expand All @@ -291,7 +293,19 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):
# interweave
code = _interweave_eqs(vecFcode, vecJcode)

return code
return code, linear


def _is_linear(eqs, state_vars, sympy_vars):
for expr in eqs:
for x in state_vars:
for y in state_vars:
try:
if not sp.Eq(sp.diff(expr, x, y), 0):
return False
except TypeError:
return False
return True


def integrate2c(diff_string, dt_var, vars, use_pade_approx=False):
Expand Down
2 changes: 2 additions & 0 deletions src/pybind/pyembed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ struct SolveNonLinearSystemExecutor: public PythonExecutor {
// output
// returns a vector of solutions, i.e. new statements to add to block:
std::vector<std::string> solutions;
// returns if the system is linear or not.
bool linear;
// may also return a python exception message:
std::string exception_message;

Expand Down
4 changes: 3 additions & 1 deletion src/pybind/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,22 @@ void SolveNonLinearSystemExecutor::operator()() {
from nmodl.ode import solve_non_lin_system
exception_message = ""
try:
solutions = solve_non_lin_system(equation_strings,
solutions, linear = solve_non_lin_system(equation_strings,
state_vars,
vars,
function_calls)
except Exception as e:
# if we fail, fail silently and return empty string
solutions = [""]
linear = False
new_local_vars = [""]
exception_message = str(e)
)",
py::globals(),
locals);
// returns a vector of solutions, i.e. new statements to add to block:
solutions = locals["solutions"].cast<std::vector<std::string>>();
linear = locals["linear"].cast<bool>();
// may also return a python exception message:
exception_message = locals["exception_message"].cast<std::string>();
}
Expand Down
10 changes: 8 additions & 2 deletions src/visitors/sympy_solver_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ void SympySolverVisitor::solve_non_linear_system(
(*solver)();
// returns a vector of solutions, i.e. new statements to add to block:
auto solutions = solver->solutions;
bool linear = solver->linear;
// may also return a python exception message:
auto exception_message = solver->exception_message;
pywrap::EmbeddedPythonLoader::get_instance().api()->destroy_nsls_executor(solver);
Expand All @@ -364,8 +365,13 @@ void SympySolverVisitor::solve_non_linear_system(
exception_message);
return;
}
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
construct_eigen_solver_block(pre_solve_statements, solutions, false);
if (!linear) {
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
}
else {
logger->debug("SympySolverVisitor :: Constructing eigen solve block");
}
construct_eigen_solver_block(pre_solve_statements, solutions, linear);
}

void SympySolverVisitor::visit_var_name(ast::VarName& node) {
Expand Down
60 changes: 22 additions & 38 deletions test/unit/visitor/sympy_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
)";
std::string expected_result = R"(
DERIVATIVE states {
EIGEN_NEWTON_SOLVE[1]{
EIGEN_LINEAR_SOLVE[1]{
LOCAL old_m
}{
IF (mInf == 1) {
Expand All @@ -628,7 +628,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
old_m = m
}{
nmodl_eigen_x[0] = m
}{
nmodl_eigen_f[0] = (-nmodl_eigen_x[0]*dt+dt*mInf+mTau*(-nmodl_eigen_x[0]+old_m))/mTau
nmodl_eigen_j[0] = -(dt+mTau)/mTau
}{
Expand Down Expand Up @@ -659,15 +658,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
})";
std::string expected_result = R"(
DERIVATIVE states {
EIGEN_NEWTON_SOLVE[2]{
EIGEN_LINEAR_SOLVE[2]{
LOCAL a, b, old_y, old_x
}{
old_y = y
old_x = x
}{
nmodl_eigen_x[0] = x
nmodl_eigen_x[1] = y
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[1]+a*dt+old_y
nmodl_eigen_j[0] = 0
nmodl_eigen_j[2] = -1.0
Expand Down Expand Up @@ -703,15 +701,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
})";
std::string expected_result = R"(
DERIVATIVE states {
EIGEN_NEWTON_SOLVE[2]{
EIGEN_LINEAR_SOLVE[2]{
LOCAL a, b, old_M_1, old_M_0
}{
old_M_1 = M[1]
old_M_0 = M[0]
}{
nmodl_eigen_x[0] = M[0]
nmodl_eigen_x[1] = M[1]
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[1]+a*dt+old_M_1
nmodl_eigen_j[0] = 0
nmodl_eigen_j[2] = -1.0
Expand Down Expand Up @@ -748,15 +745,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
})";
std::string expected_result = R"(
DERIVATIVE states {
EIGEN_NEWTON_SOLVE[2]{
EIGEN_LINEAR_SOLVE[2]{
LOCAL a, b, old_x, old_y
}{
old_x = x
old_y = y
}{
nmodl_eigen_x[0] = x
nmodl_eigen_x[1] = y
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[0]+a*dt+old_x
nmodl_eigen_j[0] = -1.0
nmodl_eigen_j[2] = 0
Expand Down Expand Up @@ -825,15 +821,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
DERIVATIVE states {
LOCAL a, b
IF (a == 1) {
EIGEN_NEWTON_SOLVE[2]{
EIGEN_LINEAR_SOLVE[2]{
LOCAL old_x, old_y
}{
old_x = x
old_y = y
}{
nmodl_eigen_x[0] = x
nmodl_eigen_x[1] = y
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[0]+a*dt+old_x
nmodl_eigen_j[0] = -1.0
nmodl_eigen_j[2] = 0
Expand Down Expand Up @@ -875,15 +870,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
})";
std::string expected_result = R"(
DERIVATIVE states {
EIGEN_NEWTON_SOLVE[2]{
EIGEN_LINEAR_SOLVE[2]{
LOCAL a, b, old_x, old_y
}{
old_x = x
old_y = y
}{
nmodl_eigen_x[0] = x
nmodl_eigen_x[1] = y
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[1]*a*dt+b*dt+old_x
nmodl_eigen_j[0] = -1.0
nmodl_eigen_j[2] = a*dt
Expand All @@ -901,15 +895,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
})";
std::string expected_result_cse = R"(
DERIVATIVE states {
EIGEN_NEWTON_SOLVE[2]{
EIGEN_LINEAR_SOLVE[2]{
LOCAL a, b, old_x, old_y
}{
old_x = x
old_y = y
}{
nmodl_eigen_x[0] = x
nmodl_eigen_x[1] = y
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[1]*a*dt+b*dt+old_x
nmodl_eigen_j[0] = -1.0
nmodl_eigen_j[2] = a*dt
Expand Down Expand Up @@ -954,7 +947,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
)";
std::string expected_result = R"(
DERIVATIVE states {
EIGEN_NEWTON_SOLVE[3]{
EIGEN_LINEAR_SOLVE[3]{
LOCAL a, b, c, d, h, old_x, old_y, old_z
}{
old_x = x
Expand All @@ -964,7 +957,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
nmodl_eigen_x[0] = x
nmodl_eigen_x[1] = y
nmodl_eigen_x[2] = z
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[2]*a*dt+b*dt*h+old_x
nmodl_eigen_j[0] = -1.0
nmodl_eigen_j[3] = 0
Expand All @@ -986,7 +978,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
})";
std::string expected_cse_result = R"(
DERIVATIVE states {
EIGEN_NEWTON_SOLVE[3]{
EIGEN_LINEAR_SOLVE[3]{
LOCAL a, b, c, d, h, old_x, old_y, old_z
}{
old_x = x
Expand All @@ -996,7 +988,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
nmodl_eigen_x[0] = x
nmodl_eigen_x[1] = y
nmodl_eigen_x[2] = z
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[2]*a*dt+b*dt*h+old_x
nmodl_eigen_j[0] = -1.0
nmodl_eigen_j[3] = 0
Expand Down Expand Up @@ -1042,15 +1033,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
)";
std::string expected_result = R"(
DERIVATIVE scheme1 {
EIGEN_NEWTON_SOLVE[2]{
EIGEN_LINEAR_SOLVE[2]{
LOCAL old_mc, old_m
}{
old_mc = mc
old_m = m
}{
nmodl_eigen_x[0] = mc
nmodl_eigen_x[1] = m
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc
nmodl_eigen_j[0] = -a*dt-1.0
nmodl_eigen_j[2] = b*dt
Expand Down Expand Up @@ -1086,14 +1076,13 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
)";
std::string expected_result = R"(
DERIVATIVE scheme1 {
EIGEN_NEWTON_SOLVE[2]{
EIGEN_LINEAR_SOLVE[2]{
LOCAL old_mc
}{
old_mc = mc
}{
nmodl_eigen_x[0] = mc
nmodl_eigen_x[1] = m
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc
nmodl_eigen_j[0] = -a*dt-1.0
nmodl_eigen_j[2] = b*dt
Expand Down Expand Up @@ -1131,15 +1120,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
)";
std::string expected_result = R"(
DERIVATIVE scheme1 {
EIGEN_NEWTON_SOLVE[2]{
EIGEN_LINEAR_SOLVE[2]{
LOCAL old_mc, old_m
}{
old_mc = mc
old_m = m
}{
nmodl_eigen_x[0] = mc
nmodl_eigen_x[1] = m
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc
nmodl_eigen_j[0] = -a*dt-1.0
nmodl_eigen_j[2] = b*dt
Expand Down Expand Up @@ -1180,7 +1168,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
})";
std::string expected_result = R"(
DERIVATIVE ihkin {
EIGEN_NEWTON_SOLVE[5]{
EIGEN_LINEAR_SOLVE[5]{
LOCAL alpha, beta, k3p, k4, k1ca, k2, old_c1, old_o1, old_p0
}{
evaluate_fct(v, cai)
Expand All @@ -1193,7 +1181,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
nmodl_eigen_x[2] = o2
nmodl_eigen_x[3] = p0
nmodl_eigen_x[4] = p1
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*alpha*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*beta*dt+old_c1
nmodl_eigen_j[0] = -alpha*dt-1.0
nmodl_eigen_j[5] = beta*dt
Expand Down Expand Up @@ -1260,13 +1247,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
)";
std::string expected_result = R"(
DERIVATIVE scheme1 {
EIGEN_NEWTON_SOLVE[1]{
EIGEN_LINEAR_SOLVE[1]{
LOCAL old_W_0
}{
old_W_0 = W[0]
}{
nmodl_eigen_x[0] = W[0]
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]+nmodl_eigen_x[0]*dt*B[0]-nmodl_eigen_x[0]+3.0*dt*A[1]+old_W_0
nmodl_eigen_j[0] = -dt*A[0]+dt*B[0]-1.0
}{
Expand Down Expand Up @@ -1300,15 +1286,14 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
)";
std::string expected_result = R"(
DERIVATIVE scheme1 {
EIGEN_NEWTON_SOLVE[2]{
EIGEN_LINEAR_SOLVE[2]{
LOCAL old_M_0, old_M_1
}{
old_M_0 = M[0]
old_M_1 = M[1]
}{
nmodl_eigen_x[0] = M[0]
nmodl_eigen_x[1] = M[1]
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]-nmodl_eigen_x[0]+nmodl_eigen_x[1]*dt*B[0]+old_M_0
nmodl_eigen_j[0] = -dt*A[0]-1.0
nmodl_eigen_j[2] = dt*B[0]
Expand Down Expand Up @@ -1346,13 +1331,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
)";
std::string expected_result = R"(
DERIVATIVE scheme1 {
EIGEN_NEWTON_SOLVE[1]{
EIGEN_LINEAR_SOLVE[1]{
LOCAL old_W_0
}{
old_W_0 = W[0]
}{
nmodl_eigen_x[0] = W[0]
}{
nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]+nmodl_eigen_x[0]*dt*B[0]-nmodl_eigen_x[0]+3.0*dt*A[1]+old_W_0
nmodl_eigen_j[0] = -dt*A[0]+dt*B[0]-1.0
}{
Expand Down Expand Up @@ -2053,7 +2037,7 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s
x
}
NONLINEAR nonlin {
~ x = 5
~ x * x * x = 5
})";
std::string expected_text = R"(
NONLINEAR nonlin {
Expand All @@ -2062,8 +2046,8 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s
}{
nmodl_eigen_x[0] = x
}{
nmodl_eigen_f[0] = 5.0-nmodl_eigen_x[0]
nmodl_eigen_j[0] = -1.0
nmodl_eigen_f[0] = 5.0-pow(nmodl_eigen_x[0], 3)
nmodl_eigen_j[0] = -3.0 * pow(nmodl_eigen_x[0], 2)
}{
x = nmodl_eigen_x[0]
}{
Expand All @@ -2084,7 +2068,7 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s
NONLINEAR nonlin {
~ s[0] = 1
~ s[1] = 3
~ s[2] + s[1] = s[0]
~ s[2] + s[1] = s[0] * s[0]
})";
std::string expected_text = R"(
NONLINEAR nonlin {
Expand All @@ -2097,14 +2081,14 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s
}{
nmodl_eigen_f[0] = 1.0-nmodl_eigen_x[0]
nmodl_eigen_f[1] = 3.0-nmodl_eigen_x[1]
nmodl_eigen_f[2] = nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]
nmodl_eigen_f[2] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]-nmodl_eigen_x[2]
nmodl_eigen_j[0] = -1.0
nmodl_eigen_j[3] = 0
nmodl_eigen_j[6] = 0
nmodl_eigen_j[1] = 0
nmodl_eigen_j[4] = -1.0
nmodl_eigen_j[7] = 0
nmodl_eigen_j[2] = 1.0
nmodl_eigen_j[2] = 2.0 * nmodl_eigen_x[0]
nmodl_eigen_j[5] = -1.0
nmodl_eigen_j[8] = -1.0
}{
Expand Down

0 comments on commit d868bd3

Please sign in to comment.