diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td index 45279e95edbba..b0fb911143203 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -45,7 +45,7 @@ def PDL_ApplyNativeConstraintOp ``` }]; - let arguments = (ins StrAttr:$name, Variadic:$args); + let arguments = (ins StrAttr:$name, Variadic:$args, DefaultValuedAttr:$isNegated); let results = (outs Variadic:$results); let assemblyFormat = "$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict"; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td index 742de481c25ea..c830020eff46b 100644 --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -100,8 +100,8 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> { pdl_interp.apply_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest ``` }]; - - let arguments = (ins StrAttr:$name, Variadic:$args); + + let arguments = (ins StrAttr:$name, Variadic:$args, DefaultValuedAttr:$isNegated); let results = (outs Variadic:$results); let assemblyFormat = [{ $name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict `->` successors diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h index 5760592cae1a2..fda83efc987a2 100644 --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -390,7 +390,8 @@ class CallExpr final : public Node::NodeBase, private llvm::TrailingObjects { public: static CallExpr *create(Context &ctx, SMRange loc, Expr *callable, - ArrayRef arguments, Type resultType); + ArrayRef arguments, Type resultType, + bool isNegated = false); /// Return the callable of this call. Expr *getCallableExpr() const { return callable; } @@ -403,9 +404,13 @@ class CallExpr final : public Node::NodeBase, return const_cast(this)->getArguments(); } + bool getIsNegated() const { return isNegated; } + private: - CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs) - : Base(loc, type), callable(callable), numArgs(numArgs) {} + CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs, + bool isNegated) + : Base(loc, type), callable(callable), numArgs(numArgs), + isNegated(isNegated) {} /// The callable of this call. Expr *callable; @@ -415,6 +420,8 @@ class CallExpr final : public Node::NodeBase, /// TrailingObject utilities. friend llvm::TrailingObjects; + + bool isNegated; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index cb2e40184fc20..128b3a0ae7919 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -468,7 +468,7 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, auto *cstQuestion = cast(question); auto applyConstraintOp = builder.create( loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args, - success, failure); + cstQuestion->getIsNegated(), success, failure); // Replace the generated placeholders with the results of the constraint and // erase them for (auto result : llvm::enumerate(applyConstraintOp.getResults())) { diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h index 20304bd35dc28..c6ee9cd3baaea 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -472,7 +472,7 @@ struct AttributeQuestion struct ConstraintQuestion : public PredicateBase< ConstraintQuestion, Qualifier, - std::tuple, ArrayRef>, + std::tuple, ArrayRef, bool>, Predicates::ConstraintQuestion> { using Base::Base; @@ -485,13 +485,18 @@ struct ConstraintQuestion /// Return the result types of the constraint. ArrayRef getResultTypes() const { return std::get<2>(key); } + bool getIsNegated() const { return std::get<3>(key); } + /// Construct an instance with the given storage allocator. static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, KeyTy key) { return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), alloc.copyInto(std::get<1>(key)), - alloc.copyInto(std::get<2>(key))}); + alloc.copyInto(std::get<2>(key)), + std::get<3>(key)}); } + + static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(key); } }; /// Compare the equality of two values. @@ -698,9 +703,9 @@ class PredicateBuilder { /// Create a predicate that applies a generic constraint. Predicate getConstraint(StringRef name, ArrayRef args, - ArrayRef resultTypes) { - return {ConstraintQuestion::get(uniquer, - std::make_tuple(name, args, resultTypes)), + ArrayRef resultTypes, bool isNegated) { + return {ConstraintQuestion::get( + uniquer, std::make_tuple(name, args, resultTypes, isNegated)), TrueAnswer::get(uniquer)}; } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 1d04e427c0167..424089e469b61 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -274,7 +274,8 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, comparePosDepth); ResultRange results = op.getResults(); PredicateBuilder::Predicate pred = builder.getConstraint( - op.getName(), allPositions, SmallVector(results.getTypes())); + op.getName(), allPositions, SmallVector(results.getTypes()), + op.getIsNegated()); // for each result register a position so it can be used later for (auto result : llvm::enumerate(results)) { diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 1e0c0f05bf19c..42fadb7c939ea 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -791,6 +791,7 @@ void Generator::generate(pdl_interp::ApplyConstraintOp op, // TODO: Handle result ranges writer.append(result); } + writer.append(ByteCodeField(op.getIsNegated())); writer.append(op.getSuccessors()); } void Generator::generate(pdl_interp::ApplyRewriteOp op, @@ -1447,7 +1448,8 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { const PDLConstraintFunction &constraintFn = constraintFunctions[fun_idx]; LogicalResult rewriteResult = constraintFn(rewriter, args); // Depending on the constraint jump to the proper destination. - selectJump(succeeded(rewriteResult)); + ByteCodeField isNegated = read(); + selectJump(isNegated != succeeded(rewriteResult)); } else { const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx]; ByteCodeRewriteResultList results(numResults); @@ -1474,7 +1476,8 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { : 0; } // Depending on the constraint jump to the proper destination. - selectJump(succeeded(rewriteResult)); + ByteCodeField isNegated = read(); + selectJump(isNegated != succeeded(rewriteResult)); } } diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp index 9a6f14b0e7e5f..34082415d52b8 100644 --- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -225,7 +225,10 @@ void NodePrinter::printImpl(const AttributeExpr *expr) { void NodePrinter::printImpl(const CallExpr *expr) { os << "CallExpr " << expr << " Type<"; print(expr->getType()); - os << ">\n"; + os << ">"; + if (expr->getIsNegated()) + os << " negated"; + os << "\n"; printChildren(expr->getCallableExpr()); printChildren("Arguments", expr->getArguments()); } diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp index 04d978ddd043b..7c838e103ac02 100644 --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -266,12 +266,13 @@ AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc, //===----------------------------------------------------------------------===// CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable, - ArrayRef arguments, Type resultType) { + ArrayRef arguments, Type resultType, + bool isNegated) { unsigned allocSize = CallExpr::totalSizeToAlloc(arguments.size()); void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr)); - CallExpr *expr = - new (rawData) CallExpr(loc, resultType, callable, arguments.size()); + CallExpr *expr = new (rawData) + CallExpr(loc, resultType, callable, arguments.size(), isNegated); std::uninitialized_copy(arguments.begin(), arguments.end(), expr->getArguments().begin()); return expr; diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp index 3d7a9263e42be..5ade71267aed0 100644 --- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -103,12 +103,14 @@ class CodeGen { Value genExprImpl(const ast::TypeExpr *expr); SmallVector genConstraintCall(const ast::UserConstraintDecl *decl, - Location loc, ValueRange inputs); + Location loc, ValueRange inputs, + bool isNegated = false); SmallVector genRewriteCall(const ast::UserRewriteDecl *decl, Location loc, ValueRange inputs); template SmallVector genConstraintOrRewriteCall(const T *decl, Location loc, - ValueRange inputs); + ValueRange inputs, + bool isNegated = false); //===--------------------------------------------------------------------===// // Fields @@ -419,7 +421,7 @@ SmallVector CodeGen::genExprImpl(const ast::CallExpr *expr) { // Generate the PDL based on the type of callable. const ast::Decl *callable = callableExpr->getDecl(); if (const auto *decl = dyn_cast(callable)) - return genConstraintCall(decl, loc, arguments); + return genConstraintCall(decl, loc, arguments, expr->getIsNegated()); if (const auto *decl = dyn_cast(callable)) return genRewriteCall(decl, loc, arguments); llvm_unreachable("unhandled CallExpr callable"); @@ -553,15 +555,15 @@ Value CodeGen::genExprImpl(const ast::TypeExpr *expr) { SmallVector CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc, - ValueRange inputs) { + ValueRange inputs, bool isNegated) { // Apply any constraints defined on the arguments to the input values. for (auto it : llvm::zip(decl->getInputs(), inputs)) applyVarConstraints(std::get<0>(it), std::get<1>(it)); // Generate the constraint call. SmallVector results = - genConstraintOrRewriteCall(decl, loc, - inputs); + genConstraintOrRewriteCall( + decl, loc, inputs, isNegated); // Apply any constraints defined on the results of the constraint. for (auto it : llvm::zip(decl->getResults(), results)) @@ -576,9 +578,9 @@ SmallVector CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl, } template -SmallVector CodeGen::genConstraintOrRewriteCall(const T *decl, - Location loc, - ValueRange inputs) { +SmallVector +CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc, + ValueRange inputs, bool isNegated) { const ast::CompoundStmt *cstBody = decl->getBody(); // If the decl doesn't have a statement body, it is a native decl. @@ -593,6 +595,8 @@ SmallVector CodeGen::genConstraintOrRewriteCall(const T *decl, } Operation *pdlOp = builder.create( loc, resultTypes, decl->getName().getName(), inputs); + if (isNegated) + pdlOp->setAttr("isNegated", builder.getBoolAttr(true)); return pdlOp->getResults(); } diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp index 337a8c0b53e4e..388dab188bbaf 100644 --- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp @@ -233,6 +233,8 @@ Token Lexer::lexToken() { return formToken(Token::l_paren, tokStart); case ')': return formToken(Token::r_paren, tokStart); + case '!': + return formToken(Token::exclam, tokStart); case '/': if (*curPtr == '/') { lexComment(); diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h index 6a78669f854d5..b649d1af1f3fe 100644 --- a/mlir/lib/Tools/PDLL/Parser/Lexer.h +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h @@ -79,6 +79,7 @@ class Token { equal, equal_arrow, semicolon, + exclam, /// Paired punctuation. less, greater, diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 3b04a065a2b6c..bf035bb07b363 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -31,8 +31,8 @@ #include "llvm/Support/ScopedPrinter.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Parser.h" -#include #include +#include using namespace mlir; using namespace mlir::pdll; @@ -316,13 +316,15 @@ class Parser { /// Identifier expressions. FailureOr parseArrayAttrExpr(); FailureOr parseAttributeExpr(); - FailureOr parseCallExpr(ast::Expr *parentExpr); + FailureOr parseCallExpr(ast::Expr *parentExpr, + bool isNegated = false); FailureOr parseDeclRefExpr(StringRef name, SMRange loc); FailureOr parseDictAttrExpr(); FailureOr parseIdentifierExpr(); FailureOr parseInlineConstraintLambdaExpr(); FailureOr parseInlineRewriteLambdaExpr(); FailureOr parseMemberAccessExpr(ast::Expr *parentExpr); + FailureOr parseNegatedExpr(); FailureOr parseOperationName(bool allowEmptyName = false); FailureOr parseWrappedOperationName(bool allowEmptyName); FailureOr @@ -406,7 +408,7 @@ class Parser { FailureOr createCallExpr(SMRange loc, ast::Expr *parentExpr, - MutableArrayRef arguments); + MutableArrayRef arguments, bool isNegated); FailureOr createDeclRefExpr(SMRange loc, ast::Decl *decl); FailureOr createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, @@ -1829,6 +1831,9 @@ FailureOr Parser::parseExpr() { case Token::l_square: lhsExpr = parseArrayAttrExpr(); break; + case Token::exclam: + lhsExpr = parseNegatedExpr(); + break; case Token::string_block: return emitError("expected expression. If you are trying to create an " "ArrayAttr, use a space between `[` and `{`."); @@ -1912,7 +1917,8 @@ FailureOr Parser::parseAttributeExpr() { return ast::AttributeExpr::create(ctx, loc, attrExpr); } -FailureOr Parser::parseCallExpr(ast::Expr *parentExpr) { +FailureOr Parser::parseCallExpr(ast::Expr *parentExpr, + bool isNegated) { consumeToken(Token::l_paren); // Parse the arguments of the call. @@ -1936,7 +1942,7 @@ FailureOr Parser::parseCallExpr(ast::Expr *parentExpr) { if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) return failure(); - return createCallExpr(loc, parentExpr, arguments); + return createCallExpr(loc, parentExpr, arguments, isNegated); } FailureOr Parser::parseDeclRefExpr(StringRef name, SMRange loc) { @@ -2061,6 +2067,16 @@ FailureOr Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { return createMemberAccessExpr(parentExpr, memberName, loc); } +FailureOr Parser::parseNegatedExpr() { + consumeToken(Token::exclam); + if (!curToken.is(Token::identifier)) + return emitError("expected native constraint"); + FailureOr identifierExpr = parseIdentifierExpr(); + if (failed(identifierExpr)) + return failure(); + return parseCallExpr(*identifierExpr, /*isNegated = */ true); +} + FailureOr Parser::parseOperationName(bool allowEmptyName) { SMRange loc = curToken.getLoc(); @@ -2789,7 +2805,8 @@ Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { FailureOr Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, - MutableArrayRef arguments) { + MutableArrayRef arguments, + bool isNegated = false) { ast::Type parentType = parentExpr->getType(); ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); @@ -2803,6 +2820,8 @@ Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, if (isa(callableDecl)) return emitError( loc, "unable to invoke `Constraint` within a rewrite section"); + if (isNegated) + return emitError(loc, "negation of Rewrites is not supported"); } else if (isa(callableDecl)) { return emitError(loc, "unable to invoke `Rewrite` within a match section"); } @@ -2835,7 +2854,7 @@ Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, } return ast::CallExpr::create(ctx, loc, parentExpr, arguments, - callableDecl->getResultType()); + callableDecl->getResultType(), isNegated); } FailureOr Parser::createDeclRefExpr(SMRange loc, @@ -2959,8 +2978,7 @@ FailureOr Parser::createOperationExpr( OpResultTypeContext resultTypeContext, SmallVectorImpl &operands, MutableArrayRef attributes, - SmallVectorImpl &results, - unsigned numRegions) { + SmallVectorImpl &results, unsigned numRegions) { std::optional opNameRef = name->getName(); const ods::Operation *odsOp = lookupODSOperation(opNameRef); diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir index 13578274b35db..6b194091ae8e9 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -130,6 +130,20 @@ module @constraint_with_result_multiple { // ----- +// CHECK-LABEL: module @negated_constraint +module @negated_constraint { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: pdl_interp.apply_constraint "constraint"(%[[ROOT]] : !pdl.operation) {isNegated = true} + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation) + pdl.pattern : benefit(1) { + %root = operation + pdl.apply_native_constraint "constraint"(%root : !pdl.operation) {isNegated = true} + rewrite %root with "rewriter" + } +} + +// ----- + // CHECK-LABEL: module @inputs module @inputs { // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir index 7976512f5e88a..26d1f6e710b66 100644 --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -103,6 +103,43 @@ module @ir attributes { test.apply_constraint_3 } { // ----- +// Test support for negated constraints. +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + %test_attr = pdl_interp.create_attribute unit + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.are_equal %test_attr, %attr : !pdl.attribute -> ^pat, ^end + + ^pat: + pdl_interp.apply_constraint "single_entity_constraint"(%root : !pdl.operation) {isNegated = true} -> ^pat1, ^end + + ^pat1: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.replaced_by_pattern" + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_constraint_4 +// CHECK-NOT: "test.replaced_by_pattern" +// CHECK: "test.replaced_by_pattern" + +module @ir attributes { test.apply_constraint_4 } { + "test.op"() { test_attr } : () -> () + "test.foo"() { test_attr } : () -> () +} + +// ----- + //===----------------------------------------------------------------------===// // pdl_interp::ApplyRewriteOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll index de9d467c82b8f..a3e1bb88e168a 100644 --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -36,6 +36,20 @@ Pattern TestExternalCall => replace root: Op with TestRewrite(root); // ----- +// CHECK: pdl.pattern @TestExternalNegatedCall +// CHECK: %[[ROOT:.*]] = operation +// CHECK: apply_native_constraint "TestConstraint"(%[[ROOT]] : !pdl.operation) {isNegated = true} +// CHECK: rewrite %[[ROOT]] +// CHECK: erase %[[ROOT]] +Constraint TestConstraint(op: Op); +Pattern TestExternalNegatedCall with benefit(1) { + let root = op : Op; + !TestConstraint(root); + erase root; +} + +// ----- + //===----------------------------------------------------------------------===// // MemberAccessExpr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll index c55cf78de55ca..9c299c55fc311 100644 --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -183,6 +183,18 @@ Pattern { // ----- +Rewrite Foo(op: Op); + +Pattern { + // CHECK: negation of Rewrites is not supported + let root = op<>; + rewrite root with { + !Foo(root); + } +} + +// ----- + Pattern { // CHECK: expected expression let tuple = (10 = _: Value); diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index 6961de6c9bfb8..938e181587030 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -77,6 +77,22 @@ Pattern { // ----- +// CHECK: Module {{.*}} +// CHECK: -UserConstraintDecl {{.*}} Name ResultType> +// CHECK: `-PatternDecl {{.*}} +// CHECK: -CallExpr {{.*}} Type> negated +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType> +Constraint checkOp(op: Op); + +Pattern { + let inputOp = op; + !checkOp(inputOp); + erase inputOp; +} + +// ----- + //===----------------------------------------------------------------------===// // MemberAccessExpr //===----------------------------------------------------------------------===//