diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index bf035bb07b363..14fcfbf0d9628 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -766,6 +766,15 @@ LogicalResult Parser::convertTupleExpressionTo( return convertToRange({valueTy, valueRangeTy}, valueRangeTy); if (type == typeRangeTy) return convertToRange({typeTy, typeRangeTy}, typeRangeTy); + if (type == attrTy && exprType.size() == 1 && + exprType.getElementTypes()[0] == type) { + // Parenthesis become tuples. Allow to unpack single element tuples + // to expressions. + expr = ast::MemberAccessExpr::create(ctx, expr->getLoc(), expr, + llvm::to_string(0), + exprType.getElementTypes()[0]); + return success(); + } return emitErrorFn(); } diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll index 9c299c55fc311..8cdad55e7ff8b 100644 --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -124,6 +124,16 @@ Pattern { // ----- +Constraint checkAttr(attr: Attr); +Pattern { + let tuple = (result1 = value: Value); + // CHECK: unable to convert expression of type `Tuple` to the expected type of `Attr` + checkAttr(tuple); + erase _: Op; +} + +// ----- + //===----------------------------------------------------------------------===// // Range Expr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index 938e181587030..34fa259bf5726 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -114,6 +114,20 @@ Pattern { // ----- +// Implicitly convert single element Tuple to Attr +// CHECK: Module +// CHECK: `-MemberAccessExpr {{.*}} Member<0> Type +// CHECK: `-TupleExpr {{.*}} Type> +// CHECK: `-AttributeExpr {{.*}} Value<"10: i32"> +Constraint checkAttr(attr: Attr); +Pattern { + let tuple = (attr<"10: i32">); + checkAttr(tuple); + erase _: Op; +} + +// ----- + #include "include/ops.td" // CHECK: Module