Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offset Reduction #1151

Draft
wants to merge 38 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
1d5e8b1
fixed VarDeclStmt visitor
mroethlin Dec 12, 2019
556e5c6
Merge remote-tracking branch 'origin/master'
mroethlin Dec 17, 2019
092f7eb
Merge remote-tracking branch 'origin/master'
mroethlin Dec 17, 2019
f262804
Merge remote-tracking branch 'origin/master'
mroethlin Jan 6, 2020
3865d4d
Merge remote-tracking branch 'origin/master'
mroethlin Jan 7, 2020
83f4680
Merge remote-tracking branch 'origin/master'
mroethlin Jan 9, 2020
bf0a6d2
Merge remote-tracking branch 'origin/master'
mroethlin Jan 9, 2020
a317ec9
Merge remote-tracking branch 'origin/master'
mroethlin Jan 10, 2020
2b243c4
Merge remote-tracking branch 'origin/master'
mroethlin Jan 14, 2020
1e6c81a
Merge remote-tracking branch 'origin/master'
mroethlin Jan 15, 2020
1b53bd1
Merge remote-tracking branch 'origin/master'
mroethlin Jan 16, 2020
3a50412
updating git hook ignore list for new codegen tests
mroethlin Jan 16, 2020
6550116
Merge remote-tracking branch 'origin/master'
mroethlin Jan 16, 2020
d933ffe
Merge remote-tracking branch 'origin/master'
mroethlin Jan 23, 2020
2662ecb
Use experimental::fs if fs is not supported (#640)
havogt Jan 23, 2020
c383e3f
Merge branch 'master' of https://github.com/MeteoSwiss-APN/dawn
mroethlin Jan 23, 2020
72b547e
Merge branch 'master' of https://github.com/MeteoSwiss-APN/dawn
mroethlin Jan 23, 2020
328a0cf
Merge remote-tracking branch 'origin/master'
mroethlin Feb 3, 2020
8d560dc
Merge branch 'master' of https://github.com/MeteoSwiss-APN/dawn
mroethlin Jan 19, 2021
d25dab1
Merge branch 'master' of https://github.com/MeteoSwiss-APN/dawn
mroethlin Jan 19, 2021
bca7232
wip
mroethlin Jan 27, 2021
26e6347
rough but complete implementation
mroethlin Jan 28, 2021
b252516
Merge branch 'master' of https://github.com/MeteoSwiss-APN/dawn
mroethlin Feb 8, 2021
1400f05
Merge branch 'master' of https://github.com/MeteoSwiss-APN/dawn
mroethlin Feb 9, 2021
311c440
small improvements, fix stupid bug when constructing offset dequeue
mroethlin Feb 10, 2021
a14e95e
Merge branch 'master' of https://github.com/MeteoSwiss-APN/dawn
mroethlin Feb 22, 2021
695687f
handling the case where lower AND upper level is end
mroethlin Feb 22, 2021
0b910bb
sparse full field indexing bug fixed
mroethlin Feb 26, 2021
52e407b
only check errors in debug build
mroethlin Mar 5, 2021
4900e0f
merge master
mroethlin Mar 8, 2021
9f5e95c
merge fixes
mroethlin Mar 8, 2021
d1263a1
fix bug in cudaico-codegen with offset reductions
mroethlin Mar 16, 2021
51b0010
fixed a bug in the space collector for offsetReductions
mroethlin Mar 16, 2021
9a3a80f
fix code generation for indices, again. add integration test
mroethlin Mar 17, 2021
829d12c
fix merge conflicts
mroethlin May 19, 2021
8b67de2
update ref
mroethlin May 19, 2021
8e1bbfa
Merge branch 'master' into offsetReduction
mroethlin Aug 5, 2021
e179f64
Merge branch 'master' into offsetReduction
actions-user Aug 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions dawn/src/dawn/AST/ASTExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,25 +459,25 @@ ReductionOverNeighborExpr::ReductionOverNeighborExpr(std::string const& op,
std::shared_ptr<Expr> const& rhs,
std::shared_ptr<Expr> const& init,
std::vector<ast::LocationType> chain,
bool includeCenter, SourceLocation loc)
bool includeCenter, std::vector<int> offsets,
SourceLocation loc)
: Expr(Kind::ReductionOverNeighborExpr, loc), op_(op),
iterSpace_(std::move(chain), includeCenter), operands_{rhs, init} {}
iterSpace_(std::move(chain), includeCenter), operands_{rhs, init}, offsets_(offsets) {}

