Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/PDL/IR/Builtins.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ enum class UnaryOpKind {
LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> args);
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
Attribute element);
LogicalResult addElemToArrayAttr(PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> args);
LogicalResult mul(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);
LogicalResult div(PatternRewriter &rewriter, PDLResultList &results,
Expand Down
52 changes: 27 additions & 25 deletions mlir/lib/Dialect/PDL/IR/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,19 @@ LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter,
return success();
}

mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
mlir::Attribute attr,
mlir::Attribute element) {
assert(isa<ArrayAttr>(attr));
auto values = cast<ArrayAttr>(attr).getValue().vec();
values.push_back(element);
return rewriter.getArrayAttr(values);
LogicalResult addElemToArrayAttr(PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> args) {

assert(args.size() == 2 &&
"Expected two arguments, one ArrayAttr and one Attr");
auto arrayAttr = cast<ArrayAttr>(args[0].cast<Attribute>());
auto attrElement = args[1].cast<Attribute>();
llvm::SmallVector<Attribute> values(arrayAttr.getValue());
values.push_back(attrElement);

results.push_back(rewriter.getArrayAttr(values));
return success();
}

template <UnaryOpKind T>
Expand Down Expand Up @@ -344,11 +350,15 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
// See Parser::defineBuiltins()
pdlPattern.registerRewriteFunction(
"__builtin_addEntryToDictionaryAttr_rewrite", addEntryToDictionaryAttr);
pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr",
addElemToArrayAttr);
pdlPattern.registerConstraintFunction(
"__builtin_addEntryToDictionaryAttr_constraint",
addEntryToDictionaryAttr);

pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttrRewriter",
addElemToArrayAttr);
pdlPattern.registerConstraintFunction(
"__builtin_addElemToArrayAttrConstraint", addElemToArrayAttr);

