Skip to content

Commit

Permalink
[RotateForLoops] Support for negative bounds
Browse files Browse the repository at this point in the history
This commit improves the loop rotation pass by letting it rewrite for
loops with negative bounds and which provably execute at least once. As
an implication of this change, the predicate of the comparison operation
inserted in the created do-while loop is now determined by the method
which assesses the legality of the rotation: it is signed if the loop's
lower bound + the step cannot be guaranteed to be non-negative, and
unsigned otherwise.

This commit also makes some cosmetic changes to the pass, generally
making its implementation cleaner.
  • Loading branch information
lucas-rami committed Aug 20, 2023
1 parent 7db4619 commit aa8eca2
Showing 1 changed file with 66 additions and 43 deletions.
109 changes: 66 additions & 43 deletions lib/Transforms/ScfRotateForLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,15 @@ using namespace mlir;
using namespace dynamatic;
using namespace circt::handshake;

/// Determines whether a for loop is valid for rotation i.e., whether we can
/// determine that it will execute at least once. The heuristic implemented by
/// this function is necessarily conservative.
static bool isLegalForRotation(scf::ForOp forOp) {
// Check that both bound values are defined by operations
auto lbDef = forOp.getLowerBound().getDefiningOp();
auto ubDef = forOp.getUpperBound().getDefiningOp();
if (!lbDef || !ubDef)
return false;

// Check that both bounds of the for loop are statically known and that the
// lower bound is strictly less than the upper bound
if (auto lbCst = dyn_cast<arith::ConstantOp>(lbDef)) {
if (auto ubCst = dyn_cast<arith::ConstantOp>(ubDef)) {
auto lbVal = dyn_cast<IntegerAttr>(lbCst.getValue());
auto ubVal = dyn_cast<IntegerAttr>(ubCst.getValue());
if (!lbVal || !ubVal)
return false;
return lbVal.getValue().getZExtValue() < ubVal.getValue().getZExtValue();
}
}

return false;
}

namespace {

struct RotateLoop : public OpRewritePattern<scf::ForOp> {
using OpRewritePattern<scf::ForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const override {
if (!isLegalForRotation(forOp))
arith::CmpIPredicate pred;
if (!isLegalForRotation(forOp, pred))
return failure();

rewriter.setInsertionPoint(forOp);
Expand All @@ -66,8 +42,8 @@ struct RotateLoop : public OpRewritePattern<scf::ForOp> {
rewriter.create<scf::WhileOp>(forOp.getLoc(), whileArgsRange.getTypes(),
whileOpArgs, nullptr, nullptr);

// Move all operations from the for loop body to the "before" region of
// the while loop
// Move all operations from the for loop body to the "before" region of the
// while loop
Block &beforeBlock = whileOp.getBefore().front();
rewriter.mergeBlocks(&forOp.getRegion().front(), &beforeBlock,
beforeBlock.getArguments());
Expand All @@ -77,21 +53,17 @@ struct RotateLoop : public OpRewritePattern<scf::ForOp> {
auto addOp = rewriter.create<arith::AddIOp>(
forOp->getLoc(), beforeBlock.getArguments().front(), forOp.getStep());
auto cmpOp = rewriter.create<arith::CmpIOp>(
forOp->getLoc(), arith::CmpIPredicate::ult, addOp.getResult(),
forOp.getUpperBound());
forOp->getLoc(), pred, addOp.getResult(), forOp.getUpperBound());

// Identify the yield operation that was moved from the for loop body to
// the before block
auto yields = beforeBlock.getOps<scf::YieldOp>();
assert(!yields.empty() && "no yields moved from for to while loop");
assert(++yields.begin() == yields.end() && "expected only one yield");
auto yieldOp = *yields.begin();
// Get the yield operation that was moved from the for loop body to the
// before block
scf::YieldOp yieldOp = *beforeBlock.getOps<scf::YieldOp>().begin();
assert(yieldOp && "expected to find a yield");

// Replace the for loop yield terminator with a while condition terminator
SmallVector<Value> condOperands;
condOperands.push_back(addOp.getResult());
for (auto op : yieldOp->getOperands())
condOperands.push_back(op);
llvm::copy(yieldOp->getOperands(), std::back_inserter(condOperands));
auto condOp = rewriter.replaceOpWithNewOp<scf::ConditionOp>(
yieldOp, cmpOp.getResult(), condOperands);

Expand All @@ -102,14 +74,65 @@ struct RotateLoop : public OpRewritePattern<scf::ForOp> {
rewriter.setInsertionPointToStart(&afterBlock);
rewriter.create<scf::YieldOp>(condOp->getLoc(), afterBlock.getArguments());

// Replace for's results with while's results (drop while's first
// result, which is the IV)
for (auto res :
llvm::zip(forOp->getResults(), whileOp.getResults().drop_front()))
std::get<0>(res).replaceAllUsesWith(std::get<1>(res));
// Replace for's results with while's results (drop while's first result,
// which is the IV)
rewriter.replaceOp(forOp, whileOp.getResults().drop_front());
return success();
}

private:
/// Determines whether a for loop is valid for rotation i.e., whether we can
/// determine that it will execute at least once. The heuristic implemented by
/// this function is necessarily conservative. If the function returns true,
/// pred contains the comparison predicate to use to evaluate the condition of
/// the to-be-created do-while loop; otherwise its value is undefined.
bool isLegalForRotation(scf::ForOp forOp, arith::CmpIPredicate &pred) const;
};
} // namespace

/// Determines whether the loop's lower bound added to the step value can be
/// proved to be non-negative (0 or more).
static bool isIteratorProvablyPositive(APInt &lb, Value step) {
if (!lb.isNegative())
return true;
auto stepCst = dyn_cast_if_present<arith::ConstantOp>(step.getDefiningOp());
if (!stepCst)
return false;
IntegerAttr stepVal = dyn_cast<IntegerAttr>(stepCst.getValue());
return lb.getSExtValue() + stepVal.getValue().getZExtValue() >= 0;
}

bool RotateLoop::isLegalForRotation(scf::ForOp forOp,
arith::CmpIPredicate &pred) const {
// Check that both bounds are constant
auto lbCst = dyn_cast_if_present<arith::ConstantOp>(
forOp.getLowerBound().getDefiningOp());
auto ubCst = dyn_cast_if_present<arith::ConstantOp>(
forOp.getUpperBound().getDefiningOp());
if (!lbCst || !ubCst)
return false;

// Check whether the lower bound is strictly lower than the upper bound
IntegerAttr lbVal = dyn_cast<IntegerAttr>(lbCst.getValue());
IntegerAttr ubVal = dyn_cast<IntegerAttr>(ubCst.getValue());
APInt lb = lbVal.getValue();
APInt ub = ubVal.getValue();
if (lb.getSignificantBits() >= 64 || ub.getSignificantBits() >= 64)
return false;

// Determine comparison predicate to use when rotating the loop. We can insert
// an unsigned comparison only if the lower bound added to the (guaranteed
// positive) step can be guaranteed to be non-negative, since the first
// comparison will occur after the first iteration of the old for loop body /
// new do-while body
pred = isIteratorProvablyPositive(lb, forOp.getStep())
? arith::CmpIPredicate::ult
: arith::CmpIPredicate::slt;

return lb.getSExtValue() < ub.getSExtValue();
}

namespace {

/// Simple greedy pattern rewrite driver for SCF loop rotation pass.
struct ScfForLoopRotationPass
Expand Down

0 comments on commit aa8eca2

Please sign in to comment.