From a309fe1383616bb0e5bac624d3af5189c3f49e2f Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 26 Jan 2024 19:38:54 +0100 Subject: [PATCH] Pdl: Allow to define builtin native calls --- mlir/include/mlir/Dialect/PDL/IR/Builtins.h | 36 +++++++ mlir/lib/Dialect/PDL/IR/Builtins.cpp | 56 ++++++++++ mlir/lib/Dialect/PDL/IR/CMakeLists.txt | 1 + mlir/lib/Rewrite/FrozenRewritePatternSet.cpp | 3 + mlir/lib/Tools/PDLL/Parser/Parser.cpp | 105 ++++++++++++++----- mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll | 54 +++------- mlir/test/mlir-pdll/Parser/expr-failure.pdll | 14 --- mlir/test/mlir-pdll/Parser/expr.pdll | 22 ++-- mlir/unittests/Dialect/CMakeLists.txt | 1 + mlir/unittests/Dialect/PDL/BuiltinTest.cpp | 72 +++++++++++++ mlir/unittests/Dialect/PDL/CMakeLists.txt | 8 ++ 11 files changed, 279 insertions(+), 93 deletions(-) create mode 100644 mlir/include/mlir/Dialect/PDL/IR/Builtins.h create mode 100644 mlir/lib/Dialect/PDL/IR/Builtins.cpp create mode 100644 mlir/unittests/Dialect/PDL/BuiltinTest.cpp create mode 100644 mlir/unittests/Dialect/PDL/CMakeLists.txt diff --git a/mlir/include/mlir/Dialect/PDL/IR/Builtins.h b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h new file mode 100644 index 0000000000000..72603b7ec100c --- /dev/null +++ b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h @@ -0,0 +1,36 @@ +//===- Builtins.h - Builtin functions of the PDL dialect --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines builtin functions of the PDL dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PDL_IR_BUILTINS_H_ +#define MLIR_DIALECT_PDL_IR_BUILTINS_H_ + +namespace mlir { +class PDLPatternModule; +class Attribute; +class PatternRewriter; + +namespace pdl { +void registerBuiltins(PDLPatternModule &pdlPattern); + +namespace builtin { +Attribute createDictionaryAttr(PatternRewriter &rewriter); +Attribute addEntryToDictionaryAttr(PatternRewriter &rewriter, + Attribute dictAttr, Attribute attrName, + Attribute attrEntry); +Attribute createArrayAttr(PatternRewriter &rewriter); +Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr, + Attribute element); +} // namespace builtin +} // namespace pdl +} // namespace mlir + +#endif // MLIR_DIALECT_PDL_IR_BUILTINS_H_ diff --git a/mlir/lib/Dialect/PDL/IR/Builtins.cpp b/mlir/lib/Dialect/PDL/IR/Builtins.cpp new file mode 100644 index 0000000000000..dfe635e8ab304 --- /dev/null +++ b/mlir/lib/Dialect/PDL/IR/Builtins.cpp @@ -0,0 +1,56 @@ +#include +#include + +using namespace mlir; + +namespace mlir::pdl { +namespace builtin { +mlir::Attribute createDictionaryAttr(mlir::PatternRewriter &rewriter) { + return rewriter.getDictionaryAttr({}); +} + +mlir::Attribute addEntryToDictionaryAttr(mlir::PatternRewriter &rewriter, + mlir::Attribute dictAttr, + mlir::Attribute attrName, + mlir::Attribute attrEntry) { + assert(isa(dictAttr)); + auto attr = dictAttr.cast(); + auto name = attrName.cast(); + std::vector values = attr.getValue().vec(); + + // Remove entry if it exists in the dictionary. + llvm::erase_if(values, [&](NamedAttribute &namedAttr) { + return namedAttr.getName() == name.getValue(); + }); + + values.push_back(rewriter.getNamedAttr(name, attrEntry)); + return rewriter.getDictionaryAttr(values); +} + +mlir::Attribute createArrayAttr(mlir::PatternRewriter &rewriter) { + return rewriter.getArrayAttr({}); +} + +mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter, + mlir::Attribute attr, + mlir::Attribute element) { + assert(isa(attr)); + auto values = cast(attr).getValue().vec(); + values.push_back(element); + return rewriter.getArrayAttr(values); +} +} // namespace builtin + +void registerBuiltins(PDLPatternModule &pdlPattern) { + using namespace builtin; + // See Parser::defineBuiltins() + pdlPattern.registerRewriteFunction("__builtin_createDictionaryAttr", + createDictionaryAttr); + pdlPattern.registerRewriteFunction("__builtin_addEntryToDictionaryAttr", + addEntryToDictionaryAttr); + pdlPattern.registerRewriteFunction("__builtin_createArrayAttr", + createArrayAttr); + pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr", + addElemToArrayAttr); +} +} // namespace mlir::pdl diff --git a/mlir/lib/Dialect/PDL/IR/CMakeLists.txt b/mlir/lib/Dialect/PDL/IR/CMakeLists.txt index a0bec9f51a623..49187d9726274 100644 --- a/mlir/lib/Dialect/PDL/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/PDL/IR/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRPDLDialect + Builtins.cpp PDL.cpp PDLTypes.cpp diff --git a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp index 43840d1e8cec2..fe37be3b5a100 100644 --- a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp +++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp @@ -13,6 +13,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include #include using namespace mlir; @@ -132,6 +133,8 @@ FrozenRewritePatternSet::FrozenRewritePatternSet( llvm::report_fatal_error( "failed to lower PDL pattern module to the PDL Interpreter"); + pdl::registerBuiltins(pdlPatterns); + // Generate the pdl bytecode. impl->pdlByteCode = std::make_unique( pdlModule, pdlPatterns.takeConfigs(), configMap, diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index bf035bb07b363..6d3080c8f6941 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -106,6 +106,17 @@ class Parser { /// Pop the last decl scope from the lexer. void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } + /// Creates a native constraint taking a set of Attr as arguments. + /// The number of arguments and their names is given by argNames. + /// The native returns an Attr when returnsAttr is true, otherwise returns + /// nothing. + template + T *declareBuiltin(StringRef name, ArrayRef argNames, + bool returnsAttr); + + /// Register all builtin natives. + void declareBuiltins(); + /// Parse the body of an AST module. LogicalResult parseModuleBody(SmallVectorImpl &decls); @@ -416,12 +427,12 @@ class Parser { FailureOr createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc); - // Create a native call with \p nativeFuncName and \p arguments. + // Create a native call with \p function and \p arguments. // This should be accompanied by a C++ implementation of the function that // needs to be linked and registered in passes that process PDLL files. - FailureOr - createNativeCall(SMRange loc, StringRef nativeFuncName, - MutableArrayRef arguments); + FailureOr + createBuiltinCall(SMRange loc, ast::Decl *function, + MutableArrayRef arguments); /// Validate the member access `name` into the given parent expression. On /// success, this also returns the type of the member accessed. @@ -576,13 +587,64 @@ class Parser { /// The optional code completion context. CodeCompleteContext *codeCompleteContext; + + struct { + ast::UserRewriteDecl *createDictionaryAttr; + ast::UserRewriteDecl *addEntryToDictionaryAttr; + ast::UserRewriteDecl *createArrayAttr; + ast::UserRewriteDecl *addElemToArrayAttr; + } builtins{}; }; } // namespace +template +T *Parser::declareBuiltin(StringRef name, ArrayRef argNames, + bool returnsAttr) { + SMRange loc; + auto attrConstr = ast::ConstraintRef( + ast::AttrConstraintDecl::create(ctx, loc, nullptr), loc); + + pushDeclScope(); + SmallVector args; + for (auto argName : argNames) { + FailureOr arg = + createArgOrResultVariableDecl(argName, loc, attrConstr); + assert(succeeded(arg)); + args.push_back(*arg); + } + SmallVector results; + if (returnsAttr) { + auto result = createArgOrResultVariableDecl("", loc, attrConstr); + assert(succeeded(result)); + results.push_back(*result); + } + popDeclScope(); + + auto *constraintDecl = T::createNative(ctx, ast::Name::create(ctx, name, loc), + args, results, {}, attrTy); + curDeclScope->add(constraintDecl); + return constraintDecl; +} + +void Parser::declareBuiltins() { + builtins.createDictionaryAttr = declareBuiltin( + "__builtin_createDictionaryAttr", {}, /*returnsAttr=*/true); + builtins.addEntryToDictionaryAttr = declareBuiltin( + "__builtin_addEntryToDictionaryAttr", {"attr", "attrName", "attrEntry"}, + /*returnsAttr=*/true); + builtins.createArrayAttr = declareBuiltin( + "__builtin_createArrayAttr", {}, /*returnsAttr=*/true); + builtins.addElemToArrayAttr = declareBuiltin( + "__builtin_addElemToArrayAttr", {"attr", "element"}, + /*returnsAttr=*/true); +} + FailureOr Parser::parseModule() { SMLoc moduleLoc = curToken.getStartLoc(); pushDeclScope(); + declareBuiltins(); + // Parse the top-level decls of the module. SmallVector decls; if (failed(parseModuleBody(decls))) @@ -1869,7 +1931,7 @@ FailureOr Parser::parseArrayAttrExpr() { "Parsing of array attributes as constraint not supported!"); auto arrayAttrCall = - createNativeCall(curToken.getLoc(), "createArrayAttr", {}); + createBuiltinCall(curToken.getLoc(), builtins.createArrayAttr, {}); if (failed(arrayAttrCall)) return failure(); @@ -1879,8 +1941,8 @@ FailureOr Parser::parseArrayAttrExpr() { return failure(); SmallVector arrayAttrArgs{*arrayAttrCall, *attr}; - auto elemToArrayCall = createNativeCall( - curToken.getLoc(), "addElemToArrayAttr", arrayAttrArgs); + auto elemToArrayCall = createBuiltinCall( + curToken.getLoc(), builtins.addElemToArrayAttr, arrayAttrArgs); if (failed(elemToArrayCall)) return failure(); @@ -1961,7 +2023,7 @@ FailureOr Parser::parseDictAttrExpr() { return emitError( "Parsing of dictionary attributes as constraint not supported!"); - auto dictAttrCall = createNativeCall(loc, "createDictionaryAttr", {}); + auto dictAttrCall = createBuiltinCall(loc, builtins.createDictionaryAttr, {}); if (failed(dictAttrCall)) return failure(); @@ -1995,8 +2057,8 @@ FailureOr Parser::parseDictAttrExpr() { // Create addEntryToDictionaryAttr native call. SmallVector arrayAttrArgs{*dictAttrCall, *stringAttrRef, namedDecl->getValue()}; - auto entryToDictionaryCall = - createNativeCall(loc, "addEntryToDictionaryAttr", arrayAttrArgs); + auto entryToDictionaryCall = createBuiltinCall( + loc, builtins.addEntryToDictionaryAttr, arrayAttrArgs); if (failed(entryToDictionaryCall)) return failure(); @@ -2895,33 +2957,20 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, return ast::MemberAccessExpr::create(ctx, loc, parentExpr, name, *memberType); } -FailureOr -Parser::createNativeCall(SMRange loc, StringRef nativeFuncName, - MutableArrayRef arguments) { +FailureOr +Parser::createBuiltinCall(SMRange loc, ast::Decl *function, + MutableArrayRef arguments) { - FailureOr nativeFuncExpr = parseDeclRefExpr(nativeFuncName, loc); + FailureOr nativeFuncExpr = createDeclRefExpr(loc, function); if (failed(nativeFuncExpr)) return failure(); - if (!(*nativeFuncExpr)->getType().isa()) - return emitError(nativeFuncName + " should be defined as a rewriter."); - FailureOr nativeCall = createCallExpr(loc, *nativeFuncExpr, arguments); if (failed(nativeCall)) return failure(); - // Create a unique anonymous name declaration to use, as its name is not - // important. - std::string anonName = - llvm::formatv("{0}_{1}", nativeFuncName, anonymousDeclNameCounter++) - .str(); - FailureOr varDecl = defineVariableDecl( - anonName, loc, (*nativeCall)->getType(), *nativeCall, {}); - if (failed(varDecl)) - return failure(); - - return createDeclRefExpr(loc, *varDecl); + return *nativeCall; } FailureOr Parser::validateMemberAccess(ast::Expr *parentExpr, diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll index a3e1bb88e168a..752e3c8268ede 100644 --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -163,19 +163,13 @@ Pattern TypeExpr => erase op<> -> (type<"i32">); // Parse attributes and rewrite //===----------------------------------------------------------------------===// -// Rewriter helpers declarations. -Rewrite createDictionaryAttr() -> Attr; -Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr; -Rewrite createArrayAttr() -> Attr; -Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr; - // CHECK-LABEL: pdl.pattern @RewriteOneEntryDictionary // CHECK: %[[VAL_1:.*]] = operation "test.op" // CHECK: %[[VAL_2:.*]] = attribute = "test" // CHECK: rewrite %[[VAL_1]] { -// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "createDictionaryAttr" +// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr" // CHECK: %[[VAL_4:.*]] = attribute = "firstAttr" -// CHECK: %[[VAL_5:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_2]] +// CHECK: %[[VAL_5:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_2]] // CHECK: %[[VAL_6:.*]] = operation "test.success" {"some_dictionary" = %[[VAL_5]]} // CHECK: replace %[[VAL_1]] with %[[VAL_6]] Pattern RewriteOneEntryDictionary { @@ -189,23 +183,19 @@ Pattern RewriteOneEntryDictionary { // ----- -// Rewriter helpers declarations. -Rewrite createDictionaryAttr() -> Attr; -Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr; - // CHECK-LABEL: pdl.pattern @RewriteMultipleEntriesDictionary // CHECK: %[[VAL_1:.*]] = operation "test.op" // CHECK: %[[VAL_2:.*]] = attribute = "test2" // CHECK: %[[VAL_3:.*]] = attribute = "test3" // CHECK: rewrite %[[VAL_1]] { -// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "createDictionaryAttr" +// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr" // CHECK: %[[VAL_5:.*]] = attribute = "firstAttr" // CHECK: %[[VAL_6:.*]] = attribute = "test1" -// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]] +// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]] // CHECK: %[[VAL_8:.*]] = attribute = "secondAttr" -// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_7]], %[[VAL_8]], %[[VAL_2]] +// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_7]], %[[VAL_8]], %[[VAL_2]] // CHECK: %[[VAL_10:.*]] = attribute = "thirdAttr" -// CHECK: %[[VAL_11:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_9]], %[[VAL_10]], %[[VAL_3]] +// CHECK: %[[VAL_11:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_9]], %[[VAL_10]], %[[VAL_3]] // CHECK: %[[VAL_12:.*]] = operation "test.success" {"some_dictionary" = %[[VAL_11]]} // CHECK: replace %[[VAL_1]] with %[[VAL_12]] Pattern RewriteMultipleEntriesDictionary { @@ -220,21 +210,15 @@ Pattern RewriteMultipleEntriesDictionary { // ----- -// Rewriter helpers declarations. -Rewrite createDictionaryAttr() -> Attr; -Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr; -Rewrite createArrayAttr() -> Attr; -Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr; - // CHECK-LABEL: pdl.pattern @RewriteOneDictionaryArrayAttr // CHECK: %[[VAL_1:.*]] = operation "test.op" // CHECK: rewrite %[[VAL_1]] { -// CHECK: %[[VAL_2:.*]] = apply_native_rewrite "createArrayAttr" -// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "createDictionaryAttr" +// CHECK: %[[VAL_2:.*]] = apply_native_rewrite "__builtin_createArrayAttr" +// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr" // CHECK: %[[VAL_4:.*]] = attribute = "firstAttr" // CHECK: %[[VAL_5:.*]] = attribute = "test1" -// CHECK: %[[VAL_6:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]] -// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "addElemToArrayAttr"(%[[VAL_2]], %[[VAL_6]] +// CHECK: %[[VAL_6:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]] +// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_2]], %[[VAL_6]] // CHECK: %[[VAL_8:.*]] = operation "test.success" {"some_array" = %[[VAL_7]]} // CHECK: replace %[[VAL_1]] with %[[VAL_8]] Pattern RewriteOneDictionaryArrayAttr { @@ -247,23 +231,17 @@ Pattern RewriteOneDictionaryArrayAttr { // ----- -// Rewriter helpers declarations. -Rewrite createDictionaryAttr() -> Attr; -Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr; -Rewrite createArrayAttr() -> Attr; -Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr; - // CHECK-LABEL: pdl.pattern @RewriteMultiplyElementsArrayAttr // CHECK: %[[VAL_1:.*]] = operation "test.op" // CHECK: %[[VAL_2:.*]] = attribute = "test2" // CHECK: rewrite %[[VAL_1]] { -// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "createArrayAttr" -// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "createDictionaryAttr" +// CHECK: %[[VAL_3:.*]] = apply_native_rewrite "__builtin_createArrayAttr" +// CHECK: %[[VAL_4:.*]] = apply_native_rewrite "__builtin_createDictionaryAttr" // CHECK: %[[VAL_5:.*]] = attribute = "firstAttr" // CHECK: %[[VAL_6:.*]] = attribute = "test1" -// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]] -// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "addElemToArrayAttr"(%[[VAL_3]], %[[VAL_7]] -// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "addElemToArrayAttr"(%[[VAL_8]], %[[VAL_2]] +// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]] +// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_3]], %[[VAL_7]] +// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_8]], %[[VAL_2]] // CHECK: %[[VAL_10:.*]] = operation "test.success" {"some_array" = %[[VAL_9]]} // CHECK: replace %[[VAL_1]] with %[[VAL_10]] Pattern RewriteMultiplyElementsArrayAttr { @@ -273,4 +251,4 @@ Pattern RewriteMultiplyElementsArrayAttr { let newRoot = op() { some_array = [ {"firstAttr" = attr<"\"test1\"">}, attr2]} -> (); replace root with newRoot; }; -} \ No newline at end of file +} diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll index 9c299c55fc311..b0430abbf02ad 100644 --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -253,20 +253,6 @@ Pattern { // ----- -Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr; - -Pattern { - let root = op -> (); - let attr = attr<"\"test\"">; - rewrite root with { - // CHECK: undefined reference to `createArrayAttr` - let newRoot = op() { some_array = [ attr<"\"test\""> ]} -> (); - replace root with newRoot; - }; -} - -// ----- - //===----------------------------------------------------------------------===// // `op` Expr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index 938e181587030..2cb3242f39c69 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -14,21 +14,17 @@ Pattern { // ----- -// CHECK: |-NamedAttributeDecl {{.*}} Name -// CHECK: `-UserRewriteDecl {{.*}} Name ResultType +// CHECK: NamedAttributeDecl {{.*}} Name +// CHECK: UserRewriteDecl {{.*}} Name<__builtin_addElemToArrayAttr> ResultType // CHECK: `Arguments` -// CHECK: `-CallExpr {{.*}} Type -// CHECK: `-UserRewriteDecl {{.*}} Name ResultType -// CHECK: `-CallExpr {{.*}} Type -// CHECK: `-UserRewriteDecl {{.*}} Name ResultType +// CHECK: CallExpr {{.*}} Type +// CHECK: UserRewriteDecl {{.*}} Name<__builtin_createArrayAttr> ResultType +// CHECK: CallExpr {{.*}} Type +// CHECK: UserRewriteDecl {{.*}} Name<__builtin_addEntryToDictionaryAttr> ResultType // CHECK: `Arguments` -// CHECK: `-CallExpr {{.*}} Type -// CHECK: `-UserRewriteDecl {{.*}} Name ResultType -// CHECK: `-AttributeExpr {{.*}} Value<""firstAttr""> -Rewrite createDictionaryAttr() -> Attr; -Rewrite addEntryToDictionaryAttr(dict: Attr, attrName: Attr, attr : Attr) -> Attr; -Rewrite createArrayAttr() -> Attr; -Rewrite addElemToArrayAttr(arrayAttr: Attr, newElement: Attr) -> Attr; +// CHECK: CallExpr {{.*}} Type +// CHECK: UserRewriteDecl {{.*}} Name<__builtin_createDictionaryAttr> ResultType +// CHECK: AttributeExpr {{.*}} Value<""firstAttr""> Pattern { let root = op -> (); diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 522aeca29146d..16671d2fca174 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -8,6 +8,7 @@ target_link_libraries(MLIRDialectTests add_subdirectory(LLVMIR) add_subdirectory(MemRef) +add_subdirectory(PDL) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) add_subdirectory(Transform) diff --git a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp new file mode 100644 index 0000000000000..f653477a5ac87 --- /dev/null +++ b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp @@ -0,0 +1,72 @@ +//===- BuiltinTest.cpp - PDL Builtin Tests --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/PDL/IR/Builtins.h" +#include "gmock/gmock.h" +#include + +using namespace mlir; +using namespace mlir::pdl; + +namespace { + +class TestPatternRewriter : public PatternRewriter { +public: + TestPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {} +}; + +class BuiltinTest : public ::testing::Test { +public: + MLIRContext ctx; + TestPatternRewriter rewriter{&ctx}; +}; + +TEST_F(BuiltinTest, createDictionaryAttr) { + auto attr = builtin::createDictionaryAttr(rewriter); + auto dict = dyn_cast(attr); + EXPECT_TRUE(dict); + EXPECT_TRUE(dict.empty()); +} + +TEST_F(BuiltinTest, addEntryToDictionaryAttr) { + auto dictAttr = rewriter.getDictionaryAttr({}); + + mlir::Attribute updated = builtin::addEntryToDictionaryAttr( + rewriter, dictAttr, rewriter.getStringAttr("testAttr"), + rewriter.getI16IntegerAttr(0)); + + EXPECT_TRUE(updated.cast().contains("testAttr")); + + auto second = builtin::addEntryToDictionaryAttr( + rewriter, updated, rewriter.getStringAttr("testAttr2"), + rewriter.getI16IntegerAttr(0)); + EXPECT_TRUE(second.cast().contains("testAttr")); + EXPECT_TRUE(second.cast().contains("testAttr2")); +} + +TEST_F(BuiltinTest, createArrayAttr) { + auto attr = builtin::createArrayAttr(rewriter); + auto dict = dyn_cast(attr); + EXPECT_TRUE(dict); + EXPECT_TRUE(dict.empty()); +} + +TEST_F(BuiltinTest, addElemToArrayAttr) { + auto dict = rewriter.getDictionaryAttr( + rewriter.getNamedAttr("key", rewriter.getStringAttr("value"))); + rewriter.getArrayAttr({}); + + auto arrAttr = builtin::createArrayAttr(rewriter); + mlir::Attribute updatedArrAttr = + builtin::addElemToArrayAttr(rewriter, arrAttr, dict); + + auto dictInsideArrAttr = + cast(*cast(updatedArrAttr).begin()); + EXPECT_EQ(dictInsideArrAttr, dict); +} +} // namespace diff --git a/mlir/unittests/Dialect/PDL/CMakeLists.txt b/mlir/unittests/Dialect/PDL/CMakeLists.txt new file mode 100644 index 0000000000000..01a94fe1877aa --- /dev/null +++ b/mlir/unittests/Dialect/PDL/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_unittest(MLIRSPLBuiltinTests + BuiltinTest.cpp +) +target_link_libraries(MLIRSPLBuiltinTests + PRIVATE + MLIRIR + MLIRPDLDialect +)