diff --git a/src/codegen/codegen_cpp_visitor.cpp b/src/codegen/codegen_cpp_visitor.cpp index 946fca68f..f99bc8b15 100644 --- a/src/codegen/codegen_cpp_visitor.cpp +++ b/src/codegen/codegen_cpp_visitor.cpp @@ -4060,7 +4060,7 @@ void CodegenCppVisitor::visit_for_netcon(const ast::ForNetcon& node) { // to the next netcon. const auto& args = node.get_parameters(); RenameVisitor v; - auto& statement_block = node.get_statement_block(); + const auto& statement_block = node.get_statement_block(); for (size_t i_arg = 0; i_arg < args.size(); ++i_arg) { // sanitize node_name since we want to substitute names like (*w) as they are auto old_name = diff --git a/src/language/nodes.py b/src/language/nodes.py index ff1037f69..7c6d77f14 100644 --- a/src/language/nodes.py +++ b/src/language/nodes.py @@ -234,11 +234,15 @@ def member_typename(self): return type_name + @property + def _is_member_type_wrapped_as_shared_pointer(self): + return not (self.is_vector or self.is_base_type_node or self.is_ptr_excluded_node) + @property def member_rvalue_typename(self): """returns rvalue reference type when used as returned or parameter type""" typename = self.member_typename - if not self.is_integral_type_node: + if not self.is_integral_type_node and not self._is_member_type_wrapped_as_shared_pointer: return "const " + typename + "&" return typename diff --git a/src/language/templates/ast/ast.cpp b/src/language/templates/ast/ast.cpp index 87f539954..04ff47538 100644 --- a/src/language/templates/ast/ast.cpp +++ b/src/language/templates/ast/ast.cpp @@ -30,7 +30,7 @@ std::string Ast::get_node_name() const { throw std::logic_error("get_node_name() not implemented"); } -const std::shared_ptr& Ast::get_statement_block() const { +std::shared_ptr Ast::get_statement_block() const { throw std::runtime_error("get_statement_block not implemented"); } diff --git a/src/language/templates/ast/ast.hpp b/src/language/templates/ast/ast.hpp index 084580794..6e9e4b18a 100644 --- a/src/language/templates/ast/ast.hpp +++ b/src/language/templates/ast/ast.hpp @@ -273,7 +273,7 @@ struct Ast: public std::enable_shared_from_this { * * \sa ast::StatementBlock */ - virtual const std::shared_ptr& get_statement_block() const; + virtual std::shared_ptr get_statement_block() const; /** * \brief Set symbol table for the AST node diff --git a/src/language/templates/pybind/pyast.hpp b/src/language/templates/pybind/pyast.hpp index f2fa13eb1..e70ffe3ac 100644 --- a/src/language/templates/pybind/pyast.hpp +++ b/src/language/templates/pybind/pyast.hpp @@ -114,8 +114,8 @@ struct PyAst: public Ast { PYBIND11_OVERRIDE(symtab::SymbolTable*, Ast, get_symbol_table, ); } - const std::shared_ptr& get_statement_block() const override { - PYBIND11_OVERRIDE(const std::shared_ptr&, Ast, get_statement_block, ); + std::shared_ptr get_statement_block() const override { + PYBIND11_OVERRIDE(std::shared_ptr, Ast, get_statement_block, ); } void set_symbol_table(symtab::SymbolTable* newsymtab) override { diff --git a/src/visitors/global_var_visitor.cpp b/src/visitors/global_var_visitor.cpp index f11956bb4..ae8472059 100644 --- a/src/visitors/global_var_visitor.cpp +++ b/src/visitors/global_var_visitor.cpp @@ -25,7 +25,7 @@ void GlobalToRangeVisitor::visit_neuron_block(ast::NeuronBlock& node) { std::unordered_set global_variables_to_remove; std::unordered_set global_statements_to_remove; - auto& statement_block = node.get_statement_block(); + auto const& statement_block = node.get_statement_block(); auto& statements = (*statement_block).get_statements(); const auto& symbol_table = ast.get_symbol_table(); diff --git a/src/visitors/neuron_solve_visitor.cpp b/src/visitors/neuron_solve_visitor.cpp index 77af74eff..d8da1e03c 100644 --- a/src/visitors/neuron_solve_visitor.cpp +++ b/src/visitors/neuron_solve_visitor.cpp @@ -32,7 +32,7 @@ void NeuronSolveVisitor::visit_derivative_block(ast::DerivativeBlock& node) { node.visit_children(*this); derivative_block = false; if (solve_blocks[derivative_block_name] == codegen::naming::EULER_METHOD) { - auto& statement_block = node.get_statement_block(); + const auto& statement_block = node.get_statement_block(); for (auto& e: euler_solution_expressions) { statement_block->emplace_back_statement(e); } diff --git a/src/visitors/sympy_replace_solutions_visitor.cpp b/src/visitors/sympy_replace_solutions_visitor.cpp index 34ba9e349..9906c5577 100644 --- a/src/visitors/sympy_replace_solutions_visitor.cpp +++ b/src/visitors/sympy_replace_solutions_visitor.cpp @@ -161,8 +161,8 @@ void SympyReplaceSolutionsVisitor::visit_statement_block(ast::StatementBlock& no void SympyReplaceSolutionsVisitor::try_replace_tagged_statement( const ast::Node& node, - const std::shared_ptr& get_lhs(const ast::Node& node), - const std::shared_ptr& get_rhs(const ast::Node& node)) { + std::shared_ptr get_lhs(const ast::Node& node), + std::shared_ptr get_rhs(const ast::Node& node)) { interleaves_counter.new_equation(true); const auto& statement = std::static_pointer_cast( @@ -212,11 +212,11 @@ void SympyReplaceSolutionsVisitor::try_replace_tagged_statement( void SympyReplaceSolutionsVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { logger->debug("SympyReplaceSolutionsVisitor :: visit {}", to_nmodl(node)); - auto get_lhs = [](const ast::Node& node) -> const std::shared_ptr& { + auto get_lhs = [](const ast::Node& node) -> std::shared_ptr { return dynamic_cast(node).get_expression()->get_lhs(); }; - auto get_rhs = [](const ast::Node& node) -> const std::shared_ptr& { + auto get_rhs = [](const ast::Node& node) -> std::shared_ptr { return dynamic_cast(node).get_expression()->get_rhs(); }; @@ -225,11 +225,11 @@ void SympyReplaceSolutionsVisitor::visit_diff_eq_expression(ast::DiffEqExpressio void SympyReplaceSolutionsVisitor::visit_lin_equation(ast::LinEquation& node) { logger->debug("SympyReplaceSolutionsVisitor :: visit {}", to_nmodl(node)); - auto get_lhs = [](const ast::Node& node) -> const std::shared_ptr& { + auto get_lhs = [](const ast::Node& node) -> std::shared_ptr { return dynamic_cast(node).get_left_linxpression(); }; - auto get_rhs = [](const ast::Node& node) -> const std::shared_ptr& { + auto get_rhs = [](const ast::Node& node) -> std::shared_ptr { return dynamic_cast(node).get_left_linxpression(); }; @@ -239,11 +239,11 @@ void SympyReplaceSolutionsVisitor::visit_lin_equation(ast::LinEquation& node) { void SympyReplaceSolutionsVisitor::visit_non_lin_equation(ast::NonLinEquation& node) { logger->debug("SympyReplaceSolutionsVisitor :: visit {}", to_nmodl(node)); - auto get_lhs = [](const ast::Node& node) -> const std::shared_ptr& { + auto get_lhs = [](const ast::Node& node) -> std::shared_ptr { return dynamic_cast(node).get_lhs(); }; - auto get_rhs = [](const ast::Node& node) -> const std::shared_ptr& { + auto get_rhs = [](const ast::Node& node) -> std::shared_ptr { return dynamic_cast(node).get_rhs(); }; diff --git a/src/visitors/sympy_replace_solutions_visitor.hpp b/src/visitors/sympy_replace_solutions_visitor.hpp index 371b4788d..42bc4da2d 100644 --- a/src/visitors/sympy_replace_solutions_visitor.hpp +++ b/src/visitors/sympy_replace_solutions_visitor.hpp @@ -253,8 +253,8 @@ class SympyReplaceSolutionsVisitor: public AstVisitor { */ void try_replace_tagged_statement( const ast::Node& node, - const std::shared_ptr& get_lhs(const ast::Node& node), - const std::shared_ptr& get_rhs(const ast::Node& node)); + std::shared_ptr get_lhs(const ast::Node& node), + std::shared_ptr get_rhs(const ast::Node& node)); /** * \struct InterleavesCounter