Skip to content

Commit

Permalink
[mlir][scf] scf.while uplifting: optimize op matching (llvm#88813)
Browse files Browse the repository at this point in the history
Instead of iterating over potential induction var uses looking for
suitable `arith.addi`, try to trace it back from yield argument.
  • Loading branch information
Hardcode84 committed Apr 16, 2024
1 parent 61717c1 commit 1ca6b44
Showing 1 changed file with 14 additions and 22 deletions.
36 changes: 14 additions & 22 deletions mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,38 +101,30 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,

Block *afterBody = loop.getAfterBody();
scf::YieldOp afterTerm = loop.getYieldOp();
auto argNumber = inductionVar.getArgNumber();
auto afterTermIndArg = afterTerm.getResults()[argNumber];
unsigned argNumber = inductionVar.getArgNumber();
Value afterTermIndArg = afterTerm.getResults()[argNumber];

auto inductionVarAfter = afterBody->getArgument(argNumber);

Value step;
Value inductionVarAfter = afterBody->getArgument(argNumber);

// Find suitable `addi` op inside `after` block, one of the args must be an
// Induction var passed from `before` block and second arg must be defined
// outside of the loop and will be considered step value.
// TODO: Add `subi` support?
for (auto &use : inductionVarAfter.getUses()) {
auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
if (!owner)
continue;

auto other =
(inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
if (!dom.properlyDominates(other, loop))
continue;

if (afterTermIndArg != owner.getResult())
continue;
auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
if (!addOp)
return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");

step = other;
break;
Value step;
if (addOp.getLhs() == inductionVarAfter) {
step = addOp.getRhs();
} else if (addOp.getRhs() == inductionVarAfter) {
step = addOp.getLhs();
}

if (!step)
return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
if (!step || !dom.properlyDominates(step, loop))
return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");

auto lb = loop.getInits()[argNumber];
Value lb = loop.getInits()[argNumber];

assert(lb.getType().isIntOrIndex());
assert(lb.getType() == ub.getType());
Expand Down

0 comments on commit 1ca6b44

Please sign in to comment.