From 2636b4e67e4333b0c92985284917805506923e1c Mon Sep 17 00:00:00 2001 From: Luc Grosheintz Date: Wed, 12 Jun 2024 10:05:34 +0200 Subject: [PATCH] Modernize NMODL -> Python layer. (#1306) The changes are: * Don't forcefully set the PYTHONPATH. * Replace false object oriented code with a function (each). * Replace a global variable, with a function to initialize. * Use a reference for a pointer that's should never be a nullptr. --- src/main.cpp | 4 +- src/pybind/CMakeLists.txt | 3 + src/pybind/pyembed.cpp | 80 +++++++++----- src/pybind/pyembed.hpp | 120 +-------------------- src/pybind/wrapper.cpp | 119 +++++++++----------- src/pybind/wrapper.hpp | 62 +++++++++++ src/visitors/main.cpp | 4 +- src/visitors/sympy_conductance_visitor.cpp | 10 +- src/visitors/sympy_solver_visitor.cpp | 60 +++-------- test/unit/codegen/main.cpp | 4 +- test/unit/visitor/main.cpp | 4 +- 11 files changed, 198 insertions(+), 272 deletions(-) create mode 100644 src/pybind/wrapper.hpp diff --git a/src/main.cpp b/src/main.cpp index 96a8ec964f..f90c7d5879 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -486,7 +486,7 @@ int main(int argc, const char* argv[]) { if (sympy_conductance || sympy_analytic || sparse_solver_exists(*ast)) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() .api() - ->initialize_interpreter(); + .initialize_interpreter(); if (sympy_conductance) { logger->info("Running sympy conductance visitor"); SympyConductanceVisitor().visit_program(*ast); @@ -507,7 +507,7 @@ int main(int argc, const char* argv[]) { } nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() .api() - ->finalize_interpreter(); + .finalize_interpreter(); } { diff --git a/src/pybind/CMakeLists.txt b/src/pybind/CMakeLists.txt index eb2ce323a5..9439ea4f32 100644 --- a/src/pybind/CMakeLists.txt +++ b/src/pybind/CMakeLists.txt @@ -30,6 +30,7 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ode_py.hpp.inc ${CMAKE_CURRENT_BINARY add_library(pyembed pyembed.cpp) set_property(TARGET pyembed PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(pyembed PRIVATE util) +target_link_libraries(pyembed PRIVATE fmt::fmt) if(NOT LINK_AGAINST_PYTHON) add_library(pywrapper SHARED ${CMAKE_CURRENT_SOURCE_DIR}/wrapper.cpp) @@ -41,6 +42,8 @@ else() target_compile_definitions(pyembed PRIVATE NMODL_STATIC_PYWRAPPER=1) endif() +target_link_libraries(pywrapper PRIVATE fmt::fmt) + target_include_directories(pyembed PRIVATE ${PYBIND11_INCLUDE_DIR} ${PYTHON_INCLUDE_DIRS}) target_include_directories(pywrapper PRIVATE ${pybind11_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) target_include_directories(pywrapper PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/src/pybind/pyembed.cpp b/src/pybind/pyembed.cpp index fd20bb101e..7e4bfb8dcc 100644 --- a/src/pybind/pyembed.cpp +++ b/src/pybind/pyembed.cpp @@ -4,13 +4,15 @@ * * SPDX-License-Identifier: Apache-2.0 */ +#include "pybind/pyembed.hpp" #include #include #include +#include + #include "config/config.h" -#include "pybind/pyembed.hpp" #include "utils/logger.hpp" #define STRINGIFY(x) #x @@ -22,15 +24,43 @@ namespace nmodl { namespace pybind_wrappers { +using nmodl_init_pybind_wrapper_api_fpointer = decltype(&nmodl_init_pybind_wrapper_api); + bool EmbeddedPythonLoader::have_wrappers() { #if defined(NMODL_STATIC_PYWRAPPER) - static auto wrapper_api = nmodl::pybind_wrappers::init_pybind_wrap_api(); - wrappers = &wrapper_api; - return true; + auto* init = &nmodl_init_pybind_wrapper_api; #else - wrappers = static_cast(dlsym(RTLD_DEFAULT, "nmodl_wrapper_api")); - return wrappers != nullptr; + auto* init = (nmodl_init_pybind_wrapper_api_fpointer) (dlsym(RTLD_DEFAULT, + "nmodl_init_pybind_wrapper_api")); #endif + + if (init != nullptr) { + wrappers = init(); + } + + return init != nullptr; +} + +void assert_compatible_python_versions() { + // This code is imported and slightly modified from PyBind11 because this + // is primarly in details for internal usage + // License of PyBind11 is BSD-style + + std::string compiled_ver = fmt::format("{}.{}", PY_MAJOR_VERSION, PY_MINOR_VERSION); + auto pPy_GetVersion = (const char* (*) (void) ) dlsym(RTLD_DEFAULT, "Py_GetVersion"); + if (pPy_GetVersion == nullptr) { + throw std::runtime_error("Unable to find the function `Py_GetVersion`"); + } + const char* runtime_ver = pPy_GetVersion(); + std::size_t len = compiled_ver.size(); + if (std::strncmp(runtime_ver, compiled_ver.c_str(), len) != 0 || + (runtime_ver[len] >= '0' && runtime_ver[len] <= '9')) { + throw std::runtime_error( + fmt::format("Python version mismatch. nmodl has been compiled with python {} and is " + "being run with python {}", + compiled_ver, + runtime_ver)); + } } void EmbeddedPythonLoader::load_libraries() { @@ -49,26 +79,7 @@ void EmbeddedPythonLoader::load_libraries() { throw std::runtime_error("Failed to dlopen"); } - // This code is imported from PyBind11 because this is primarly in details for internal usage - // License of PyBind11 is BSD-style - { - std::string compiled_ver = fmt::format("{}.{}", PY_MAJOR_VERSION, PY_MINOR_VERSION); - const char* (*fun)(void) = (const char* (*) (void) ) dlsym(pylib_handle, "Py_GetVersion"); - if (fun == nullptr) { - logger->critical("Unable to find the function `Py_GetVersion`"); - throw std::runtime_error("Unable to find the function `Py_GetVersion`"); - } - const char* runtime_ver = fun(); - std::size_t len = compiled_ver.size(); - if (std::strncmp(runtime_ver, compiled_ver.c_str(), len) != 0 || - (runtime_ver[len] >= '0' && runtime_ver[len] <= '9')) { - logger->critical( - "nmodl has been compiled with python {} and is being run with python {}", - compiled_ver, - runtime_ver); - throw std::runtime_error("Python version mismatch between compile-time and runtime."); - } - } + assert_compatible_python_versions(); if (std::getenv("NMODLHOME") == nullptr) { logger->critical("NMODLHOME environment variable must be set to load embedded python"); @@ -91,25 +102,36 @@ void EmbeddedPythonLoader::load_libraries() { } void EmbeddedPythonLoader::populate_symbols() { - wrappers = static_cast(dlsym(pybind_wrapper_handle, "nmodl_wrapper_api")); - if (!wrappers) { +#if defined(NMODL_STATIC_PYWRAPPER) + auto* init = &nmodl_init_pybind_wrapper_api; +#else + // By now it's been dynamically loaded with `RTLD_GLOBAL`. + auto* init = (nmodl_init_pybind_wrapper_api_fpointer) (dlsym(RTLD_DEFAULT, + "nmodl_init_pybind_wrapper_api")); +#endif + + if (!init) { const auto errstr = dlerror(); logger->critical("Tried but failed to load pybind wrapper symbols"); logger->critical(errstr); throw std::runtime_error("Failed to dlsym"); } + + wrappers = init(); } void EmbeddedPythonLoader::unload() { if (pybind_wrapper_handle) { dlclose(pybind_wrapper_handle); + pybind_wrapper_handle = nullptr; } if (pylib_handle) { dlclose(pylib_handle); + pylib_handle = nullptr; } } -const pybind_wrap_api* EmbeddedPythonLoader::api() { +const pybind_wrap_api& EmbeddedPythonLoader::api() { return wrappers; } diff --git a/src/pybind/pyembed.hpp b/src/pybind/pyembed.hpp index 851ff9240c..5c0acf4819 100644 --- a/src/pybind/pyembed.hpp +++ b/src/pybind/pyembed.hpp @@ -7,123 +7,11 @@ #pragma once -#include - -#include -#include -#include -#include +#include "wrapper.hpp" namespace nmodl { namespace pybind_wrappers { - -struct PythonExecutor { - virtual ~PythonExecutor() {} - - virtual void operator()() = 0; -}; - - -struct SolveLinearSystemExecutor: public PythonExecutor { - // input - std::vector eq_system; - std::vector state_vars; - std::set vars; - bool small_system; - bool elimination; - // This is used only if elimination is true. It gives the root for the tmp variables - std::string tmp_unique_prefix; - std::set function_calls; - // output - // returns a vector of solutions, i.e. new statements to add to block: - std::vector solutions; - // and a vector of new local variables that need to be declared in the block: - std::vector new_local_vars; - // may also return a python exception message: - std::string exception_message; - // executor function - void operator()() override; -}; - - -struct SolveNonLinearSystemExecutor: public PythonExecutor { - // input - std::vector eq_system; - std::vector state_vars; - std::set vars; - std::set function_calls; - // output - // returns a vector of solutions, i.e. new statements to add to block: - std::vector solutions; - // may also return a python exception message: - std::string exception_message; - - // executor function - void operator()() override; -}; - - -struct DiffeqSolverExecutor: public PythonExecutor { - // input - std::string node_as_nmodl; - std::string dt_var; - std::set vars; - bool use_pade_approx; - std::set function_calls; - std::string method; - // output - // returns solution, i.e. new statement to add to block: - std::string solution; - // may also return a python exception message: - std::string exception_message; - - // executor function - void operator()() override; -}; - - -struct AnalyticDiffExecutor: public PythonExecutor { - // input - std::vector expressions; - std::set used_names_in_block; - // output - // returns solution, i.e. new statement to add to block: - std::string solution; - // may also return a python exception message: - std::string exception_message; - - // executor function - void operator()() override; -}; - - -SolveLinearSystemExecutor* create_sls_executor_func(); -SolveNonLinearSystemExecutor* create_nsls_executor_func(); -DiffeqSolverExecutor* create_des_executor_func(); -AnalyticDiffExecutor* create_ads_executor_func(); -void destroy_sls_executor_func(SolveLinearSystemExecutor* exec); -void destroy_nsls_executor_func(SolveNonLinearSystemExecutor* exec); -void destroy_des_executor_func(DiffeqSolverExecutor* exec); -void destroy_ads_executor_func(AnalyticDiffExecutor* exec); - -void initialize_interpreter_func(); -void finalize_interpreter_func(); - -struct pybind_wrap_api { - decltype(&initialize_interpreter_func) initialize_interpreter; - decltype(&finalize_interpreter_func) finalize_interpreter; - decltype(&create_sls_executor_func) create_sls_executor; - decltype(&create_nsls_executor_func) create_nsls_executor; - decltype(&create_des_executor_func) create_des_executor; - decltype(&create_ads_executor_func) create_ads_executor; - decltype(&destroy_sls_executor_func) destroy_sls_executor; - decltype(&destroy_nsls_executor_func) destroy_nsls_executor; - decltype(&destroy_des_executor_func) destroy_des_executor; - decltype(&destroy_ads_executor_func) destroy_ads_executor; -}; - - /** * A singleton class handling access to the pybind_wrap_api struct * @@ -154,14 +42,14 @@ class EmbeddedPythonLoader { * Get access to the container struct for the pointers to the functions in the wrapper library. * @return a pybind_wrap_api pointer */ - const pybind_wrap_api* api(); + const pybind_wrap_api& api(); ~EmbeddedPythonLoader() { unload(); } private: - pybind_wrap_api* wrappers = nullptr; + pybind_wrap_api wrappers; void* pylib_handle = nullptr; void* pybind_wrapper_handle = nullptr; @@ -180,7 +68,5 @@ class EmbeddedPythonLoader { }; -pybind_wrap_api init_pybind_wrap_api() noexcept; - } // namespace pybind_wrappers } // namespace nmodl diff --git a/src/pybind/wrapper.cpp b/src/pybind/wrapper.cpp index 2d21859201..3c9f2661a5 100644 --- a/src/pybind/wrapper.cpp +++ b/src/pybind/wrapper.cpp @@ -5,6 +5,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include "wrapper.hpp" + #include "codegen/codegen_naming.hpp" #include "pybind/pyembed.hpp" @@ -22,7 +24,14 @@ using namespace py::literals; namespace nmodl { namespace pybind_wrappers { -void SolveLinearSystemExecutor::operator()() { +std::tuple, std::vector, std::string> +call_solve_linear_system(const std::vector& eq_system, + const std::vector& state_vars, + const std::set& vars, + bool small_system, + bool elimination, + const std::string& tmp_unique_prefix, + const std::set& function_calls) { const auto locals = py::dict("eq_strings"_a = eq_system, "state_vars"_a = state_vars, "vars"_a = vars, @@ -49,15 +58,21 @@ except Exception as e: py::exec(nmodl::pybind_wrappers::ode_py + script, locals); // returns a vector of solutions, i.e. new statements to add to block: - solutions = locals["solutions"].cast>(); + auto solutions = locals["solutions"].cast>(); // and a vector of new local variables that need to be declared in the block: - new_local_vars = locals["new_local_vars"].cast>(); + auto new_local_vars = locals["new_local_vars"].cast>(); // may also return a python exception message: - exception_message = locals["exception_message"].cast(); + auto exception_message = locals["exception_message"].cast(); + + return {std::move(solutions), std::move(new_local_vars), std::move(exception_message)}; } -void SolveNonLinearSystemExecutor::operator()() { +std::tuple, std::string> call_solve_nonlinear_system( + const std::vector& eq_system, + const std::vector& state_vars, + const std::set& vars, + const std::set& function_calls) { const auto locals = py::dict("equation_strings"_a = eq_system, "state_vars"_a = state_vars, "vars"_a = vars, @@ -78,12 +93,20 @@ except Exception as e: py::exec(nmodl::pybind_wrappers::ode_py + script, locals); // returns a vector of solutions, i.e. new statements to add to block: - solutions = locals["solutions"].cast>(); + auto solutions = locals["solutions"].cast>(); // may also return a python exception message: - exception_message = locals["exception_message"].cast(); + auto exception_message = locals["exception_message"].cast(); + + return {std::move(solutions), std::move(exception_message)}; } -void DiffeqSolverExecutor::operator()() { + +std::tuple call_diffeq_solver(const std::string& node_as_nmodl, + const std::string& dt_var, + const std::set& vars, + bool use_pade_approx, + const std::set& function_calls, + const std::string& method) { const auto locals = py::dict("equation_string"_a = node_as_nmodl, "dt_var"_a = dt_var, "vars"_a = vars, @@ -123,13 +146,18 @@ except Exception as e: py::exec(nmodl::pybind_wrappers::ode_py + script, locals); } else { // nothing to do, but the caller should know. - return; + return {}; } - solution = locals["solution"].cast(); - exception_message = locals["exception_message"].cast(); + auto solution = locals["solution"].cast(); + auto exception_message = locals["exception_message"].cast(); + + return {std::move(solution), std::move(exception_message)}; } -void AnalyticDiffExecutor::operator()() { + +std::tuple call_analytic_diff( + const std::vector& expressions, + const std::set& used_names_in_block) { auto locals = py::dict("expressions"_a = expressions, "vars"_a = used_names_in_block); std::string script = R"( exception_message = "" @@ -147,74 +175,33 @@ except Exception as e: )"; py::exec(nmodl::pybind_wrappers::ode_py + script, locals); - solution = locals["solution"].cast(); - exception_message = locals["exception_message"].cast(); -} - -SolveLinearSystemExecutor* create_sls_executor_func() { - return new SolveLinearSystemExecutor(); -} -SolveNonLinearSystemExecutor* create_nsls_executor_func() { - return new SolveNonLinearSystemExecutor(); -} + auto solution = locals["solution"].cast(); + auto exception_message = locals["exception_message"].cast(); -DiffeqSolverExecutor* create_des_executor_func() { - return new DiffeqSolverExecutor(); + return {std::move(solution), std::move(exception_message)}; } -AnalyticDiffExecutor* create_ads_executor_func() { - return new AnalyticDiffExecutor(); -} - -void destroy_sls_executor_func(SolveLinearSystemExecutor* exec) { - delete exec; -} - -void destroy_nsls_executor_func(SolveNonLinearSystemExecutor* exec) { - delete exec; -} - -void destroy_des_executor_func(DiffeqSolverExecutor* exec) { - delete exec; -} - -void destroy_ads_executor_func(AnalyticDiffExecutor* exec) { - delete exec; -} void initialize_interpreter_func() { pybind11::initialize_interpreter(true); - const auto python_path_cstr = std::getenv("PYTHONPATH"); - if (python_path_cstr) { - pybind11::module::import("sys").attr("path").cast().insert( - 0, python_path_cstr); - } } void finalize_interpreter_func() { pybind11::finalize_interpreter(); } -pybind_wrap_api init_pybind_wrap_api() noexcept { - return { - &nmodl::pybind_wrappers::initialize_interpreter_func, - &nmodl::pybind_wrappers::finalize_interpreter_func, - &nmodl::pybind_wrappers::create_sls_executor_func, - &nmodl::pybind_wrappers::create_nsls_executor_func, - &nmodl::pybind_wrappers::create_des_executor_func, - &nmodl::pybind_wrappers::create_ads_executor_func, - &nmodl::pybind_wrappers::destroy_sls_executor_func, - &nmodl::pybind_wrappers::destroy_nsls_executor_func, - &nmodl::pybind_wrappers::destroy_des_executor_func, - &nmodl::pybind_wrappers::destroy_ads_executor_func, - }; +// Prevent mangling for easier `dlsym`. +extern "C" { +__attribute__((visibility("default"))) pybind_wrap_api nmodl_init_pybind_wrapper_api() noexcept { + return {&nmodl::pybind_wrappers::initialize_interpreter_func, + &nmodl::pybind_wrappers::finalize_interpreter_func, + &call_solve_nonlinear_system, + &call_solve_linear_system, + &call_diffeq_solver, + &call_analytic_diff}; +} } } // namespace pybind_wrappers } // namespace nmodl - - -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -__attribute__((visibility("default"))) nmodl::pybind_wrappers::pybind_wrap_api nmodl_wrapper_api = - nmodl::pybind_wrappers::init_pybind_wrap_api(); diff --git a/src/pybind/wrapper.hpp b/src/pybind/wrapper.hpp new file mode 100644 index 0000000000..2a51da014e --- /dev/null +++ b/src/pybind/wrapper.hpp @@ -0,0 +1,62 @@ +/* + * Copyright 2023 Blue Brain Project, EPFL. + * See the top-level LICENSE file for details. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include + +namespace nmodl { +namespace pybind_wrappers { + + +void initialize_interpreter_func(); +void finalize_interpreter_func(); + +std::tuple, std::vector, std::string> +call_solve_linear_system(const std::vector& eq_system, + const std::vector& state_vars, + const std::set& vars, + bool small_system, + bool elimination, + const std::string& tmp_unique_prefix, + const std::set& function_calls); + +std::tuple, std::string> call_solve_nonlinear_system( + const std::vector& eq_system, + const std::vector& state_vars, + const std::set& vars, + const std::set& function_calls); + +std::tuple call_diffeq_solver(const std::string& node_as_nmodl, + const std::string& dt_var, + const std::set& vars, + bool use_pade_approx, + const std::set& function_calls, + const std::string& method); + +std::tuple call_analytic_diff( + const std::vector& expressions, + const std::set& used_names_in_block); + +struct pybind_wrap_api { + decltype(&initialize_interpreter_func) initialize_interpreter; + decltype(&finalize_interpreter_func) finalize_interpreter; + decltype(&call_solve_nonlinear_system) solve_nonlinear_system; + decltype(&call_solve_linear_system) solve_linear_system; + decltype(&call_diffeq_solver) diffeq_solver; + decltype(&call_analytic_diff) analytic_diff; +}; + +extern "C" { +__attribute__((visibility("default"))) pybind_wrap_api nmodl_init_pybind_wrapper_api() noexcept; +} + + +} // namespace pybind_wrappers +} // namespace nmodl diff --git a/src/visitors/main.cpp b/src/visitors/main.cpp index 42ba3a89d0..9a6b969663 100644 --- a/src/visitors/main.cpp +++ b/src/visitors/main.cpp @@ -106,7 +106,7 @@ int main(int argc, const char* argv[]) { {std::make_shared(), "verbatim", "VerbatimVisitor"}, }; - nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance().api()->initialize_interpreter(); + nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance().api().initialize_interpreter(); for (const auto& filename: files) { logger->info("Processing {}", filename.string()); @@ -128,7 +128,7 @@ int main(int argc, const char* argv[]) { } } - nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance().api()->finalize_interpreter(); + nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance().api().finalize_interpreter(); return 0; } diff --git a/src/visitors/sympy_conductance_visitor.cpp b/src/visitors/sympy_conductance_visitor.cpp index 30aa9bfd4e..7d161e352f 100644 --- a/src/visitors/sympy_conductance_visitor.cpp +++ b/src/visitors/sympy_conductance_visitor.cpp @@ -74,14 +74,8 @@ std::vector SympyConductanceVisitor::generate_statement_strings( binary_expr_index[lhs_str]) + 1); // differentiate dI/dV - auto analytic_diff = - pywrap::EmbeddedPythonLoader::get_instance().api()->create_ads_executor(); - analytic_diff->expressions = expressions; - analytic_diff->used_names_in_block = used_names_in_block; - (*analytic_diff)(); - auto dIdV = analytic_diff->solution; - auto exception_message = analytic_diff->exception_message; - pywrap::EmbeddedPythonLoader::get_instance().api()->destroy_ads_executor(analytic_diff); + auto analytic_diff = pywrap::EmbeddedPythonLoader::get_instance().api().analytic_diff; + auto [dIdV, exception_message] = analytic_diff(expressions, used_names_in_block); if (!exception_message.empty()) { logger->warn("SympyConductance :: python exception: {}", exception_message); } diff --git a/src/visitors/sympy_solver_visitor.cpp b/src/visitors/sympy_solver_visitor.cpp index 040bddbc69..68dfcff28d 100644 --- a/src/visitors/sympy_solver_visitor.cpp +++ b/src/visitors/sympy_solver_visitor.cpp @@ -288,25 +288,16 @@ void SympySolverVisitor::solve_linear_system(const std::vector& pre init_state_vars_vector(); // call sympy linear solver bool small_system = (eq_system.size() <= SMALL_LINEAR_SYSTEM_MAX_STATES); - auto solver = pywrap::EmbeddedPythonLoader::get_instance().api()->create_sls_executor(); - solver->eq_system = eq_system; - solver->state_vars = state_vars; - solver->vars = vars; - solver->small_system = small_system; - solver->elimination = elimination; + auto solver = pywrap::EmbeddedPythonLoader::get_instance().api().solve_linear_system; // this is necessary after we destroy the solver const auto tmp_unique_prefix = suffix_random_string(vars, "tmp"); - solver->tmp_unique_prefix = tmp_unique_prefix; - solver->function_calls = function_calls; - (*solver)(); - // returns a vector of solutions, i.e. new statements to add to block: - auto solutions = solver->solutions; - // and a vector of new local variables that need to be declared in the block: - auto new_local_vars = solver->new_local_vars; + + // returns a vector of solutions, i.e. new statements to add to block; + // and a vector of new local variables that need to be declared in the block; // may also return a python exception message: - auto exception_message = solver->exception_message; - // destroy solver - pywrap::EmbeddedPythonLoader::get_instance().api()->destroy_sls_executor(solver); + auto [solutions, new_local_vars, exception_message] = solver( + eq_system, state_vars, vars, small_system, elimination, tmp_unique_prefix, function_calls); + if (!exception_message.empty()) { logger->warn("SympySolverVisitor :: solve_lin_system python exception: " + exception_message); @@ -344,19 +335,10 @@ void SympySolverVisitor::solve_non_linear_system( const std::vector& pre_solve_statements) { // construct ordered vector of state vars used in non-linear system init_state_vars_vector(); - // call sympy non-linear solver - - auto solver = pywrap::EmbeddedPythonLoader::get_instance().api()->create_nsls_executor(); - solver->eq_system = eq_system; - solver->state_vars = state_vars; - solver->vars = vars; - solver->function_calls = function_calls; - (*solver)(); - // returns a vector of solutions, i.e. new statements to add to block: - auto solutions = solver->solutions; - // may also return a python exception message: - auto exception_message = solver->exception_message; - pywrap::EmbeddedPythonLoader::get_instance().api()->destroy_nsls_executor(solver); + + auto solver = pywrap::EmbeddedPythonLoader::get_instance().api().solve_nonlinear_system; + auto [solutions, exception_message] = solver(eq_system, state_vars, vars, function_calls); + if (!exception_message.empty()) { logger->warn("SympySolverVisitor :: solve_non_lin_system python exception: " + exception_message); @@ -404,19 +386,11 @@ void SympySolverVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { check_expr_statements_in_same_block(); const auto node_as_nmodl = to_nmodl_for_sympy(node); - const auto deleter = [](nmodl::pybind_wrappers::DiffeqSolverExecutor* ptr) { - pywrap::EmbeddedPythonLoader::get_instance().api()->destroy_des_executor(ptr); - }; - std::unique_ptr diffeq_solver{ - pywrap::EmbeddedPythonLoader::get_instance().api()->create_des_executor(), deleter}; - - diffeq_solver->node_as_nmodl = node_as_nmodl; - diffeq_solver->dt_var = codegen::naming::NTHREAD_DT_VARIABLE; - diffeq_solver->vars = vars; - diffeq_solver->use_pade_approx = use_pade_approx; - diffeq_solver->function_calls = function_calls; - diffeq_solver->method = solve_method; - (*diffeq_solver)(); + auto diffeq_solver = pywrap::EmbeddedPythonLoader::get_instance().api().diffeq_solver; + + auto dt_var = codegen::naming::NTHREAD_DT_VARIABLE; + auto [solution, exception_message] = (*diffeq_solver)( + node_as_nmodl, dt_var, vars, use_pade_approx, function_calls, solve_method); if (solve_method == codegen::naming::EULER_METHOD) { // replace x' = f(x) differential equation // with forwards Euler timestep: @@ -449,10 +423,8 @@ void SympySolverVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { } // replace ODE with solution in AST - auto solution = diffeq_solver->solution; logger->debug("SympySolverVisitor :: -> solution: {}", solution); - auto exception_message = diffeq_solver->exception_message; if (!exception_message.empty()) { logger->warn("SympySolverVisitor :: python exception: " + exception_message); return; diff --git a/test/unit/codegen/main.cpp b/test/unit/codegen/main.cpp index 53060dd673..d8de60edbd 100644 --- a/test/unit/codegen/main.cpp +++ b/test/unit/codegen/main.cpp @@ -15,11 +15,11 @@ using namespace nmodl; int main(int argc, char* argv[]) { // initialize python interpreter once for entire catch executable - nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance().api()->initialize_interpreter(); + nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance().api().initialize_interpreter(); // enable verbose logger output logger->set_level(spdlog::level::debug); // run all catch tests int result = Catch::Session().run(argc, argv); - nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance().api()->finalize_interpreter(); + nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance().api().finalize_interpreter(); return result; } diff --git a/test/unit/visitor/main.cpp b/test/unit/visitor/main.cpp index 53060dd673..d8de60edbd 100644 --- a/test/unit/visitor/main.cpp +++ b/test/unit/visitor/main.cpp @@ -15,11 +15,11 @@ using namespace nmodl; int main(int argc, char* argv[]) { // initialize python interpreter once for entire catch executable - nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance().api()->initialize_interpreter(); + nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance().api().initialize_interpreter(); // enable verbose logger output logger->set_level(spdlog::level::debug); // run all catch tests int result = Catch::Session().run(argc, argv); - nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance().api()->finalize_interpreter(); + nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance().api().finalize_interpreter(); return result; }