pdlPattern.registerRewriteFunction("__builtin_mulRewrite", mul);
pdlPattern.registerRewriteFunction("__builtin_divRewrite", div);
pdlPattern.registerRewriteFunction("__builtin_modRewrite", mod);
Expand All @@ -357,22 +367,14 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
pdlPattern.registerRewriteFunction("__builtin_log2Rewrite", log2);
pdlPattern.registerRewriteFunction("__builtin_exp2Rewrite", exp2);
pdlPattern.registerRewriteFunction("__builtin_absRewrite", abs);
pdlPattern.registerConstraintFunction("__builtin_mulConstraint",
mul);
pdlPattern.registerConstraintFunction("__builtin_divConstraint",
div);
pdlPattern.registerConstraintFunction("__builtin_modConstraint",
mod);
pdlPattern.registerConstraintFunction("__builtin_addConstraint",
add);
pdlPattern.registerConstraintFunction("__builtin_subConstraint",
sub);
pdlPattern.registerConstraintFunction("__builtin_log2Constraint",
log2);
pdlPattern.registerConstraintFunction("__builtin_exp2Constraint",
exp2);
pdlPattern.registerConstraintFunction("__builtin_absConstraint",
abs);
pdlPattern.registerConstraintFunction("__builtin_mulConstraint", mul);
pdlPattern.registerConstraintFunction("__builtin_divConstraint", div);
pdlPattern.registerConstraintFunction("__builtin_modConstraint", mod);
pdlPattern.registerConstraintFunction("__builtin_addConstraint", add);
pdlPattern.registerConstraintFunction("__builtin_subConstraint", sub);
pdlPattern.registerConstraintFunction("__builtin_log2Constraint", log2);
pdlPattern.registerConstraintFunction("__builtin_exp2Constraint", exp2);
pdlPattern.registerConstraintFunction("__builtin_absConstraint", abs);
pdlPattern.registerConstraintFunction("__builtin_equals", equals);
}
} // namespace mlir::pdl
32 changes: 23 additions & 9 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,8 @@ class Parser {
struct {
ast::UserRewriteDecl *addEntryToDictionaryAttr_Rewrite;
ast::UserConstraintDecl *addEntryToDictionaryAttr_Constraint;
ast::UserRewriteDecl *addElemToArrayAttr;
ast::UserRewriteDecl *addElemToArrayAttrRewrite;
ast::UserConstraintDecl *addElemToArrayAttrConstraint;
ast::UserRewriteDecl *mulRewrite;
ast::UserRewriteDecl *divRewrite;
ast::UserRewriteDecl *modRewrite;
Expand Down Expand Up @@ -691,9 +692,13 @@ void Parser::declareBuiltins() {
"__builtin_addEntryToDictionaryAttr_constraint",
{"attr", "attrName", "attrEntry"},
/*returnsAttr=*/true);
builtins.addElemToArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_addElemToArrayAttr", {"attr", "element"},
builtins.addElemToArrayAttrRewrite = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_addElemToArrayAttrRewriter", {"attr", "element"},
/*returnsAttr=*/true);
builtins.addElemToArrayAttrConstraint =
declareBuiltin<ast::UserConstraintDecl>(
"__builtin_addElemToArrayAttrConstraint", {"attr", "element"},
/*returnsAttr=*/true);
builtins.mulRewrite = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_mulRewrite", {"lhs", "rhs"}, true);
builtins.divRewrite = declareBuiltin<ast::UserRewriteDecl>(
Expand Down Expand Up @@ -2323,27 +2328,35 @@ FailureOr<ast::Expr *> Parser::parseArrayAttrExpr() {

consumeToken(Token::l_square);

ast::Decl *builtinFunction = builtins.addElemToArrayAttrRewrite;
if (parserContext != ParserContext::Rewrite)
return emitError(
"Parsing of array attributes as constraint not supported!");
builtinFunction = builtins.addElemToArrayAttrConstraint;

FailureOr<ast::Expr *> arrayAttr = ast::AttributeExpr::create(ctx, curToken.getLoc(), "[]");
FailureOr<ast::Expr *> arrayAttr =
ast::AttributeExpr::create(ctx, curToken.getLoc(), "[]");
if (failed(arrayAttr))
return failure();

// No values inside the array
if (consumeIf(Token::r_square)) {
return arrayAttr;
}

do {
FailureOr<ast::Expr *> attr = parseExpr();
if (failed(attr))
return failure();

SmallVector<ast::Expr *> arrayAttrArgs{*arrayAttr, *attr};
auto elemToArrayCall = createBuiltinCall(
curToken.getLoc(), builtins.addElemToArrayAttr, arrayAttrArgs);

auto elemToArrayCall =
createBuiltinCall(curToken.getLoc(), builtinFunction, arrayAttrArgs);
if (failed(elemToArrayCall))
return failure();

// Uses the new array for the next element.
arrayAttr = elemToArrayCall;

} while (consumeIf(Token::comma));

if (failed(
Expand Down Expand Up @@ -2415,7 +2428,8 @@ FailureOr<ast::Expr *> Parser::parseDictAttrExpr() {
consumeToken(Token::l_brace);
SMRange loc = curToken.getLoc();

FailureOr<ast::Expr *> dictAttrCall = ast::AttributeExpr::create(ctx, loc, "{}");
FailureOr<ast::Expr *> dictAttrCall =
ast::AttributeExpr::create(ctx, loc, "{}");
if (failed(dictAttrCall))
return failure();

Expand Down
43 changes: 40 additions & 3 deletions mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ Pattern RewriteMultipleEntriesDictionary {
// CHECK: %[[VAL_4:.*]] = attribute = "firstAttr"
// CHECK: %[[VAL_5:.*]] = attribute = "test1"
// CHECK: %[[VAL_6:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_2]], %[[VAL_6]]
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_2]], %[[VAL_6]]
// CHECK: %[[VAL_8:.*]] = operation "test.success" {"some_array" = %[[VAL_7]]}
// CHECK: replace %[[VAL_1]] with %[[VAL_8]]
Pattern RewriteOneDictionaryArrayAttr {
Expand All @@ -229,6 +229,43 @@ Pattern RewriteOneDictionaryArrayAttr {
};
}

// -----

// CHECK-LABEL: pdl.pattern @ConstraintWithArrayAttr
// CHECK: %[[VAL_0:.*]] = attribute = "test1"
// CHECK: %[[VAL_1:.*]] = attribute = "test2"
// CHECK: %[[VAL_2:.*]] = attribute = []
// CHECK: %[[VAL_3:.*]] = apply_native_constraint "__builtin_addElemToArrayAttrConstraint"(%[[VAL_2]], %[[VAL_0]]
// CHECK: %[[VAL_4:.*]] = apply_native_constraint "__builtin_addElemToArrayAttrConstraint"(%[[VAL_3]], %[[VAL_1]]
// CHECK: %[[VAL_5:.*]] = operation "test.op"
// CHECK: rewrite %[[VAL_5]] {
// CHECK: %[[VAL_6:.*]] = operation "test.success" {"some_array" = %[[VAL_4]]}
// CHECK: replace %[[VAL_5]] with %[[VAL_6]]

Pattern ConstraintWithArrayAttr {
let attr1 = attr<"\"test1\"">;
let attr2 = attr<"\"test2\"">;
let array = [attr1, attr2];
let root = op<test.op> -> ();
rewrite root with {
let newRoot = op<test.success>() { some_array = array} -> ();
replace root with newRoot;
};
}

// -----

// CHECK-LABEL: pdl.pattern @ConstraintNotMatchingArrayAttrInAttrType
// CHECK-NOT: apply_native_constraint "__builtin_addElemToArrayAttrConstraint"


Constraint I64Value(value: Value);
Pattern ConstraintNotMatchingArrayAttrInAttrType {
let root = op<my_dialect.foo>(arg: Value, arg2: Value, arg3: [Value, I64Value], arg);
replace root with arg;
}


// -----

// CHECK-LABEL: pdl.pattern @RewriteMultiplyElementsArrayAttr
Expand All @@ -240,8 +277,8 @@ Pattern RewriteOneDictionaryArrayAttr {
// CHECK: %[[VAL_5:.*]] = attribute = "firstAttr"
// CHECK: %[[VAL_6:.*]] = attribute = "test1"
// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]
// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_3]], %[[VAL_7]]
// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_8]], %[[VAL_2]]
// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_3]], %[[VAL_7]]
// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_8]], %[[VAL_2]]
// CHECK: %[[VAL_10:.*]] = operation "test.success" {"some_array" = %[[VAL_9]]}
// CHECK: replace %[[VAL_1]] with %[[VAL_10]]
Pattern RewriteMultiplyElementsArrayAttr {
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr-failure.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,22 @@ Pattern {

// -----

Pattern ConstraintArrayAttrWithAttrAndValue {
let root = op<test.op>(arg: Value) -> ();
let attr1 = attr<"\"test1\"">;
let array = [attr1, arg];
// CHECK: unable to convert expression of type `Value` to the expected type of `Attr`
let root = op<test.op> -> ();
rewrite root with {
let newRoot = op<test.success>() { some_array = array} -> ();
replace root with newRoot;
};
}

// -----



//===----------------------------------------------------------------------===//
// Range Expr
//===----------------------------------------------------------------------===//
Expand Down
73 changes: 72 additions & 1 deletion mlir/test/mlir-pdll/Parser/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Pattern {

// CHECK-LABEL: Module
// CHECK: |-NamedAttributeDecl {{.*}} Name<some_array>
// CHECK: `-UserRewriteDecl {{.*}} Name<__builtin_addElemToArrayAttr> ResultType<Attr>
// CHECK: `-UserRewriteDecl {{.*}} Name<__builtin_addElemToArrayAttrRewriter> ResultType<Attr>
// CHECK: `Arguments`
// CHECK: CallExpr {{.*}} Type<Attr>
// CHECK: AttributeExpr {{.*}} Value<"[]">
Expand Down Expand Up @@ -87,6 +87,77 @@ Constraint getPopulatedDict() -> Attr {
return dictionary;
}



// -----

// CHECK-LABEL: Module
// CHECK:LetStmt {{.*}}
//CHECK-NEXT:`-VariableDecl {{.*}} Name<array> Type<Attr>
//CHECK-NEXT: `-AttributeExpr {{.*}} Value<"[]">
//CHECK-NEXT:ReturnStmt {{.*}}

Constraint getEmtpyArray() -> Attr {
let array = [];
return array;
}

// -----

// CHECK-LABEL: Module
// CHECK:LetStmt {{.*}}
//CHECK-NEXT:`-VariableDecl {{.*}} Name<array> Type<Attr>
//CHECK-NEXT: `-CallExpr {{.*}} Type<Attr>
//CHECK-NEXT: `-DeclRefExpr {{.*}} Type<Constraint>
//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType<Attr>
// CHECK: `Arguments`
//CHECK-NEXT: |-AttributeExpr {{.*}} Value<"[]">
//CHECK-NEXT: `-AttributeExpr {{.*}} Value<""attr1"">
//CHECK-NEXT:ReturnStmt {{.*}}

Constraint getPopulateArray() -> Attr {
let array = ["attr1"];
return array;
}


// -----


// CHECK-LABEL: Module
// CHECK:LetStmt {{.*}}
//CHECK-NEXT:`-VariableDecl {{.*}} Name<array> Type<Attr>
//CHECK-NEXT: `-CallExpr {{.*}} Type<Attr>
//CHECK-NEXT: `-DeclRefExpr {{.*}} Type<Constraint>
//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType<Attr>
// CHECK-DAG: `Arguments`
//CHECK-NEXT: |-CallExpr {{.*}} Type<Attr>
//CHECK-NEXT: `-DeclRefExpr {{.*}} Type<Constraint>
//CHECK-NEXT: | `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType<Attr>
// CHECK-DAG: `Arguments`
//CHECK-NEXT: |-AttributeExpr {{.*}} Value<"[]">
//CHECK-NEXT: `-CallExpr {{.*}} Type<Attr>
//CHECK-NEXT: `-DeclRefExpr {{.*}} Type<Constraint>
//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<getA> ResultType<Attr>
// CHECK: `-CallExpr {{.*}} Type<Attr>
//CHECK-NEXT: `-DeclRefExpr {{.*}} Type<Constraint>
//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<getB> ResultType<Attr>
// CHECK-DAG: -ReturnStmt {{.*}}

Constraint getA() -> Attr {
return "A";
}

Constraint getB() -> Attr {
return "B";
}

Constraint getPopulateArrayFromOtherConstraints() -> Attr {
let array = [getA(), getB()];
return array;
}


// -----

//===----------------------------------------------------------------------===//
Expand Down
11 changes: 8 additions & 3 deletions mlir/unittests/Dialect/PDL/BuiltinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,17 @@ TEST_F(BuiltinTest, addEntryToDictionaryAttr) {
}

TEST_F(BuiltinTest, addElemToArrayAttr) {
TestPDLResultList results(1);

auto dict = rewriter.getDictionaryAttr(
rewriter.getNamedAttr("key", rewriter.getStringAttr("value")));
rewriter.getArrayAttr({});

auto arrAttr = rewriter.getArrayAttr({});
EXPECT_TRUE(succeeded(
builtin::addElemToArrayAttr(rewriter, results, {arrAttr, dict})));
mlir::Attribute updatedArrAttr =
builtin::addElemToArrayAttr(rewriter, arrAttr, dict);
results.getResults().front().cast<Attribute>();

auto dictInsideArrAttr =
cast<DictionaryAttr>(*cast<ArrayAttr>(updatedArrAttr).begin());
Expand Down Expand Up @@ -617,7 +621,7 @@ TEST_F(BuiltinTest, log2) {
cast<FloatAttr>(result.cast<Attribute>()).getValue().convertToFloat(),
2.0);
}

auto threeF16 = rewriter.getF16FloatAttr(3.0);

// check correctness
Expand All @@ -626,7 +630,8 @@ TEST_F(BuiltinTest, log2) {
EXPECT_TRUE(builtin::log2(rewriter, results, {threeF16}).succeeded());

PDLValue result = results.getResults()[0];
float resultVal = cast<FloatAttr>(result.cast<Attribute>()).getValue().convertToFloat();
float resultVal =
cast<FloatAttr>(result.cast<Attribute>()).getValue().convertToFloat();
EXPECT_TRUE(resultVal > 1.58 && resultVal < 1.59);
}
}
Expand Down