diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp index 74b02cc5209d9..337a8c0b53e4e 100644 --- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp @@ -378,7 +378,7 @@ Token Lexer::lexString(const char *tokStart, bool isStringBlock) { --curPtr; StringRef expectedEndStr = isStringBlock ? "}]" : "\""; - return emitError(curPtr - 1, + return emitError(tokStart, "expected '" + expectedEndStr + "' in string literal"); } diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index e89b8c8908a35..3b04a065a2b6c 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -314,9 +314,11 @@ class Parser { FailureOr parseExpr(); /// Identifier expressions. + FailureOr parseArrayAttrExpr(); FailureOr parseAttributeExpr(); FailureOr parseCallExpr(ast::Expr *parentExpr); FailureOr parseDeclRefExpr(StringRef name, SMRange loc); + FailureOr parseDictAttrExpr(); FailureOr parseIdentifierExpr(); FailureOr parseInlineConstraintLambdaExpr(); FailureOr parseInlineRewriteLambdaExpr(); @@ -329,7 +331,6 @@ class Parser { FailureOr parseTupleExpr(); FailureOr parseTypeExpr(); FailureOr parseUnderscoreExpr(); - //===--------------------------------------------------------------------===// // Stmts @@ -413,6 +414,13 @@ class Parser { FailureOr createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); + // Create a native call with \p nativeFuncName and \p arguments. + // This should be accompanied by a C++ implementation of the function that + // needs to be linked and registered in passes that process PDLL files. + FailureOr + createNativeCall(SMRange loc, StringRef nativeFuncName, + MutableArrayRef arguments); + /// Validate the member access `name` into the given parent expression. On /// success, this also returns the type of the member accessed. FailureOr validateMemberAccess(ast::Expr *parentExpr, @@ -1815,6 +1823,15 @@ FailureOr Parser::parseExpr() { case Token::l_paren: lhsExpr = parseTupleExpr(); break; + case Token::l_brace: + lhsExpr = parseDictAttrExpr(); + break; + case Token::l_square: + lhsExpr = parseArrayAttrExpr(); + break; + case Token::string_block: + return emitError("expected expression. If you are trying to create an " + "ArrayAttr, use a space between `[` and `{`."); default: return emitError("expected expression"); } @@ -1838,6 +1855,40 @@ FailureOr Parser::parseExpr() { } } +FailureOr Parser::parseArrayAttrExpr() { + + consumeToken(Token::l_square); + + if (parserContext != ParserContext::Rewrite) + return emitError( + "Parsing of array attributes as constraint not supported!"); + + auto arrayAttrCall = + createNativeCall(curToken.getLoc(), "createArrayAttr", {}); + if (failed(arrayAttrCall)) + return failure(); + + do { + FailureOr attr = parseExpr(); + if (failed(attr)) + return failure(); + + SmallVector arrayAttrArgs{*arrayAttrCall, *attr}; + auto elemToArrayCall = createNativeCall( + curToken.getLoc(), "addElemToArrayAttr", arrayAttrArgs); + if (failed(elemToArrayCall)) + return failure(); + + // Uses the new array for the next element. + arrayAttrCall = elemToArrayCall; + } while (consumeIf(Token::comma)); + + if (failed( + parseToken(Token::r_square, "expected `]` to close array attribute"))) + return failure(); + return arrayAttrCall; +} + FailureOr Parser::parseAttributeExpr() { SMRange loc = curToken.getLoc(); consumeToken(Token::kw_attr); @@ -1896,6 +1947,62 @@ FailureOr Parser::parseDeclRefExpr(StringRef name, SMRange loc) { return createDeclRefExpr(loc, decl); } +FailureOr Parser::parseDictAttrExpr() { + consumeToken(Token::l_brace); + SMRange loc = curToken.getLoc(); + + if (parserContext != ParserContext::Rewrite) + return emitError( + "Parsing of dictionary attributes as constraint not supported!"); + + auto dictAttrCall = createNativeCall(loc, "createDictionaryAttr", {}); + if (failed(dictAttrCall)) + return failure(); + + // Add each nested attribute to the dict + do { + FailureOr decl = + parseNamedAttributeDecl(std::nullopt); + if (failed(decl)) + return failure(); + + ast::NamedAttributeDecl *namedDecl = *decl; + + std::string stringAttrValue = + "\"" + std::string((*namedDecl).getName().getName()) + "\""; + auto *stringAttr = ast::AttributeExpr::create(ctx, loc, stringAttrValue); + + // Declare it as a variable + std::string anonName = + llvm::formatv("dict{0}", anonymousDeclNameCounter++).str(); + FailureOr stringAttrDecl = + createVariableDecl(anonName, namedDecl->getLoc(), stringAttr, {}); + if (failed(stringAttrDecl)) + return failure(); + + // Get its reference + auto stringAttrRef = parseDeclRefExpr( + (*stringAttrDecl)->getName().getName(), namedDecl->getLoc()); + if (failed(stringAttrRef)) + return failure(); + + // Create addEntryToDictionaryAttr native call. + SmallVector arrayAttrArgs{*dictAttrCall, *stringAttrRef, + namedDecl->getValue()}; + auto entryToDictionaryCall = + createNativeCall(loc, "addEntryToDictionaryAttr", arrayAttrArgs); + if (failed(entryToDictionaryCall)) + return failure(); + + // Uses the new array for the next element. + dictAttrCall = entryToDictionaryCall; + } while (consumeIf(Token::comma)); + if (failed(parseToken(Token::r_brace, + "expected `}` to close dictionary attribute"))) + return failure(); + return dictAttrCall; +} + FailureOr Parser::parseIdentifierExpr() { StringRef name = curToken.getSpelling(); SMRange nameLoc = curToken.getLoc(); @@ -2769,6 +2876,35 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); } +FailureOr +Parser::createNativeCall(SMRange loc, StringRef nativeFuncName, + MutableArrayRef arguments) { + + FailureOr nativeFuncExpr = parseDeclRefExpr(nativeFuncName, loc); + if (failed(nativeFuncExpr)) + return failure(); + + if (!(*nativeFuncExpr)->getType().isa()) + return emitError(nativeFuncName + " should be defined as a rewriter."); + + FailureOr nativeCall = + createCallExpr(loc, *nativeFuncExpr, arguments); + if (failed(nativeCall)) + return failure(); + + // Create a unique anonymous name declaration to use, as its name is not + // important. + std::string anonName = + llvm::formatv("{0}_{1}", nativeFuncName, anonymousDeclNameCounter++) + .str(); + FailureOr varDecl = defineVariableDecl( + anonName, loc, (*nativeCall)->getType(), *nativeCall, {}); + if (failed(varDecl)) + return failure(); + + return createDeclRefExpr(loc, *varDecl); +} + FailureOr Parser::validateMemberAccess(ast::Expr *parentExpr, StringRef name, SMRange loc) { ast::Type parentType = parentExpr->getType(); diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll index 950b90d75d6a4..de9d467c82b8f 100644 --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -142,3 +142,121 @@ Pattern RangeExpr { // CHECK: %[[TYPE:.*]] = type : i32 // CHECK: operation({{.*}}) -> (%[[TYPE]] : !pdl.type) Pattern TypeExpr => erase op<> -> (type<"i32">); + +// ----- + +//===----------------------------------------------------------------------===// +// Parse attributes and rewrite +//===----------------------------------------------------------------------===// + +// Rewriter helpers declarations. +Rewrite createDictionaryAttr() -> Attr; +Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr; +Rewrite createArrayAttr() -> Attr; +Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr; + +// CHECK-LABEL: pdl.pattern @RewriteOneEntryDictionary +// CHECK: %[[VAL_1:.*]] = operation "test.op" +// CHECK: %[[VAL_2:.*]] = attribute = "test" +// CHECK: rewrite %[[VAL_1]] { +// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "createDictionaryAttr" +// CHECK: %[[VAL_4:.*]] = attribute = "firstAttr" +// CHECK: %[[VAL_5:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_2]] +// CHECK: %[[VAL_6:.*]] = operation "test.success" {"some_dictionary" = %[[VAL_5]]} +// CHECK: replace %[[VAL_1]] with %[[VAL_6]] +Pattern RewriteOneEntryDictionary { + let root = op -> (); + let attr1 = attr<"\"test\"">; + rewrite root with { + let newRoot = op() { some_dictionary = {firstAttr=attr1} } -> (); + replace root with newRoot; + }; +} + +// ----- + +// Rewriter helpers declarations. +Rewrite createDictionaryAttr() -> Attr; +Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr; + +// CHECK-LABEL: pdl.pattern @RewriteMultipleEntriesDictionary +// CHECK: %[[VAL_1:.*]] = operation "test.op" +// CHECK: %[[VAL_2:.*]] = attribute = "test2" +// CHECK: %[[VAL_3:.*]] = attribute = "test3" +// CHECK: rewrite %[[VAL_1]] { +// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "createDictionaryAttr" +// CHECK: %[[VAL_5:.*]] = attribute = "firstAttr" +// CHECK: %[[VAL_6:.*]] = attribute = "test1" +// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]] +// CHECK: %[[VAL_8:.*]] = attribute = "secondAttr" +// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_7]], %[[VAL_8]], %[[VAL_2]] +// CHECK: %[[VAL_10:.*]] = attribute = "thirdAttr" +// CHECK: %[[VAL_11:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_9]], %[[VAL_10]], %[[VAL_3]] +// CHECK: %[[VAL_12:.*]] = operation "test.success" {"some_dictionary" = %[[VAL_11]]} +// CHECK: replace %[[VAL_1]] with %[[VAL_12]] +Pattern RewriteMultipleEntriesDictionary { + let root = op -> (); + let attr2 = attr<"\"test2\"">; + let attr3 = attr<"\"test3\"">; + rewrite root with { + let newRoot = op() { some_dictionary = {"firstAttr" = attr<"\"test1\"">, secondAttr = attr2, thirdAttr = attr3} } -> (); + replace root with newRoot; + }; +} + +// ----- + +// Rewriter helpers declarations. +Rewrite createDictionaryAttr() -> Attr; +Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr; +Rewrite createArrayAttr() -> Attr; +Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr; + +// CHECK-LABEL: pdl.pattern @RewriteOneDictionaryArrayAttr +// CHECK: %[[VAL_1:.*]] = operation "test.op" +// CHECK: rewrite %[[VAL_1]] { +// CHECK: %[[VAL_2:.*]] = apply_native_rewrite "createArrayAttr" +// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "createDictionaryAttr" +// CHECK: %[[VAL_4:.*]] = attribute = "firstAttr" +// CHECK: %[[VAL_5:.*]] = attribute = "test1" +// CHECK: %[[VAL_6:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]] +// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "addElemToArrayAttr"(%[[VAL_2]], %[[VAL_6]] +// CHECK: %[[VAL_8:.*]] = operation "test.success" {"some_array" = %[[VAL_7]]} +// CHECK: replace %[[VAL_1]] with %[[VAL_8]] +Pattern RewriteOneDictionaryArrayAttr { + let root = op -> (); + rewrite root with { + let newRoot = op() { some_array = [ {"firstAttr" = attr<"\"test1\"">}]} -> (); + replace root with newRoot; + }; +} + +// ----- + +// Rewriter helpers declarations. +Rewrite createDictionaryAttr() -> Attr; +Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr; +Rewrite createArrayAttr() -> Attr; +Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr; + +// CHECK-LABEL: pdl.pattern @RewriteMultiplyElementsArrayAttr +// CHECK: %[[VAL_1:.*]] = operation "test.op" +// CHECK: %[[VAL_2:.*]] = attribute = "test2" +// CHECK: rewrite %[[VAL_1]] { +// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "createArrayAttr" +// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "createDictionaryAttr" +// CHECK: %[[VAL_5:.*]] = attribute = "firstAttr" +// CHECK: %[[VAL_6:.*]] = attribute = "test1" +// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]] +// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "addElemToArrayAttr"(%[[VAL_3]], %[[VAL_7]] +// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "addElemToArrayAttr"(%[[VAL_8]], %[[VAL_2]] +// CHECK: %[[VAL_10:.*]] = operation "test.success" {"some_array" = %[[VAL_9]]} +// CHECK: replace %[[VAL_1]] with %[[VAL_10]] +Pattern RewriteMultiplyElementsArrayAttr { + let root = op -> (); + let attr2 = attr<"\"test2\"">; + rewrite root with { + let newRoot = op() { some_array = [ {"firstAttr" = attr<"\"test1\"">}, attr2]} -> (); + replace root with newRoot; + }; +} \ No newline at end of file diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll index 31258cb99ebc4..c55cf78de55ca 100644 --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -218,6 +218,43 @@ Pattern { // ----- +Pattern { + let root = op -> (); + // CHECK: expected expression. If you are trying to create an ArrayAttr, use a space between `[` and `{`. + rewrite root with { + let newRoot = op() { some_array = [{"firstAttr" = attr<"\"test\"">}]} -> (); + replace root with newRoot; + };; +} + + +// ----- + +Pattern { + let root = op -> (); + // CHECK: expected '}]' in string literal + rewrite root with { + let newRoot = op() { some_array = [{"firstAttr" = attr<"\"test\"">}, attr<"\"test\"">] } -> (); + replace root with newRoot; + }; +} + +// ----- + +Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr; + +Pattern { + let root = op -> (); + let attr = attr<"\"test\"">; + rewrite root with { + // CHECK: undefined reference to `createArrayAttr` + let newRoot = op() { some_array = [ attr<"\"test\""> ]} -> (); + replace root with newRoot; + }; +} + +// ----- + //===----------------------------------------------------------------------===// // `op` Expr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index a9e24a7b20b4b..6961de6c9bfb8 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -14,6 +14,33 @@ Pattern { // ----- +// CHECK: |-NamedAttributeDecl {{.*}} Name +// CHECK: `-UserRewriteDecl {{.*}} Name ResultType +// CHECK: `Arguments` +// CHECK: `-CallExpr {{.*}} Type +// CHECK: `-UserRewriteDecl {{.*}} Name ResultType +// CHECK: `-CallExpr {{.*}} Type +// CHECK: `-UserRewriteDecl {{.*}} Name ResultType +// CHECK: `Arguments` +// CHECK: `-CallExpr {{.*}} Type +// CHECK: `-UserRewriteDecl {{.*}} Name ResultType +// CHECK: `-AttributeExpr {{.*}} Value<""firstAttr""> +Rewrite createDictionaryAttr() -> Attr; +Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr; +Rewrite createArrayAttr() -> Attr; +Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr; + +Pattern { + let root = op -> (); + let attr = attr<"\"test\"">; + rewrite root with { + let newRoot = op() { some_array = [ {"firstAttr" = attr<"\"test\"">}], attr} -> (); + replace root with newRoot; + }; +} + +// ----- + //===----------------------------------------------------------------------===// // CallExpr //===----------------------------------------------------------------------===//