Skip to content

Commit

Permalink
Move IndexRemover to its own .hpp and .cpp file.
Browse files Browse the repository at this point in the history
  • Loading branch information
1uc committed Mar 7, 2024
1 parent 347f786 commit 3fda031
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 64 deletions.
1 change: 1 addition & 0 deletions src/visitors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_library(
global_var_visitor.cpp
implicit_argument_visitor.cpp
indexedname_visitor.cpp
index_remover.cpp
inline_visitor.cpp
kinetic_block_visitor.cpp
local_to_assigned_visitor.cpp
Expand Down
47 changes: 47 additions & 0 deletions src/visitors/index_remover.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "index_remover.hpp"

#include "utils/logger.hpp"
#include "visitors/visitor_utils.hpp"

namespace nmodl {
namespace visitor {

IndexRemover::IndexRemover(std::string index, int value)
: index(std::move(index))
, value(value) {}

/// if expression we are visiting is `Name` then return new `Integer` node
std::shared_ptr<ast::Expression> IndexRemover::replace_for_name(
const std::shared_ptr<ast::Expression>& node) const {
if (node->is_name()) {
auto name = std::dynamic_pointer_cast<ast::Name>(node);
if (name->get_node_name() == index) {
return std::make_shared<ast::Integer>(value, nullptr);
}
}
return node;
}

void IndexRemover::visit_binary_expression(ast::BinaryExpression& node) {
node.visit_children(*this);
if (under_indexed_name) {
/// first recursively replaces children
/// replace lhs & rhs if they have matching index variable
auto lhs = replace_for_name(node.get_lhs());
auto rhs = replace_for_name(node.get_rhs());
node.set_lhs(std::move(lhs));
node.set_rhs(std::move(rhs));
}
}

void IndexRemover::visit_indexed_name(ast::IndexedName& node) {
under_indexed_name = true;
node.visit_children(*this);
/// once all children are replaced, do the same for index
auto length = replace_for_name(node.get_length());
node.set_length(std::move(length));
under_indexed_name = false;
}

} // namespace visitor
} // namespace nmodl
51 changes: 51 additions & 0 deletions src/visitors/index_remover.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2023 Blue Brain Project, EPFL.
* See the top-level LICENSE file for details.
*
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include "ast/all.hpp"
#include "visitors/ast_visitor.hpp"

namespace nmodl {
namespace visitor {

/**
* \class IndexRemover
* \brief Helper visitor to replace index of array variable with integer
*
* When loop is unrolled, the index variable like `i` :
*
* ca[i] <-> ca[i+1]
*
* has type `Name` in the AST. This needs to be replaced with `Integer`
* for optimizations like constant folding. This pass look at name and
* binary expressions under index variables.
*/
class IndexRemover: public AstVisitor {
private:
/// index variable name
std::string index;

/// integer value of index variable
int value;

/// true if we are visiting index variable
bool under_indexed_name = false;

public:
IndexRemover(std::string index, int value);

/// if expression we are visiting is `Name` then return new `Integer` node
std::shared_ptr<ast::Expression> replace_for_name(
const std::shared_ptr<ast::Expression>& node) const;

void visit_binary_expression(ast::BinaryExpression& node) override;
void visit_indexed_name(ast::IndexedName& node) override;
};

} // namespace visitor
} // namespace nmodl
65 changes: 2 additions & 63 deletions src/visitors/loop_unroll_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,77 +7,16 @@

#include "visitors/loop_unroll_visitor.hpp"


#include "ast/all.hpp"
#include "parser/c11_driver.hpp"
#include "utils/logger.hpp"
#include "visitors/index_remover.hpp"
#include "visitors/visitor_utils.hpp"


namespace nmodl {
namespace visitor {

/**
* \class IndexRemover
* \brief Helper visitor to replace index of array variable with integer
*
* When loop is unrolled, the index variable like `i` :
*
* ca[i] <-> ca[i+1]
*
* has type `Name` in the AST. This needs to be replaced with `Integer`
* for optimizations like constant folding. This pass look at name and
* binary expressions under index variables.
*/
class IndexRemover: public AstVisitor {
private:
/// index variable name
std::string index;

/// integer value of index variable
int value;

/// true if we are visiting index variable
bool under_indexed_name = false;

public:
IndexRemover(std::string index, int value)
: index(std::move(index))
, value(value) {}

/// if expression we are visiting is `Name` then return new `Integer` node
std::shared_ptr<ast::Expression> replace_for_name(
const std::shared_ptr<ast::Expression>& node) const {
if (node->is_name()) {
auto name = std::dynamic_pointer_cast<ast::Name>(node);
if (name->get_node_name() == index) {
return std::make_shared<ast::Integer>(value, nullptr);
}
}
return node;
}

void visit_binary_expression(ast::BinaryExpression& node) override {
node.visit_children(*this);
if (under_indexed_name) {
/// first recursively replaces children
/// replace lhs & rhs if they have matching index variable
auto lhs = replace_for_name(node.get_lhs());
auto rhs = replace_for_name(node.get_rhs());
node.set_lhs(std::move(lhs));
node.set_rhs(std::move(rhs));
}
}

void visit_indexed_name(ast::IndexedName& node) override {
under_indexed_name = true;
node.visit_children(*this);
/// once all children are replaced, do the same for index
auto length = replace_for_name(node.get_length());
node.set_length(std::move(length));
under_indexed_name = false;
}
};


/// return underlying expression wrapped by WrappedExpression
static std::shared_ptr<ast::Expression> unwrap(const std::shared_ptr<ast::Expression>& expr) {
Expand Down

0 comments on commit 3fda031

Please sign in to comment.