Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -728,13 +728,12 @@ class OpenMPIRBuilder {
LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
const Twine &Name = "loop");

/// Generator for the control flow structure of an OpenMP canonical loop.
/// Calculate the trip count of a canonical loop.
///
/// Instead of a logical iteration space, this allows specifying user-defined
/// loop counter values using increment, upper- and lower bounds. To
/// disambiguate the terminology when counting downwards, instead of lower
/// bounds we use \p Start for the loop counter value in the first body
/// iteration.
/// This allows specifying user-defined loop counter values using increment,
/// upper- and lower bounds. To disambiguate the terminology when counting
/// downwards, instead of lower bounds we use \p Start for the loop counter
/// value in the first body iteration.
///
/// Consider the following limitations:
///
Expand All @@ -758,7 +757,32 @@ class OpenMPIRBuilder {
///
/// for (int i = 0; i < 42; i -= 1u)
///
//
/// \param Loc The insert and source location description.
/// \param Start Value of the loop counter for the first iterations.
/// \param Stop Loop counter values past this will stop the loop.
/// \param Step Loop counter increment after each iteration; negative
/// means counting down.
/// \param IsSigned Whether Start, Stop and Step are signed integers.
/// \param InclusiveStop Whether \p Stop itself is a valid value for the loop
/// counter.
/// \param Name Base name used to derive instruction names.
///
/// \returns The value holding the calculated trip count.
Value *calculateCanonicalLoopTripCount(const LocationDescription &Loc,
Value *Start, Value *Stop, Value *Step,
bool IsSigned, bool InclusiveStop,
const Twine &Name = "loop");

/// Generator for the control flow structure of an OpenMP canonical loop.
///
/// Instead of a logical iteration space, this allows specifying user-defined
/// loop counter values using increment, upper- and lower bounds. To
/// disambiguate the terminology when counting downwards, instead of lower
/// bounds we use \p Start for the loop counter value in the first body
///
/// It calls \see calculateCanonicalLoopTripCount for trip count calculations,
/// so limitations of that method apply here as well.
///
/// \param Loc The insert and source location description.
/// \param BodyGenCB Callback that will generate the loop body code.
/// \param Start Value of the loop counter for the first iterations.
Expand Down
28 changes: 18 additions & 10 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4032,11 +4032,9 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
return CL;
}

Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
InsertPointTy ComputeIP, const Twine &Name) {

Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
bool IsSigned, bool InclusiveStop, const Twine &Name) {
// Consider the following difficulties (assuming 8-bit signed integers):
// * Adding \p Step to the loop counter which passes \p Stop may overflow:
// DO I = 1, 100, 50
Expand All @@ -4048,9 +4046,7 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
assert(IndVarTy == Stop->getType() && "Stop type mismatch");
assert(IndVarTy == Step->getType() && "Step type mismatch");

LocationDescription ComputeLoc =
ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
updateToLocation(ComputeLoc);
updateToLocation(Loc);

ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
ConstantInt *One = ConstantInt::get(IndVarTy, 1);
Expand Down Expand Up @@ -4090,8 +4086,20 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
}
Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
"omp_" + Name + ".tripcount");

return Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
"omp_" + Name + ".tripcount");
}

Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
InsertPointTy ComputeIP, const Twine &Name) {
LocationDescription ComputeLoc =
ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;

Value *TripCount = calculateCanonicalLoopTripCount(
ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);

auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
Builder.restoreIP(CodeGenIP);
Expand Down
16 changes: 3 additions & 13 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1427,8 +1427,7 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
EXPECT_EQ(&Loop->getAfter()->front(), RetInst);
}

TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
TEST_F(OpenMPIRBuilderTest, CanonicalLoopTripCount) {
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
IRBuilder<> Builder(BB);
Expand All @@ -1444,17 +1443,8 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
Value *StartVal = ConstantInt::get(LCTy, Start);
Value *StopVal = ConstantInt::get(LCTy, Stop);
Value *StepVal = ConstantInt::get(LCTy, Step);
auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
return Error::success();
};
Expected<CanonicalLoopInfo *> LoopResult =
OMPBuilder.createCanonicalLoop(Loc, LoopBodyGenCB, StartVal, StopVal,
StepVal, IsSigned, InclusiveStop);
assert(LoopResult && "unexpected error");
CanonicalLoopInfo *Loop = *LoopResult;
Loop->assertOK();
Builder.restoreIP(Loop->getAfterIP());
Value *TripCount = Loop->getTripCount();
Value *TripCount = OMPBuilder.calculateCanonicalLoopTripCount(
Loc, StartVal, StopVal, StepVal, IsSigned, InclusiveStop);
Comment on lines +1446 to +1447

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's actually pretty nice

return cast<ConstantInt>(TripCount)->getValue().getZExtValue();
};

Expand Down
67 changes: 23 additions & 44 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1772,55 +1772,34 @@ LogicalResult TargetOp::verify() {
Operation *TargetOp::getInnermostCapturedOmpOp() {
Dialect *ompDialect = (*this)->getDialect();
Operation *capturedOp = nullptr;
Region *capturedParentRegion = nullptr;

walk<WalkOrder::PostOrder>([&](Operation *op) {
// Process in pre-order to check operations from outermost to innermost,
// ensuring we only enter the region of an operation if it meets the criteria
// for being captured. We stop the exploration of nested operations as soon as
// we process a region with no operation to be captured.
walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op == *this)
return;

// Reset captured op if crossing through an omp.loop_nest, so that the top
// level one will be the one captured.
if (llvm::isa<LoopNestOp>(op)) {
capturedOp = nullptr;
capturedParentRegion = nullptr;
}
return WalkResult::advance();

// Ignore operations of other dialects or omp operations with no regions,
// because these will only be checked if they are siblings of an omp
// operation that can potentially be captured.
bool isOmpDialect = op->getDialect() == ompDialect;
bool hasRegions = op->getNumRegions() > 0;

if (capturedOp) {
bool isImmediateParent = false;
for (Region &region : op->getRegions()) {
if (&region == capturedParentRegion) {
isImmediateParent = true;
capturedParentRegion = op->getParentRegion();
break;
}
}

// Make sure the captured op is part of a (possibly multi-level) nest of
// OpenMP-only operations containing no unsupported siblings at any level.
if ((hasRegions && isOmpDialect != isImmediateParent) ||
(!isImmediateParent && !siblingAllowedInCapture(op))) {
capturedOp = nullptr;
capturedParentRegion = nullptr;
}
} else {
// The first OpenMP dialect op containing a region found while visiting
// in post-order should be the innermost captured OpenMP operation.
if (isOmpDialect && hasRegions) {
capturedOp = op;
capturedParentRegion = op->getParentRegion();

// Don't capture this op if it has a not-allowed sibling.
for (Operation &sibling : op->getParentRegion()->getOps()) {
if (&sibling != op && !siblingAllowedInCapture(&sibling)) {
capturedOp = nullptr;
capturedParentRegion = nullptr;
}
}
}
}
if (!isOmpDialect || !hasRegions)
return WalkResult::skip();

// Don't capture this op if it has a not-allowed sibling, and stop recursing
// into nested operations.
for (Operation &sibling : op->getParentRegion()->getOps())
if (&sibling != op && !siblingAllowedInCapture(&sibling))
return WalkResult::interrupt();

// Don't continue capturing nested operations if we reach an omp.loop_nest.
// Otherwise, process the contents of this operation.
capturedOp = op;
return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
: WalkResult::advance();
});

return capturedOp;
Expand Down
Loading