From 3fda03148f854259116571a90b082b9e2366d1ee Mon Sep 17 00:00:00 2001 From: Luc Grosheintz Date: Thu, 7 Mar 2024 11:20:35 +0100 Subject: [PATCH] Move `IndexRemover` to its own .hpp and .cpp file. --- cmake/hpc-coding-conventions | 2 +- src/visitors/CMakeLists.txt | 1 + src/visitors/index_remover.cpp | 47 ++++++++++++++++++++ src/visitors/index_remover.hpp | 51 ++++++++++++++++++++++ src/visitors/loop_unroll_visitor.cpp | 65 +--------------------------- 5 files changed, 102 insertions(+), 64 deletions(-) create mode 100644 src/visitors/index_remover.cpp create mode 100644 src/visitors/index_remover.hpp diff --git a/cmake/hpc-coding-conventions b/cmake/hpc-coding-conventions index 8f8115597..f8f8d69a6 160000 --- a/cmake/hpc-coding-conventions +++ b/cmake/hpc-coding-conventions @@ -1 +1 @@ -Subproject commit 8f8115597817365c5c4fa39e217b3ab0b3640cb2 +Subproject commit f8f8d69a66c23978d1c9c5dce62de79466f26e5d diff --git a/src/visitors/CMakeLists.txt b/src/visitors/CMakeLists.txt index 03b98d228..c798c8325 100644 --- a/src/visitors/CMakeLists.txt +++ b/src/visitors/CMakeLists.txt @@ -16,6 +16,7 @@ add_library( global_var_visitor.cpp implicit_argument_visitor.cpp indexedname_visitor.cpp + index_remover.cpp inline_visitor.cpp kinetic_block_visitor.cpp local_to_assigned_visitor.cpp diff --git a/src/visitors/index_remover.cpp b/src/visitors/index_remover.cpp new file mode 100644 index 000000000..82db82ccf --- /dev/null +++ b/src/visitors/index_remover.cpp @@ -0,0 +1,47 @@ +#include "index_remover.hpp" + +#include "utils/logger.hpp" +#include "visitors/visitor_utils.hpp" + +namespace nmodl { +namespace visitor { + +IndexRemover::IndexRemover(std::string index, int value) + : index(std::move(index)) + , value(value) {} + +/// if expression we are visiting is `Name` then return new `Integer` node +std::shared_ptr IndexRemover::replace_for_name( + const std::shared_ptr& node) const { + if (node->is_name()) { + auto name = std::dynamic_pointer_cast(node); + if (name->get_node_name() == index) { + return std::make_shared(value, nullptr); + } + } + return node; +} + +void IndexRemover::visit_binary_expression(ast::BinaryExpression& node) { + node.visit_children(*this); + if (under_indexed_name) { + /// first recursively replaces children + /// replace lhs & rhs if they have matching index variable + auto lhs = replace_for_name(node.get_lhs()); + auto rhs = replace_for_name(node.get_rhs()); + node.set_lhs(std::move(lhs)); + node.set_rhs(std::move(rhs)); + } +} + +void IndexRemover::visit_indexed_name(ast::IndexedName& node) { + under_indexed_name = true; + node.visit_children(*this); + /// once all children are replaced, do the same for index + auto length = replace_for_name(node.get_length()); + node.set_length(std::move(length)); + under_indexed_name = false; +} + +} // namespace visitor +} // namespace nmodl diff --git a/src/visitors/index_remover.hpp b/src/visitors/index_remover.hpp new file mode 100644 index 000000000..17b4c013e --- /dev/null +++ b/src/visitors/index_remover.hpp @@ -0,0 +1,51 @@ +/* + * Copyright 2023 Blue Brain Project, EPFL. + * See the top-level LICENSE file for details. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "ast/all.hpp" +#include "visitors/ast_visitor.hpp" + +namespace nmodl { +namespace visitor { + +/** + * \class IndexRemover + * \brief Helper visitor to replace index of array variable with integer + * + * When loop is unrolled, the index variable like `i` : + * + * ca[i] <-> ca[i+1] + * + * has type `Name` in the AST. This needs to be replaced with `Integer` + * for optimizations like constant folding. This pass look at name and + * binary expressions under index variables. + */ +class IndexRemover: public AstVisitor { + private: + /// index variable name + std::string index; + + /// integer value of index variable + int value; + + /// true if we are visiting index variable + bool under_indexed_name = false; + + public: + IndexRemover(std::string index, int value); + + /// if expression we are visiting is `Name` then return new `Integer` node + std::shared_ptr replace_for_name( + const std::shared_ptr& node) const; + + void visit_binary_expression(ast::BinaryExpression& node) override; + void visit_indexed_name(ast::IndexedName& node) override; +}; + +} // namespace visitor +} // namespace nmodl diff --git a/src/visitors/loop_unroll_visitor.cpp b/src/visitors/loop_unroll_visitor.cpp index 35697f19b..6369fd737 100644 --- a/src/visitors/loop_unroll_visitor.cpp +++ b/src/visitors/loop_unroll_visitor.cpp @@ -7,77 +7,16 @@ #include "visitors/loop_unroll_visitor.hpp" + #include "ast/all.hpp" #include "parser/c11_driver.hpp" #include "utils/logger.hpp" +#include "visitors/index_remover.hpp" #include "visitors/visitor_utils.hpp" - namespace nmodl { namespace visitor { -/** - * \class IndexRemover - * \brief Helper visitor to replace index of array variable with integer - * - * When loop is unrolled, the index variable like `i` : - * - * ca[i] <-> ca[i+1] - * - * has type `Name` in the AST. This needs to be replaced with `Integer` - * for optimizations like constant folding. This pass look at name and - * binary expressions under index variables. - */ -class IndexRemover: public AstVisitor { - private: - /// index variable name - std::string index; - - /// integer value of index variable - int value; - - /// true if we are visiting index variable - bool under_indexed_name = false; - - public: - IndexRemover(std::string index, int value) - : index(std::move(index)) - , value(value) {} - - /// if expression we are visiting is `Name` then return new `Integer` node - std::shared_ptr replace_for_name( - const std::shared_ptr& node) const { - if (node->is_name()) { - auto name = std::dynamic_pointer_cast(node); - if (name->get_node_name() == index) { - return std::make_shared(value, nullptr); - } - } - return node; - } - - void visit_binary_expression(ast::BinaryExpression& node) override { - node.visit_children(*this); - if (under_indexed_name) { - /// first recursively replaces children - /// replace lhs & rhs if they have matching index variable - auto lhs = replace_for_name(node.get_lhs()); - auto rhs = replace_for_name(node.get_rhs()); - node.set_lhs(std::move(lhs)); - node.set_rhs(std::move(rhs)); - } - } - - void visit_indexed_name(ast::IndexedName& node) override { - under_indexed_name = true; - node.visit_children(*this); - /// once all children are replaced, do the same for index - auto length = replace_for_name(node.get_length()); - node.set_length(std::move(length)); - under_indexed_name = false; - } -}; - /// return underlying expression wrapped by WrappedExpression static std::shared_ptr unwrap(const std::shared_ptr& expr) {