From 7986168fb0c1467158b30e06ae415073e5106b06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20L=C3=BCcke?= Date: Tue, 20 Jun 2023 11:02:04 +0100 Subject: [PATCH 01/11] add isNegated attribute to constraint operations --- mlir/include/mlir/Dialect/PDL/IR/PDLOps.td | 2 +- mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td index 45279e95edbba..b0fb911143203 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -45,7 +45,7 @@ def PDL_ApplyNativeConstraintOp ``` }]; - let arguments = (ins StrAttr:$name, Variadic:$args); + let arguments = (ins StrAttr:$name, Variadic:$args, DefaultValuedAttr:$isNegated); let results = (outs Variadic:$results); let assemblyFormat = "$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict"; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td index 742de481c25ea..413310416d0c1 100644 --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -100,8 +100,7 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> { pdl_interp.apply_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest ``` }]; - - let arguments = (ins StrAttr:$name, Variadic:$args); + let arguments = (ins StrAttr:$name, Variadic:$args, DefaultValuedAttr:$isNegated); let results = (outs Variadic:$results); let assemblyFormat = [{ $name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict `->` successors From fad9d3be8ae5e14b8b1aff310f56ea05986f1534 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20L=C3=BCcke?= Date: Tue, 20 Jun 2023 11:05:32 +0100 Subject: [PATCH 02/11] add support for negated constraints to ConvertPDLToPDLInterp pass --- .../Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp | 2 +- mlir/lib/Conversion/PDLToPDLInterp/Predicate.h | 15 ++++++++++----- .../Conversion/PDLToPDLInterp/PredicateTree.cpp | 3 ++- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index cb2e40184fc20..128b3a0ae7919 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -468,7 +468,7 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, auto *cstQuestion = cast(question); auto applyConstraintOp = builder.create( loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args, - success, failure); + cstQuestion->getIsNegated(), success, failure); // Replace the generated placeholders with the results of the constraint and // erase them for (auto result : llvm::enumerate(applyConstraintOp.getResults())) { diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h index 20304bd35dc28..c6ee9cd3baaea 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -472,7 +472,7 @@ struct AttributeQuestion struct ConstraintQuestion : public PredicateBase< ConstraintQuestion, Qualifier, - std::tuple, ArrayRef>, + std::tuple, ArrayRef, bool>, Predicates::ConstraintQuestion> { using Base::Base; @@ -485,13 +485,18 @@ struct ConstraintQuestion /// Return the result types of the constraint. ArrayRef getResultTypes() const { return std::get<2>(key); } + bool getIsNegated() const { return std::get<3>(key); } + /// Construct an instance with the given storage allocator. static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, KeyTy key) { return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), alloc.copyInto(std::get<1>(key)), - alloc.copyInto(std::get<2>(key))}); + alloc.copyInto(std::get<2>(key)), + std::get<3>(key)}); } + + static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(key); } }; /// Compare the equality of two values. @@ -698,9 +703,9 @@ class PredicateBuilder { /// Create a predicate that applies a generic constraint. Predicate getConstraint(StringRef name, ArrayRef args, - ArrayRef resultTypes) { - return {ConstraintQuestion::get(uniquer, - std::make_tuple(name, args, resultTypes)), + ArrayRef resultTypes, bool isNegated) { + return {ConstraintQuestion::get( + uniquer, std::make_tuple(name, args, resultTypes, isNegated)), TrueAnswer::get(uniquer)}; } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 1d04e427c0167..424089e469b61 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -274,7 +274,8 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, comparePosDepth); ResultRange results = op.getResults(); PredicateBuilder::Predicate pred = builder.getConstraint( - op.getName(), allPositions, SmallVector(results.getTypes())); + op.getName(), allPositions, SmallVector(results.getTypes()), + op.getIsNegated()); // for each result register a position so it can be used later for (auto result : llvm::enumerate(results)) { From 93bded745bb1b478ae233ec167905894e91cc2dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20L=C3=BCcke?= Date: Tue, 20 Jun 2023 11:07:17 +0100 Subject: [PATCH 03/11] add support for negated constraints to Bytecode interpreter --- mlir/lib/Rewrite/ByteCode.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 1e0c0f05bf19c..dd1866c023a26 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -791,6 +791,7 @@ void Generator::generate(pdl_interp::ApplyConstraintOp op, // TODO: Handle result ranges writer.append(result); } + writer.append(ByteCodeField(op.getIsNegated())); writer.append(op.getSuccessors()); } void Generator::generate(pdl_interp::ApplyRewriteOp op, @@ -1447,7 +1448,9 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { const PDLConstraintFunction &constraintFn = constraintFunctions[fun_idx]; LogicalResult rewriteResult = constraintFn(rewriter, args); // Depending on the constraint jump to the proper destination. - selectJump(succeeded(rewriteResult)); + ByteCodeField isNegated = read(); + llvm::dbgs() << " * isNegated: " << isNegated; + selectJump(isNegated != succeeded(rewriteResult)); } else { const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx]; ByteCodeRewriteResultList results(numResults); @@ -1474,7 +1477,9 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { : 0; } // Depending on the constraint jump to the proper destination. - selectJump(succeeded(rewriteResult)); + ByteCodeField isNegated = read(); + llvm::dbgs() << " * isNegated: " << isNegated; + selectJump(isNegated != succeeded(rewriteResult)); } } From 0f00bff0316f4c087a3f859336b917bb3b003651 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20L=C3=BCcke?= Date: Tue, 20 Jun 2023 11:09:25 +0100 Subject: [PATCH 04/11] add support for negated constraints to PDLL AST --- mlir/include/mlir/Tools/PDLL/AST/Nodes.h | 13 ++++++++++--- mlir/lib/Tools/PDLL/AST/Nodes.cpp | 7 ++++--- mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp | 22 +++++++++++++--------- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h index 5760592cae1a2..fda83efc987a2 100644 --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -390,7 +390,8 @@ class CallExpr final : public Node::NodeBase, private llvm::TrailingObjects { public: static CallExpr *create(Context &ctx, SMRange loc, Expr *callable, - ArrayRef arguments, Type resultType); + ArrayRef arguments, Type resultType, + bool isNegated = false); /// Return the callable of this call. Expr *getCallableExpr() const { return callable; } @@ -403,9 +404,13 @@ class CallExpr final : public Node::NodeBase, return const_cast(this)->getArguments(); } + bool getIsNegated() const { return isNegated; } + private: - CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs) - : Base(loc, type), callable(callable), numArgs(numArgs) {} + CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs, + bool isNegated) + : Base(loc, type), callable(callable), numArgs(numArgs), + isNegated(isNegated) {} /// The callable of this call. Expr *callable; @@ -415,6 +420,8 @@ class CallExpr final : public Node::NodeBase, /// TrailingObject utilities. friend llvm::TrailingObjects; + + bool isNegated; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp index 04d978ddd043b..7c838e103ac02 100644 --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -266,12 +266,13 @@ AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc, //===----------------------------------------------------------------------===// CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable, - ArrayRef arguments, Type resultType) { + ArrayRef arguments, Type resultType, + bool isNegated) { unsigned allocSize = CallExpr::totalSizeToAlloc(arguments.size()); void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr)); - CallExpr *expr = - new (rawData) CallExpr(loc, resultType, callable, arguments.size()); + CallExpr *expr = new (rawData) + CallExpr(loc, resultType, callable, arguments.size(), isNegated); std::uninitialized_copy(arguments.begin(), arguments.end(), expr->getArguments().begin()); return expr; diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp index 3d7a9263e42be..5ade71267aed0 100644 --- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -103,12 +103,14 @@ class CodeGen { Value genExprImpl(const ast::TypeExpr *expr); SmallVector genConstraintCall(const ast::UserConstraintDecl *decl, - Location loc, ValueRange inputs); + Location loc, ValueRange inputs, + bool isNegated = false); SmallVector genRewriteCall(const ast::UserRewriteDecl *decl, Location loc, ValueRange inputs); template SmallVector genConstraintOrRewriteCall(const T *decl, Location loc, - ValueRange inputs); + ValueRange inputs, + bool isNegated = false); //===--------------------------------------------------------------------===// // Fields @@ -419,7 +421,7 @@ SmallVector CodeGen::genExprImpl(const ast::CallExpr *expr) { // Generate the PDL based on the type of callable. const ast::Decl *callable = callableExpr->getDecl(); if (const auto *decl = dyn_cast(callable)) - return genConstraintCall(decl, loc, arguments); + return genConstraintCall(decl, loc, arguments, expr->getIsNegated()); if (const auto *decl = dyn_cast(callable)) return genRewriteCall(decl, loc, arguments); llvm_unreachable("unhandled CallExpr callable"); @@ -553,15 +555,15 @@ Value CodeGen::genExprImpl(const ast::TypeExpr *expr) { SmallVector CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc, - ValueRange inputs) { + ValueRange inputs, bool isNegated) { // Apply any constraints defined on the arguments to the input values. for (auto it : llvm::zip(decl->getInputs(), inputs)) applyVarConstraints(std::get<0>(it), std::get<1>(it)); // Generate the constraint call. SmallVector results = - genConstraintOrRewriteCall(decl, loc, - inputs); + genConstraintOrRewriteCall( + decl, loc, inputs, isNegated); // Apply any constraints defined on the results of the constraint. for (auto it : llvm::zip(decl->getResults(), results)) @@ -576,9 +578,9 @@ SmallVector CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl, } template -SmallVector CodeGen::genConstraintOrRewriteCall(const T *decl, - Location loc, - ValueRange inputs) { +SmallVector +CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc, + ValueRange inputs, bool isNegated) { const ast::CompoundStmt *cstBody = decl->getBody(); // If the decl doesn't have a statement body, it is a native decl. @@ -593,6 +595,8 @@ SmallVector CodeGen::genConstraintOrRewriteCall(const T *decl, } Operation *pdlOp = builder.create( loc, resultTypes, decl->getName().getName(), inputs); + if (isNegated) + pdlOp->setAttr("isNegated", builder.getBoolAttr(true)); return pdlOp->getResults(); } From e624a6bd93a5fd3115a475591cdae8cc62fc9ea3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20L=C3=BCcke?= Date: Tue, 20 Jun 2023 11:18:47 +0100 Subject: [PATCH 05/11] add support for negated constraints to PDLL parsing --- mlir/lib/Tools/PDLL/Parser/Lexer.cpp | 2 ++ mlir/lib/Tools/PDLL/Parser/Lexer.h | 1 + mlir/lib/Tools/PDLL/Parser/Parser.cpp | 23 +++++++++++++++-------- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp index 337a8c0b53e4e..388dab188bbaf 100644 --- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp @@ -233,6 +233,8 @@ Token Lexer::lexToken() { return formToken(Token::l_paren, tokStart); case ')': return formToken(Token::r_paren, tokStart); + case '!': + return formToken(Token::exclam, tokStart); case '/': if (*curPtr == '/') { lexComment(); diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h index 6a78669f854d5..b649d1af1f3fe 100644 --- a/mlir/lib/Tools/PDLL/Parser/Lexer.h +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h @@ -79,6 +79,7 @@ class Token { equal, equal_arrow, semicolon, + exclam, /// Paired punctuation. less, greater, diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 3b04a065a2b6c..7670bac73ab95 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -316,7 +316,8 @@ class Parser { /// Identifier expressions. FailureOr parseArrayAttrExpr(); FailureOr parseAttributeExpr(); - FailureOr parseCallExpr(ast::Expr *parentExpr); + FailureOr parseCallExpr(ast::Expr *parentExpr, + bool isNegated = false); FailureOr parseDeclRefExpr(StringRef name, SMRange loc); FailureOr parseDictAttrExpr(); FailureOr parseIdentifierExpr(); @@ -406,7 +407,7 @@ class Parser { FailureOr createCallExpr(SMRange loc, ast::Expr *parentExpr, - MutableArrayRef arguments); + MutableArrayRef arguments, bool isNegated); FailureOr createDeclRefExpr(SMRange loc, ast::Decl *decl); FailureOr createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc, @@ -1844,6 +1845,11 @@ FailureOr Parser::parseExpr() { case Token::dot: lhsExpr = parseMemberAccessExpr(*lhsExpr); break; + case Token::exclam: + // TODO: Fx: This parses the "!" as suffix instead of prefix. + consumeToken(Token::exclam); + lhsExpr = parseCallExpr(*lhsExpr, /*isNegated = */ true); + break; case Token::l_paren: lhsExpr = parseCallExpr(*lhsExpr); break; @@ -1912,7 +1918,8 @@ FailureOr Parser::parseAttributeExpr() { return ast::AttributeExpr::create(ctx, loc, attrExpr); } -FailureOr Parser::parseCallExpr(ast::Expr *parentExpr) { +FailureOr Parser::parseCallExpr(ast::Expr *parentExpr, + bool isNegated) { consumeToken(Token::l_paren); // Parse the arguments of the call. @@ -1936,7 +1943,7 @@ FailureOr Parser::parseCallExpr(ast::Expr *parentExpr) { if (failed(parseToken(Token::r_paren, "expected `)` after argument list"))) return failure(); - return createCallExpr(loc, parentExpr, arguments); + return createCallExpr(loc, parentExpr, arguments, isNegated); } FailureOr Parser::parseDeclRefExpr(StringRef name, SMRange loc) { @@ -2789,7 +2796,8 @@ Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) { FailureOr Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, - MutableArrayRef arguments) { + MutableArrayRef arguments, + bool isNegated = false) { ast::Type parentType = parentExpr->getType(); ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr); @@ -2835,7 +2843,7 @@ Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, } return ast::CallExpr::create(ctx, loc, parentExpr, arguments, - callableDecl->getResultType()); + callableDecl->getResultType(), isNegated); } FailureOr Parser::createDeclRefExpr(SMRange loc, @@ -2959,8 +2967,7 @@ FailureOr Parser::createOperationExpr( OpResultTypeContext resultTypeContext, SmallVectorImpl &operands, MutableArrayRef attributes, - SmallVectorImpl &results, - unsigned numRegions) { + SmallVectorImpl &results, unsigned numRegions) { std::optional opNameRef = name->getName(); const ods::Operation *odsOp = lookupODSOperation(opNameRef); From fa9a1b40a2184acbf4ccc3152d552742f1bd2c57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20L=C3=BCcke?= Date: Wed, 21 Jun 2023 09:12:19 +0100 Subject: [PATCH 06/11] support prefix parsing of ! + add parsing test --- mlir/lib/Tools/PDLL/AST/NodePrinter.cpp | 5 ++++- mlir/lib/Tools/PDLL/Parser/Parser.cpp | 26 +++++++++++++++++++------ mlir/test/mlir-pdll/Parser/expr.pdll | 15 ++++++++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp index 9a6f14b0e7e5f..34082415d52b8 100644 --- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -225,7 +225,10 @@ void NodePrinter::printImpl(const AttributeExpr *expr) { void NodePrinter::printImpl(const CallExpr *expr) { os << "CallExpr " << expr << " Type<"; print(expr->getType()); - os << ">\n"; + os << ">"; + if (expr->getIsNegated()) + os << " negated"; + os << "\n"; printChildren(expr->getCallableExpr()); printChildren("Arguments", expr->getArguments()); } diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 7670bac73ab95..e54ffb5818d48 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -31,8 +31,8 @@ #include "llvm/Support/ScopedPrinter.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Parser.h" -#include #include +#include using namespace mlir; using namespace mlir::pdll; @@ -324,6 +324,7 @@ class Parser { FailureOr parseInlineConstraintLambdaExpr(); FailureOr parseInlineRewriteLambdaExpr(); FailureOr parseMemberAccessExpr(ast::Expr *parentExpr); + FailureOr parseNegatedExpr(); FailureOr parseOperationName(bool allowEmptyName = false); FailureOr parseWrappedOperationName(bool allowEmptyName); FailureOr @@ -1830,6 +1831,9 @@ FailureOr Parser::parseExpr() { case Token::l_square: lhsExpr = parseArrayAttrExpr(); break; + case Token::exclam: + lhsExpr = parseNegatedExpr(); + break; case Token::string_block: return emitError("expected expression. If you are trying to create an " "ArrayAttr, use a space between `[` and `{`."); @@ -1845,11 +1849,11 @@ FailureOr Parser::parseExpr() { case Token::dot: lhsExpr = parseMemberAccessExpr(*lhsExpr); break; - case Token::exclam: - // TODO: Fx: This parses the "!" as suffix instead of prefix. - consumeToken(Token::exclam); - lhsExpr = parseCallExpr(*lhsExpr, /*isNegated = */ true); - break; + // case Token::exclam: + // // TODO: Fx: This parses the "!" as suffix instead of prefix. + // consumeToken(Token::exclam); + // lhsExpr = parseCallExpr(*lhsExpr, /*isNegated = */ true); + // break; case Token::l_paren: lhsExpr = parseCallExpr(*lhsExpr); break; @@ -2068,6 +2072,16 @@ FailureOr Parser::parseMemberAccessExpr(ast::Expr *parentExpr) { return createMemberAccessExpr(parentExpr, memberName, loc); } +FailureOr Parser::parseNegatedExpr() { + consumeToken(Token::exclam); + if (!curToken.is(Token::identifier)) + return emitError("expected native constraint"); + FailureOr identifierExpr = parseIdentifierExpr(); + if (failed(identifierExpr)) + return failure(); + return parseCallExpr(*identifierExpr, /*isNegated = */ true); +} + FailureOr Parser::parseOperationName(bool allowEmptyName) { SMRange loc = curToken.getLoc(); diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index 6961de6c9bfb8..266825295a7b6 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -77,6 +77,21 @@ Pattern { // ----- +// CHECK: Module {{.*}} +// CHECK: -UserConstraintDecl {{.*}} Name ResultType> +// CHECK: `-PatternDecl {{.*}} +// CHECK: -CallExpr {{.*}} Type> negated +// CHECK: `-DeclRefExpr {{.*}} Type +// CHECK: `-UserConstraintDecl {{.*}} Name ResultType> +Constraint checkOp(op: Op); + +Pattern { + let inputOp = op; + !checkOp(inputOp); + erase inputOp; +} +// ----- + //===----------------------------------------------------------------------===// // MemberAccessExpr //===----------------------------------------------------------------------===// From bf50d70f009db03d59d61fd08bbd6af5688e99ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20L=C3=BCcke?= Date: Wed, 21 Jun 2023 12:37:07 +0100 Subject: [PATCH 07/11] remove debug prints --- mlir/lib/Rewrite/ByteCode.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index dd1866c023a26..42fadb7c939ea 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -1449,7 +1449,6 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { LogicalResult rewriteResult = constraintFn(rewriter, args); // Depending on the constraint jump to the proper destination. ByteCodeField isNegated = read(); - llvm::dbgs() << " * isNegated: " << isNegated; selectJump(isNegated != succeeded(rewriteResult)); } else { const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx]; @@ -1478,7 +1477,6 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { } // Depending on the constraint jump to the proper destination. ByteCodeField isNegated = read(); - llvm::dbgs() << " * isNegated: " << isNegated; selectJump(isNegated != succeeded(rewriteResult)); } } From 5c457ff0dce757166bc5d48801113e66ce021db0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20L=C3=BCcke?= Date: Wed, 21 Jun 2023 12:38:42 +0100 Subject: [PATCH 08/11] add tests --- .../pdl-to-pdl-interp-matcher.mlir | 14 ++++++++ mlir/test/Rewrite/pdl-bytecode.mlir | 32 +++++++++++++++++++ mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll | 14 ++++++++ mlir/test/mlir-pdll/Parser/expr.pdll | 1 + 4 files changed, 61 insertions(+) diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir index 13578274b35db..6b194091ae8e9 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -130,6 +130,20 @@ module @constraint_with_result_multiple { // ----- +// CHECK-LABEL: module @negated_constraint +module @negated_constraint { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: pdl_interp.apply_constraint "constraint"(%[[ROOT]] : !pdl.operation) {isNegated = true} + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation) + pdl.pattern : benefit(1) { + %root = operation + pdl.apply_native_constraint "constraint"(%root : !pdl.operation) {isNegated = true} + rewrite %root with "rewriter" + } +} + +// ----- + // CHECK-LABEL: module @inputs module @inputs { // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir index 7976512f5e88a..7ddd0f7e6314f 100644 --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -103,6 +103,38 @@ module @ir attributes { test.apply_constraint_3 } { // ----- +// Test support for negated constraints. +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + pdl_interp.apply_constraint "single_entity_constraint"(%root : !pdl.operation) {isNegated = true} -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.replaced_by_pattern" + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_constraint_4 +// CHECK-NOT: "test.replaced_by_pattern" +// CHECK: "test.replaced_by_pattern" + +module @ir attributes { test.apply_constraint_4 } { + "test.op"() : () -> () + "test.foo"() : () -> () +} + +// ----- + //===----------------------------------------------------------------------===// // pdl_interp::ApplyRewriteOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll index de9d467c82b8f..a3e1bb88e168a 100644 --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -36,6 +36,20 @@ Pattern TestExternalCall => replace root: Op with TestRewrite(root); // ----- +// CHECK: pdl.pattern @TestExternalNegatedCall +// CHECK: %[[ROOT:.*]] = operation +// CHECK: apply_native_constraint "TestConstraint"(%[[ROOT]] : !pdl.operation) {isNegated = true} +// CHECK: rewrite %[[ROOT]] +// CHECK: erase %[[ROOT]] +Constraint TestConstraint(op: Op); +Pattern TestExternalNegatedCall with benefit(1) { + let root = op : Op; + !TestConstraint(root); + erase root; +} + +// ----- + //===----------------------------------------------------------------------===// // MemberAccessExpr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index 266825295a7b6..938e181587030 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -90,6 +90,7 @@ Pattern { !checkOp(inputOp); erase inputOp; } + // ----- //===----------------------------------------------------------------------===// From 126ea5c58dc41ef4b0dda8194b4b0c4f68e96b8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20L=C3=BCcke?= Date: Thu, 22 Jun 2023 10:45:08 +0100 Subject: [PATCH 09/11] fix interpreter test --- mlir/test/Rewrite/pdl-bytecode.mlir | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir index 7ddd0f7e6314f..26d1f6e710b66 100644 --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -106,9 +106,14 @@ module @ir attributes { test.apply_constraint_3 } { // Test support for negated constraints. module @patterns { pdl_interp.func @matcher(%root : !pdl.operation) { - pdl_interp.apply_constraint "single_entity_constraint"(%root : !pdl.operation) {isNegated = true} -> ^pat, ^end + %test_attr = pdl_interp.create_attribute unit + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.are_equal %test_attr, %attr : !pdl.attribute -> ^pat, ^end ^pat: + pdl_interp.apply_constraint "single_entity_constraint"(%root : !pdl.operation) {isNegated = true} -> ^pat1, ^end + + ^pat1: pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end ^end: @@ -129,8 +134,8 @@ module @patterns { // CHECK: "test.replaced_by_pattern" module @ir attributes { test.apply_constraint_4 } { - "test.op"() : () -> () - "test.foo"() : () -> () + "test.op"() { test_attr } : () -> () + "test.foo"() { test_attr } : () -> () } // ----- From 0e497f8bdfa31777275fa63f53de84ac5726d9c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20L=C3=BCcke?= Date: Fri, 23 Jun 2023 15:18:55 +0100 Subject: [PATCH 10/11] address review comments --- mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td | 1 + mlir/lib/Tools/PDLL/Parser/Parser.cpp | 7 ++----- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td index 413310416d0c1..c830020eff46b 100644 --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -100,6 +100,7 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> { pdl_interp.apply_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest ``` }]; + let arguments = (ins StrAttr:$name, Variadic:$args, DefaultValuedAttr:$isNegated); let results = (outs Variadic:$results); let assemblyFormat = [{ diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index e54ffb5818d48..bf035bb07b363 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -1849,11 +1849,6 @@ FailureOr Parser::parseExpr() { case Token::dot: lhsExpr = parseMemberAccessExpr(*lhsExpr); break; - // case Token::exclam: - // // TODO: Fx: This parses the "!" as suffix instead of prefix. - // consumeToken(Token::exclam); - // lhsExpr = parseCallExpr(*lhsExpr, /*isNegated = */ true); - // break; case Token::l_paren: lhsExpr = parseCallExpr(*lhsExpr); break; @@ -2825,6 +2820,8 @@ Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr, if (isa(callableDecl)) return emitError( loc, "unable to invoke `Constraint` within a rewrite section"); + if (isNegated) + return emitError(loc, "negation of Rewrites is not supported"); } else if (isa(callableDecl)) { return emitError(loc, "unable to invoke `Rewrite` within a match section"); } From 3ec1761d0be131347c2fcfc136c8bf53a4690a6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20L=C3=BCcke?= Date: Fri, 23 Jun 2023 15:30:22 +0100 Subject: [PATCH 11/11] add test --- mlir/test/mlir-pdll/Parser/expr-failure.pdll | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll index c55cf78de55ca..9c299c55fc311 100644 --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -183,6 +183,18 @@ Pattern { // ----- +Rewrite Foo(op: Op); + +Pattern { + // CHECK: negation of Rewrites is not supported + let root = op<>; + rewrite root with { + !Foo(root); + } +} + +// ----- + Pattern { // CHECK: expected expression let tuple = (10 = _: Value);