Skip to content

Commit 2d3bbb6

Browse files
[mlir][Transforms] Dialect conversion: Erase materialized constants instead of rollback (llvm#136489)
When illegal (and not legalizable) constant operations are materialized during a dialect conversion as part of op folding, these operations must be deleted again. This used to be implemented via the rollback mechanism. This commit switches the implementation to regular rewriter API usage: simply delete the materialized constants with `eraseOp`. This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow IR rollbacks. This commit also adds a new optional parameter to `OpBuilder::tryFold` to get hold of the materialized constant ops.
1 parent 8639b36 commit 2d3bbb6

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

mlir/include/mlir/IR/Builders.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,9 +564,13 @@ class OpBuilder : public Builder {
564564

565565
/// Attempts to fold the given operation and places new results within
566566
/// `results`. Returns success if the operation was folded, failure otherwise.
567-
/// If the fold was in-place, `results` will not be filled.
567+
/// If the fold was in-place, `results` will not be filled. Optionally, newly
568+
/// materialized constant operations can be returned to the caller.
569+
///
568570
/// Note: This function does not erase the operation on a successful fold.
569-
LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
571+
LogicalResult
572+
tryFold(Operation *op, SmallVectorImpl<Value> &results,
573+
SmallVectorImpl<Operation *> *materializedConstants = nullptr);
570574

571575
/// Creates a deep copy of the specified operation, remapping any operands
572576
/// that use values outside of the operation using the map that is provided

mlir/lib/IR/Builders.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
465465
return create(state);
466466
}
467467

468-
LogicalResult OpBuilder::tryFold(Operation *op,
469-
SmallVectorImpl<Value> &results) {
468+
LogicalResult
469+
OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
470+
SmallVectorImpl<Operation *> *materializedConstants) {
470471
assert(results.empty() && "expected empty results");
471472
ResultRange opResults = op->getResults();
472473

@@ -528,6 +529,10 @@ LogicalResult OpBuilder::tryFold(Operation *op,
528529
for (Operation *cst : generatedConstants)
529530
insert(cst);
530531

532+
// Return materialized constant operations.
533+
if (materializedConstants)
534+
*materializedConstants = std::move(generatedConstants);
535+
531536
return success();
532537
}
533538

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,37 +2090,34 @@ LogicalResult
20902090
OperationLegalizer::legalizeWithFold(Operation *op,
20912091
ConversionPatternRewriter &rewriter) {
20922092
auto &rewriterImpl = rewriter.getImpl();
2093-
RewriterState curState = rewriterImpl.getCurrentState();
2094-
20952093
LLVM_DEBUG({
20962094
rewriterImpl.logger.startLine() << "* Fold {\n";
20972095
rewriterImpl.logger.indent();
20982096
});
20992097

21002098
// Try to fold the operation.
21012099
SmallVector<Value, 2> replacementValues;
2100+
SmallVector<Operation *, 2> newOps;
21022101
rewriter.setInsertionPoint(op);
2103-
if (failed(rewriter.tryFold(op, replacementValues))) {
2102+
if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
21042103
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
21052104
return failure();
21062105
}
2106+
21072107
// An empty list of replacement values indicates that the fold was in-place.
21082108
// As the operation changed, a new legalization needs to be attempted.
21092109
if (replacementValues.empty())
21102110
return legalize(op, rewriter);
21112111

21122112
// Recursively legalize any new constant operations.
2113-
for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
2114-
i != e; ++i) {
2115-
auto *createOp =
2116-
dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
2117-
if (!createOp)
2118-
continue;
2119-
if (failed(legalize(createOp->getOperation(), rewriter))) {
2113+
for (Operation *newOp : newOps) {
2114+
if (failed(legalize(newOp, rewriter))) {
21202115
LLVM_DEBUG(logFailure(rewriterImpl.logger,
21212116
"failed to legalize generated constant '{0}'",
2122-
createOp->getOperation()->getName()));
2123-
rewriterImpl.resetState(curState);
2117+
newOp->getName()));
2118+
// Legalization failed: erase all materialized constants.
2119+
for (Operation *op : newOps)
2120+
rewriter.eraseOp(op);
21242121
return failure();
21252122
}
21262123
}

0 commit comments

Comments
 (0)