diff --git a/mlir/include/mlir/Dialect/PDL/IR/Builtins.h b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h index 0c6cceb54b68a..5e60391e278eb 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/Builtins.h +++ b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h @@ -43,8 +43,9 @@ enum class UnaryOpKind { LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter, PDLResultList &results, ArrayRef args); -Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr, - Attribute element); +LogicalResult addElemToArrayAttr(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args); LogicalResult mul(PatternRewriter &rewriter, PDLResultList &results, llvm::ArrayRef args); LogicalResult div(PatternRewriter &rewriter, PDLResultList &results, diff --git a/mlir/lib/Dialect/PDL/IR/Builtins.cpp b/mlir/lib/Dialect/PDL/IR/Builtins.cpp index 9e4efbf7e71c0..bc238092c141d 100644 --- a/mlir/lib/Dialect/PDL/IR/Builtins.cpp +++ b/mlir/lib/Dialect/PDL/IR/Builtins.cpp @@ -38,13 +38,19 @@ LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter, return success(); } -mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter, - mlir::Attribute attr, - mlir::Attribute element) { - assert(isa(attr)); - auto values = cast(attr).getValue().vec(); - values.push_back(element); - return rewriter.getArrayAttr(values); +LogicalResult addElemToArrayAttr(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args) { + + assert(args.size() == 2 && + "Expected two arguments, one ArrayAttr and one Attr"); + auto arrayAttr = cast(args[0].cast()); + auto attrElement = args[1].cast(); + llvm::SmallVector values(arrayAttr.getValue()); + values.push_back(attrElement); + + results.push_back(rewriter.getArrayAttr(values)); + return success(); } template @@ -344,11 +350,15 @@ void registerBuiltins(PDLPatternModule &pdlPattern) { // See Parser::defineBuiltins() pdlPattern.registerRewriteFunction( "__builtin_addEntryToDictionaryAttr_rewrite", addEntryToDictionaryAttr); - pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr", - addElemToArrayAttr); pdlPattern.registerConstraintFunction( "__builtin_addEntryToDictionaryAttr_constraint", addEntryToDictionaryAttr); + + pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttrRewriter", + addElemToArrayAttr); + pdlPattern.registerConstraintFunction( + "__builtin_addElemToArrayAttrConstraint", addElemToArrayAttr); + pdlPattern.registerRewriteFunction("__builtin_mulRewrite", mul); pdlPattern.registerRewriteFunction("__builtin_divRewrite", div); pdlPattern.registerRewriteFunction("__builtin_modRewrite", mod); @@ -357,22 +367,14 @@ void registerBuiltins(PDLPatternModule &pdlPattern) { pdlPattern.registerRewriteFunction("__builtin_log2Rewrite", log2); pdlPattern.registerRewriteFunction("__builtin_exp2Rewrite", exp2); pdlPattern.registerRewriteFunction("__builtin_absRewrite", abs); - pdlPattern.registerConstraintFunction("__builtin_mulConstraint", - mul); - pdlPattern.registerConstraintFunction("__builtin_divConstraint", - div); - pdlPattern.registerConstraintFunction("__builtin_modConstraint", - mod); - pdlPattern.registerConstraintFunction("__builtin_addConstraint", - add); - pdlPattern.registerConstraintFunction("__builtin_subConstraint", - sub); - pdlPattern.registerConstraintFunction("__builtin_log2Constraint", - log2); - pdlPattern.registerConstraintFunction("__builtin_exp2Constraint", - exp2); - pdlPattern.registerConstraintFunction("__builtin_absConstraint", - abs); + pdlPattern.registerConstraintFunction("__builtin_mulConstraint", mul); + pdlPattern.registerConstraintFunction("__builtin_divConstraint", div); + pdlPattern.registerConstraintFunction("__builtin_modConstraint", mod); + pdlPattern.registerConstraintFunction("__builtin_addConstraint", add); + pdlPattern.registerConstraintFunction("__builtin_subConstraint", sub); + pdlPattern.registerConstraintFunction("__builtin_log2Constraint", log2); + pdlPattern.registerConstraintFunction("__builtin_exp2Constraint", exp2); + pdlPattern.registerConstraintFunction("__builtin_absConstraint", abs); pdlPattern.registerConstraintFunction("__builtin_equals", equals); } } // namespace mlir::pdl diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 0250ecb0f7f28..aacb049f32b09 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -625,7 +625,8 @@ class Parser { struct { ast::UserRewriteDecl *addEntryToDictionaryAttr_Rewrite; ast::UserConstraintDecl *addEntryToDictionaryAttr_Constraint; - ast::UserRewriteDecl *addElemToArrayAttr; + ast::UserRewriteDecl *addElemToArrayAttrRewrite; + ast::UserConstraintDecl *addElemToArrayAttrConstraint; ast::UserRewriteDecl *mulRewrite; ast::UserRewriteDecl *divRewrite; ast::UserRewriteDecl *modRewrite; @@ -691,9 +692,13 @@ void Parser::declareBuiltins() { "__builtin_addEntryToDictionaryAttr_constraint", {"attr", "attrName", "attrEntry"}, /*returnsAttr=*/true); - builtins.addElemToArrayAttr = declareBuiltin( - "__builtin_addElemToArrayAttr", {"attr", "element"}, + builtins.addElemToArrayAttrRewrite = declareBuiltin( + "__builtin_addElemToArrayAttrRewriter", {"attr", "element"}, /*returnsAttr=*/true); + builtins.addElemToArrayAttrConstraint = + declareBuiltin( + "__builtin_addElemToArrayAttrConstraint", {"attr", "element"}, + /*returnsAttr=*/true); builtins.mulRewrite = declareBuiltin( "__builtin_mulRewrite", {"lhs", "rhs"}, true); builtins.divRewrite = declareBuiltin( @@ -2323,27 +2328,35 @@ FailureOr Parser::parseArrayAttrExpr() { consumeToken(Token::l_square); + ast::Decl *builtinFunction = builtins.addElemToArrayAttrRewrite; if (parserContext != ParserContext::Rewrite) - return emitError( - "Parsing of array attributes as constraint not supported!"); + builtinFunction = builtins.addElemToArrayAttrConstraint; - FailureOr arrayAttr = ast::AttributeExpr::create(ctx, curToken.getLoc(), "[]"); + FailureOr arrayAttr = + ast::AttributeExpr::create(ctx, curToken.getLoc(), "[]"); if (failed(arrayAttr)) return failure(); + // No values inside the array + if (consumeIf(Token::r_square)) { + return arrayAttr; + } + do { FailureOr attr = parseExpr(); if (failed(attr)) return failure(); SmallVector arrayAttrArgs{*arrayAttr, *attr}; - auto elemToArrayCall = createBuiltinCall( - curToken.getLoc(), builtins.addElemToArrayAttr, arrayAttrArgs); + + auto elemToArrayCall = + createBuiltinCall(curToken.getLoc(), builtinFunction, arrayAttrArgs); if (failed(elemToArrayCall)) return failure(); // Uses the new array for the next element. arrayAttr = elemToArrayCall; + } while (consumeIf(Token::comma)); if (failed( @@ -2415,7 +2428,8 @@ FailureOr Parser::parseDictAttrExpr() { consumeToken(Token::l_brace); SMRange loc = curToken.getLoc(); - FailureOr dictAttrCall = ast::AttributeExpr::create(ctx, loc, "{}"); + FailureOr dictAttrCall = + ast::AttributeExpr::create(ctx, loc, "{}"); if (failed(dictAttrCall)) return failure(); diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll index 84cba9035123f..b876eecadfbce 100644 --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -218,7 +218,7 @@ Pattern RewriteMultipleEntriesDictionary { // CHECK: %[[VAL_4:.*]] = attribute = "firstAttr" // CHECK: %[[VAL_5:.*]] = attribute = "test1" // CHECK: %[[VAL_6:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]] -// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_2]], %[[VAL_6]] +// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_2]], %[[VAL_6]] // CHECK: %[[VAL_8:.*]] = operation "test.success" {"some_array" = %[[VAL_7]]} // CHECK: replace %[[VAL_1]] with %[[VAL_8]] Pattern RewriteOneDictionaryArrayAttr { @@ -229,6 +229,43 @@ Pattern RewriteOneDictionaryArrayAttr { }; } +// ----- + +// CHECK-LABEL: pdl.pattern @ConstraintWithArrayAttr +// CHECK: %[[VAL_0:.*]] = attribute = "test1" +// CHECK: %[[VAL_1:.*]] = attribute = "test2" +// CHECK: %[[VAL_2:.*]] = attribute = [] +// CHECK: %[[VAL_3:.*]] = apply_native_constraint "__builtin_addElemToArrayAttrConstraint"(%[[VAL_2]], %[[VAL_0]] +// CHECK: %[[VAL_4:.*]] = apply_native_constraint "__builtin_addElemToArrayAttrConstraint"(%[[VAL_3]], %[[VAL_1]] +// CHECK: %[[VAL_5:.*]] = operation "test.op" +// CHECK: rewrite %[[VAL_5]] { +// CHECK: %[[VAL_6:.*]] = operation "test.success" {"some_array" = %[[VAL_4]]} +// CHECK: replace %[[VAL_5]] with %[[VAL_6]] + +Pattern ConstraintWithArrayAttr { + let attr1 = attr<"\"test1\"">; + let attr2 = attr<"\"test2\"">; + let array = [attr1, attr2]; + let root = op -> (); + rewrite root with { + let newRoot = op() { some_array = array} -> (); + replace root with newRoot; + }; +} + +// ----- + +// CHECK-LABEL: pdl.pattern @ConstraintNotMatchingArrayAttrInAttrType +// CHECK-NOT: apply_native_constraint "__builtin_addElemToArrayAttrConstraint" + + +Constraint I64Value(value: Value); +Pattern ConstraintNotMatchingArrayAttrInAttrType { + let root = op(arg: Value, arg2: Value, arg3: [Value, I64Value], arg); + replace root with arg; +} + + // ----- // CHECK-LABEL: pdl.pattern @RewriteMultiplyElementsArrayAttr @@ -240,8 +277,8 @@ Pattern RewriteOneDictionaryArrayAttr { // CHECK: %[[VAL_5:.*]] = attribute = "firstAttr" // CHECK: %[[VAL_6:.*]] = attribute = "test1" // CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]] -// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_3]], %[[VAL_7]] -// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_8]], %[[VAL_2]] +// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_3]], %[[VAL_7]] +// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_8]], %[[VAL_2]] // CHECK: %[[VAL_10:.*]] = operation "test.success" {"some_array" = %[[VAL_9]]} // CHECK: replace %[[VAL_1]] with %[[VAL_10]] Pattern RewriteMultiplyElementsArrayAttr { diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll index 34cf54fb7c23d..9d1218f124009 100644 --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -134,6 +134,22 @@ Pattern { // ----- +Pattern ConstraintArrayAttrWithAttrAndValue { + let root = op(arg: Value) -> (); + let attr1 = attr<"\"test1\"">; + let array = [attr1, arg]; + // CHECK: unable to convert expression of type `Value` to the expected type of `Attr` + let root = op -> (); + rewrite root with { + let newRoot = op() { some_array = array} -> (); + replace root with newRoot; + }; +} + +// ----- + + + //===----------------------------------------------------------------------===// // Range Expr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index fe3d8956b3dd7..8acdb7e9bba77 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -34,7 +34,7 @@ Pattern { // CHECK-LABEL: Module // CHECK: |-NamedAttributeDecl {{.*}} Name -// CHECK: `-UserRewriteDecl {{.*}} Name<__builtin_addElemToArrayAttr> ResultType +// CHECK: `-UserRewriteDecl {{.*}} Name<__builtin_addElemToArrayAttrRewriter> ResultType // CHECK: `Arguments` // CHECK: CallExpr {{.*}} Type // CHECK: AttributeExpr {{.*}} Value<"[]"> @@ -87,6 +87,77 @@ Constraint getPopulatedDict() -> Attr { return dictionary; } + + +// ----- + +// CHECK-LABEL: Module +// CHECK:LetStmt {{.*}} +//CHECK-NEXT:`-VariableDecl {{.*}} Name Type +//CHECK-NEXT: `-AttributeExpr {{.*}} Value<"[]"> +//CHECK-NEXT:ReturnStmt {{.*}} + +Constraint getEmtpyArray() -> Attr { + let array = []; + return array; +} + +// ----- + +// CHECK-LABEL: Module +// CHECK:LetStmt {{.*}} +//CHECK-NEXT:`-VariableDecl {{.*}} Name Type +//CHECK-NEXT: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType +// CHECK: `Arguments` +//CHECK-NEXT: |-AttributeExpr {{.*}} Value<"[]"> +//CHECK-NEXT: `-AttributeExpr {{.*}} Value<""attr1""> +//CHECK-NEXT:ReturnStmt {{.*}} + +Constraint getPopulateArray() -> Attr { + let array = ["attr1"]; + return array; +} + + +// ----- + + +// CHECK-LABEL: Module +// CHECK:LetStmt {{.*}} +//CHECK-NEXT:`-VariableDecl {{.*}} Name Type +//CHECK-NEXT: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType +// CHECK-DAG: `Arguments` +//CHECK-NEXT: |-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: | `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType +// CHECK-DAG: `Arguments` +//CHECK-NEXT: |-AttributeExpr {{.*}} Value<"[]"> +//CHECK-NEXT: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name ResultType +// CHECK: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name ResultType +// CHECK-DAG: -ReturnStmt {{.*}} + +Constraint getA() -> Attr { + return "A"; +} + +Constraint getB() -> Attr { + return "B"; +} + +Constraint getPopulateArrayFromOtherConstraints() -> Attr { + let array = [getA(), getB()]; + return array; +} + + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp index 21a620e3b6675..113fc1ff8640f 100644 --- a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp +++ b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp @@ -66,13 +66,17 @@ TEST_F(BuiltinTest, addEntryToDictionaryAttr) { } TEST_F(BuiltinTest, addElemToArrayAttr) { + TestPDLResultList results(1); + auto dict = rewriter.getDictionaryAttr( rewriter.getNamedAttr("key", rewriter.getStringAttr("value"))); rewriter.getArrayAttr({}); auto arrAttr = rewriter.getArrayAttr({}); + EXPECT_TRUE(succeeded( + builtin::addElemToArrayAttr(rewriter, results, {arrAttr, dict}))); mlir::Attribute updatedArrAttr = - builtin::addElemToArrayAttr(rewriter, arrAttr, dict); + results.getResults().front().cast(); auto dictInsideArrAttr = cast(*cast(updatedArrAttr).begin()); @@ -617,7 +621,7 @@ TEST_F(BuiltinTest, log2) { cast(result.cast()).getValue().convertToFloat(), 2.0); } - + auto threeF16 = rewriter.getF16FloatAttr(3.0); // check correctness @@ -626,7 +630,8 @@ TEST_F(BuiltinTest, log2) { EXPECT_TRUE(builtin::log2(rewriter, results, {threeF16}).succeeded()); PDLValue result = results.getResults()[0]; - float resultVal = cast(result.cast()).getValue().convertToFloat(); + float resultVal = + cast(result.cast()).getValue().convertToFloat(); EXPECT_TRUE(resultVal > 1.58 && resultVal < 1.59); } }