ReductionOverNeighborExpr::ReductionOverNeighborExpr(std::string const& op,
std::shared_ptr<Expr> const& rhs,
std::shared_ptr<Expr> const& init,
std::vector<std::shared_ptr<Expr>> weights,
std::vector<ast::LocationType> chain,
bool includeCenter, SourceLocation loc)
: ReductionOverNeighborExpr(op, rhs, init, chain, includeCenter, loc) {
ReductionOverNeighborExpr::ReductionOverNeighborExpr(
std::string const& op, std::shared_ptr<Expr> const& rhs, std::shared_ptr<Expr> const& init,
std::vector<std::shared_ptr<Expr>> weights, std::vector<ast::LocationType> chain,
bool includeCenter, std::vector<int> offsets, SourceLocation loc)
: ReductionOverNeighborExpr(op, rhs, init, chain, includeCenter, offsets, loc) {
DAWN_ASSERT_MSG(weights.size() > 0, "empty weights vector passed!\n");
weights_ = weights;
operands_.insert(operands_.end(), weights.begin(), weights.end());
}

ReductionOverNeighborExpr::ReductionOverNeighborExpr(ReductionOverNeighborExpr const& expr)
: Expr(Kind::ReductionOverNeighborExpr, expr.getSourceLocation()), op_(expr.getOp()),
weights_(expr.getWeights()), iterSpace_(expr.iterSpace_), operands_(expr.operands_) {}
weights_(expr.getWeights()), iterSpace_(expr.iterSpace_), operands_(expr.operands_),
offsets_(expr.offsets_) {}

ReductionOverNeighborExpr&
ReductionOverNeighborExpr::operator=(ReductionOverNeighborExpr const& expr) {
Expand All @@ -486,6 +486,7 @@ ReductionOverNeighborExpr::operator=(ReductionOverNeighborExpr const& expr) {
operands_ = expr.operands_;
iterSpace_ = expr.iterSpace_;
weights_ = expr.getWeights();
offsets_ = expr.offsets_;
return *this;
}

Expand Down Expand Up @@ -516,7 +517,7 @@ bool ReductionOverNeighborExpr::equals(const Expr* other, bool compareData) cons

return otherPtr && otherPtr->getInit()->equals(getInit().get(), compareData) &&
otherPtr->getOp() == getOp() && otherPtr->getRhs()->equals(getRhs().get(), compareData) &&
otherPtr->iterSpace_ == iterSpace_;
otherPtr->iterSpace_ == iterSpace_ && offsets_ == otherPtr->offsets_;
}

bool ReductionOverNeighborExpr::isArithmetic() const {
Expand Down
7 changes: 5 additions & 2 deletions dawn/src/dawn/AST/ASTExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -647,19 +647,21 @@ class ReductionOverNeighborExpr : public Expr {
// hold a copy of the (shared pointer to) the weights
std::vector<std::shared_ptr<Expr>> operands_ = std::vector<std::shared_ptr<Expr>>(2);
bool chainIsValid() const;
std::vector<int> offsets_ = {};

public:
inline static const std::vector<std::string> arithmeticOps{"+", "-", "*", "/", "%"};
/// @name Constructor & Destructor
/// @{
ReductionOverNeighborExpr(std::string const& op, std::shared_ptr<Expr> const& rhs,
std::shared_ptr<Expr> const& init, std::vector<ast::LocationType> chain,
bool includeCenter = false, SourceLocation loc = SourceLocation());
bool includeCenter = false, std::vector<int> offsets_ = {},
SourceLocation loc = SourceLocation());
ReductionOverNeighborExpr(std::string const& op, std::shared_ptr<Expr> const& rhs,
std::shared_ptr<Expr> const& init,
std::vector<std::shared_ptr<Expr>> weights,
std::vector<ast::LocationType> chain, bool includeCenter = false,
SourceLocation loc = SourceLocation());
std::vector<int> offsets_ = {}, SourceLocation loc = SourceLocation());
ReductionOverNeighborExpr(ReductionOverNeighborExpr const& stmt);
ReductionOverNeighborExpr& operator=(ReductionOverNeighborExpr const& stmt);
/// @}
Expand All @@ -672,6 +674,7 @@ class ReductionOverNeighborExpr : public Expr {
std::vector<ast::LocationType> getNbhChain() const { return iterSpace_; };
ast::LocationType getLhsLocation() const { return iterSpace_.Chain.front(); };
const std::optional<std::vector<std::shared_ptr<Expr>>>& getWeights() const { return weights_; };
const std::vector<int>& getOffsets() const { return offsets_; };
bool getIncludeCenter() const { return iterSpace_.IncludeCenter; };
ast::UnstructuredIterationSpace getIterSpace() const { return iterSpace_; }

Expand Down
38 changes: 15 additions & 23 deletions dawn/src/dawn/AST/proto/AST/statements.proto
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ message CartesianDimension {
// It could also have a sparse dimension. In such case the sparse part is
// non-empty.
// @ingroup ast_proto
message UnstructuredDimension {
UnstructuredIterationSpace iter_space = 1;
}
message UnstructuredDimension { UnstructuredIterationSpace iter_space = 1; }

// @brief Dimensions spanned by a field
//
Expand Down Expand Up @@ -323,9 +321,7 @@ message LoopDescriptorGeneral {
// dummy message for future use
}

message LoopDescriptorChain {
UnstructuredIterationSpace iter_space = 1;
}
message LoopDescriptorChain { UnstructuredIterationSpace iter_space = 1; }

message LoopDescriptor {
oneof desc {
Expand All @@ -335,7 +331,7 @@ message LoopDescriptor {
}

message LoopStmt {
Stmt statements = 1; // List of statements (must be a BlockStmt)
Stmt statements = 1; // List of statements (must be a BlockStmt)
LoopDescriptor loop_descriptor = 2; // Loop bounds description
SourceLocation loc = 3; // Source location
StmtData data = 4; // Generic Stmt's data container
Expand Down Expand Up @@ -366,8 +362,7 @@ message ReturnStmt {
//
// @ingroup ast_proto
message VarDeclStmtData {
google.protobuf.Int32Value accessID =
1; // ID of the variable declared in the statement
google.protobuf.Int32Value accessID = 1; // ID of the variable declared in the statement
}

// @brief Declaration of a variable
Expand Down Expand Up @@ -543,7 +538,7 @@ message FunCallExpr {
// @endcode
// @ingroup ast_proto
message StencilFunCallExpr {
string callee = 1; // Identifier of the stencil function (i.e callee)
string callee = 1; // Identifier of the stencil function (i.e callee)
repeated Expr arguments = 2; // List of arguments
SourceLocation loc = 3; // Source location
int32 ID = 4; // ID of the Expr
Expand Down Expand Up @@ -601,18 +596,16 @@ message StencilFunArgExpr {
//
// @ingroup ast_proto
message AccessExprData {
google.protobuf.Int32Value accessID =
1; // Access ID of variable/literal/field accessed
google.protobuf.Int32Value accessID = 1; // Access ID of variable/literal/field accessed
}

// @brief Access to a variable
//
// @ingroup ast_proto
message VarAccessExpr {
string name = 1; // Name of the variable
Expr index = 2; // Is it an array access (i.e var[2])?
bool is_external =
3; // Is this an access to a external variable (e.g a global)?
string name = 1; // Name of the variable
Expr index = 2; // Is it an array access (i.e var[2])?
bool is_external = 3; // Is this an access to a external variable (e.g a global)?
SourceLocation loc = 4; // Source location
AccessExprData data = 5; // Access data
int32 ID = 6; // ID of the Expr
Expand Down Expand Up @@ -709,15 +702,15 @@ message LiteralAccessExpr {
//
// @ingroup ast_proto
message ReductionOverNeighborExpr {
string op = 1; // Reduction operation
Expr rhs = 2; // Operation to be applied for each neighbor
Expr init = 3; // Initial value of reduction
repeated Expr weights =
4; // weights (required to be of equal type, e.g. just floats)
string op = 1; // Reduction operation
Expr rhs = 2; // Operation to be applied for each neighbor
Expr init = 3; // Initial value of reduction
repeated Expr weights = 4; // weights (required to be of equal type, e.g. just floats)
UnstructuredIterationSpace iter_space =
5; // Neighbor chain definining the neighbors to reduce from and the
// location type to reduce to (first element)
SourceLocation loc = 6;
SourceLocation loc = 6;
repeated int32 offsets = 7;
}

// @brief Abstract syntax tree of the AST
Expand Down Expand Up @@ -760,4 +753,3 @@ message GlobalVariableValue {
message GlobalVariableMap {
map<string, GlobalVariableValue> map = 1; // Mape of global variables (name to value)
}

24 changes: 20 additions & 4 deletions dawn/src/dawn/CodeGen/CXXNaive-ico/ASTStencilBody.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "dawn/IIR/StencilFunctionInstantiation.h"
#include "dawn/SIR/SIR.h"
#include "dawn/Support/Unreachable.h"
#include <iterator>
#include <optional>

static std::string nbhChainToVectorString(const std::vector<dawn::ast::LocationType>& chain) {
auto getLocationTypeString = [](dawn::ast::LocationType type) {
Expand Down Expand Up @@ -228,6 +230,10 @@ std::string ASTStencilBody::makeIndexString(const std::shared_ptr<ast::FieldAcce
std::string sparseIdx = parentIsReduction_
? ASTStencilBody::ReductionSparseIndexVarName(reductionDepth_ - 1)
: ASTStencilBody::LoopLinearIndexVarName();
if(offsets_.has_value()) {
sparseIdx = std::to_string(offsets_->front());
offsets_->pop_front();
}
return "deref(LibTag{}, " + sparseArgName_ + ")," + sparseIdx + ", " + kiterStr;
}

Expand Down Expand Up @@ -256,6 +262,10 @@ std::string ASTStencilBody::makeIndexString(const std::shared_ptr<ast::FieldAcce
std::string sparseIdx = parentIsReduction_
? ASTStencilBody::ReductionSparseIndexVarName(reductionDepth_ - 1)
: ASTStencilBody::LoopLinearIndexVarName();
if(offsets_.has_value()) {
sparseIdx = std::to_string(offsets_->front());
offsets_->pop_front();
}
return "deref(LibTag{}, " + sparseArgName_ + ")," + sparseIdx;
}

Expand Down Expand Up @@ -408,10 +418,7 @@ void ASTStencilBody::visit(const std::shared_ptr<ast::ReductionOverNeighborExpr>
reductionDepth_++;
expr->getRhs()->accept(*this);
reductionDepth_--;
if(reductionDepth_ == 0) {
parentIsReduction_ = false;
currentChain_.clear();
}

// "pop" argName
denseArgName_ = argName;
if(!expr->isArithmetic()) {
Expand All @@ -421,6 +428,9 @@ void ASTStencilBody::visit(const std::shared_ptr<ast::ReductionOverNeighborExpr>
ss_ << ASTStencilBody::ReductionSparseIndexVarName(reductionDepth_) << "++;\n";
ss_ << "return lhs;\n";
ss_ << "}";
if(!expr->getOffsets().empty()) {
offsets_ = std::deque<int>(expr->getOffsets().begin(), expr->getOffsets().end());
}
if(hasWeights) {
auto weights = expr->getWeights().value();
bool first = true;
Expand All @@ -440,6 +450,12 @@ void ASTStencilBody::visit(const std::shared_ptr<ast::ReductionOverNeighborExpr>
ss_ << ", /*include center*/ true";
}
ss_ << ")";
offsets_ = std::nullopt;

if(reductionDepth_ == 0) {
parentIsReduction_ = false;
currentChain_.clear();
}
}

void ASTStencilBody::setCurrentStencilFunction(
Expand Down
1 change: 1 addition & 0 deletions dawn/src/dawn/CodeGen/CXXNaive-ico/ASTStencilBody.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ASTStencilBody : public ASTCodeGenCXX {
bool parentIsReduction_ = false;
bool parentIsForLoop_ = false;
std::vector<ast::LocationType> currentChain_;
std::optional<std::deque<int>> offsets_;

size_t reductionDepth_ = 0;

Expand Down
21 changes: 18 additions & 3 deletions dawn/src/dawn/CodeGen/Cuda-ico/ASTStencilBody.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "dawn/IIR/ASTExpr.h"
#include <algorithm>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
#include <vector>
Expand Down Expand Up @@ -133,7 +134,7 @@ void ASTStencilBody::visit(const std::shared_ptr<ast::AssignmentExpr>& expr) {
}

std::string ASTStencilBody::makeIndexString(const std::shared_ptr<ast::FieldAccessExpr>& expr,
std::string kiterStr) const {
std::string kiterStr) {
bool isVertical = metadata_.getFieldDimensions(iir::getAccessID(expr)).isVertical();
if(isVertical) {
return kiterStr;
Expand Down Expand Up @@ -164,7 +165,12 @@ std::string ASTStencilBody::makeIndexString(const std::shared_ptr<ast::FieldAcce
if(isFullField && isSparse) {
DAWN_ASSERT_MSG(parentIsForLoop_ || parentIsReduction_,
"Sparse Field Access not allowed in this context");
return nbhIterStr() + " * kSize * " + denseSize + " + " + kiterStr + "*" + denseSize + " + " +
std::string nbhIter = nbhIterStr();
if(offsets_.has_value()) {
nbhIter = std::to_string(offsets_->front());
offsets_->pop_front();
}
return nbhIter + " * kSize * " + denseSize + " + " + kiterStr + "*" + denseSize + " + " +
pidxStr();
}

Expand All @@ -182,8 +188,13 @@ std::string ASTStencilBody::makeIndexString(const std::shared_ptr<ast::FieldAcce
if(isHorizontal && isSparse) {
DAWN_ASSERT_MSG(parentIsForLoop_ || parentIsReduction_,
"Sparse Field Access not allowed in this context");
std::string nbhIter = nbhIterStr();
if(offsets_.has_value()) {
nbhIter = std::to_string(offsets_->front());
offsets_->pop_front();
}
std::string sparseSize = chainToSparseSizeString(unstrDims.getIterSpace());
return nbhIterStr() + " * " + denseSize + " + " + pidxStr();
return nbhIter + " * " + denseSize + " + " + pidxStr();
}

DAWN_ASSERT_MSG(false, "Bad Field configuration found in code gen!");
Expand Down Expand Up @@ -311,6 +322,9 @@ void ASTStencilBody::evalNeighbourReduction(
expr->getInit()->accept(*this);
ss_ << ";\n";
auto weights = expr->getWeights();
if(!expr->getOffsets().empty()) {
offsets_ = std::deque<int>(expr->getOffsets().begin(), expr->getOffsets().end());
}
if(weights.has_value()) {
ss_ << "::dawn::float_type " << weights_name << "[" << weights->size() << "] = {";
bool first = true;
Expand All @@ -322,6 +336,7 @@ void ASTStencilBody::evalNeighbourReduction(
first = false;
}
ss_ << "};\n";
offsets_ = std::nullopt;
}

ss_ << "for (int " + nbhIterStr() + " = 0; " + nbhIterStr() + " < "
Expand Down
4 changes: 2 additions & 2 deletions dawn/src/dawn/CodeGen/Cuda-ico/ASTStencilBody.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ASTStencilBody : public ASTCodeGenCXX {

bool parentIsReduction_ = false;
bool parentIsForLoop_ = false;
std::optional<std::deque<int>> offsets_;
bool genAtlasCompatCode_ = false;

std::map<int, std::unique_ptr<ASTStencilBody>> reductionParser_;
Expand All @@ -70,8 +71,7 @@ class ASTStencilBody : public ASTCodeGenCXX {
/// Nesting level of argument lists of stencil function *calls*
int nestingOfStencilFunArgLists_;

std::string makeIndexString(const std::shared_ptr<ast::FieldAccessExpr>& expr,
std::string kiter) const;
std::string makeIndexString(const std::shared_ptr<ast::FieldAccessExpr>& expr, std::string kiter);
bool hasIrregularPentagons(const std::vector<ast::LocationType>& chain) const;
void evalNeighbourReduction(const std::shared_ptr<ast::ReductionOverNeighborExpr>& expr);
void generateNeighbourRedLoop(std::stringstream& ss) const;
Expand Down
Loading