From 1ca6b4475c02e5d022ec6b35dbb65d0f11409a88 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 16 Apr 2024 12:39:57 +0300 Subject: [PATCH] [mlir][scf] `scf.while` uplifting: optimize op matching (#88813) Instead of iterating over potential induction var uses looking for suitable `arith.addi`, try to trace it back from yield argument. --- .../SCF/Transforms/UpliftWhileToFor.cpp | 36 ++++++++----------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp index fea2f659535bb..7b4024b6861a7 100644 --- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp @@ -101,38 +101,30 @@ FailureOr mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, Block *afterBody = loop.getAfterBody(); scf::YieldOp afterTerm = loop.getYieldOp(); - auto argNumber = inductionVar.getArgNumber(); - auto afterTermIndArg = afterTerm.getResults()[argNumber]; + unsigned argNumber = inductionVar.getArgNumber(); + Value afterTermIndArg = afterTerm.getResults()[argNumber]; - auto inductionVarAfter = afterBody->getArgument(argNumber); - - Value step; + Value inductionVarAfter = afterBody->getArgument(argNumber); // Find suitable `addi` op inside `after` block, one of the args must be an // Induction var passed from `before` block and second arg must be defined // outside of the loop and will be considered step value. // TODO: Add `subi` support? - for (auto &use : inductionVarAfter.getUses()) { - auto owner = dyn_cast(use.getOwner()); - if (!owner) - continue; - - auto other = - (inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs()); - if (!dom.properlyDominates(other, loop)) - continue; - - if (afterTermIndArg != owner.getResult()) - continue; + auto addOp = afterTermIndArg.getDefiningOp(); + if (!addOp) + return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op"); - step = other; - break; + Value step; + if (addOp.getLhs() == inductionVarAfter) { + step = addOp.getRhs(); + } else if (addOp.getRhs() == inductionVarAfter) { + step = addOp.getLhs(); } - if (!step) - return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op"); + if (!step || !dom.properlyDominates(step, loop)) + return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form"); - auto lb = loop.getInits()[argNumber]; + Value lb = loop.getInits()[argNumber]; assert(lb.getType().isIntOrIndex()); assert(lb.getType() == ub.getType());