diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index 59b4d06f9eca7..ddadf698f3e8d 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -151,7 +151,7 @@ struct PatternLowering { /// A mapping between constraint questions that refer to values created by /// constraints and the temporary placeholder values created for them. - DenseMap, Value> substitutions; + std::multimap, Value> substitutions; }; } // namespace @@ -377,8 +377,9 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { auto *constrResPos = cast(pos); Value placeholderValue = builder.create( loc, StringAttr::get(builder.getContext(), "placeholder")); - substitutions[{constrResPos->getQuestion(), constrResPos->getIndex()}] = - placeholderValue; + substitutions.insert( + {{constrResPos->getQuestion(), constrResPos->getIndex()}, + placeholderValue}); value = placeholderValue; break; } @@ -474,11 +475,15 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, std::pair substitutionKey = { cstQuestion, result.index()}; // Check if there are substitutions to perform. If the result is never - // used no substitutions will have been generated. - if (substitutions.count(substitutionKey)) { - substitutions[substitutionKey].replaceAllUsesWith(result.value()); - substitutions[substitutionKey].getDefiningOp()->erase(); - } + // used or multiple calls to the same constraint have been merged, + // no substitutions will have been generated for this specific op. + auto range = substitutions.equal_range(substitutionKey); + std::for_each(range.first, range.second, [&](const auto &elem) { + Value placeholder = elem.second; + placeholder.replaceAllUsesWith(result.value()); + placeholder.getDefiningOp()->erase(); + }); + substitutions.erase(substitutionKey); } break; } diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index b179841fec0de..bb9ada4794ec8 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -1438,6 +1438,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { LLVM_DEBUG({ llvm::dbgs() << " * Arguments: "; llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n"; }); ByteCodeField numResults = read(); @@ -1450,12 +1451,26 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx]; ByteCodeRewriteResultList results(numResults); LogicalResult rewriteResult = constraintFn(rewriter, results, args); - assert(results.getResults().size() == numResults && - "native PDL rewrite function returned unexpected number of results"); - - for (PDLValue &result : results.getResults()) { - LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); - memory[read()] = result.getAsOpaquePointer(); + ArrayRef constraintResults = results.getResults(); + LLVM_DEBUG({ + if (succeeded(rewriteResult)) { + llvm::dbgs() << " * Constraint succeeded\n"; + llvm::dbgs() << " * Results: "; + llvm::interleaveComma(constraintResults, llvm::dbgs()); + llvm::dbgs() << "\n"; + } else { + llvm::dbgs() << " * Constraint failed\n"; + } + }); + assert((failed(rewriteResult) || constraintResults.size() == numResults) && + "native PDL rewrite function returned " + "unexpected number of results"); + // Populate memory either with the results or with 0s to preserve memory + // structure as expected + for (int i = 0; i < numResults; i++) { + memory[read()] = succeeded(rewriteResult) + ? constraintResults[i].getAsOpaquePointer() + : 0; } // Depending on the constraint jump to the proper destination. selectJump(succeeded(rewriteResult)); diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir index 14445beadef30..13578274b35db 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -107,6 +107,29 @@ module @constraint_with_unused_result { // ----- +// CHECK-LABEL: module @constraint_with_result_multiple +module @constraint_with_result_multiple { + // check that native constraints work as expected even when multiple identical constraints are fused + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]] + // CHECK-NOT: pdl_interp.apply_constraint "check_op_and_get_attr_constr" + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter_0(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute) + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute) + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter"(%attr : !pdl.attribute) + } + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter"(%attr : !pdl.attribute) + } +} + +// ----- + // CHECK-LABEL: module @inputs module @inputs { // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)