From aa3d1134519d7fb3a9cb0c75554cd2313c2279eb Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 28 Feb 2024 13:24:36 +0100 Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?= =?UTF-8?q?anges=20to=20feature/fused-ops=20this=20commit=20is=20based=20o?= =?UTF-8?q?n?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.4 [skip ci] --- .../mlir/Dialect/EmitC/IR/EmitCTypes.td | 52 +++++++++++++++- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 59 +++++++++++++++++++ mlir/test/Dialect/EmitC/invalid_types.mlir | 54 +++++++++++++++++ mlir/test/Dialect/EmitC/types.mlir | 14 +++++ 4 files changed, 177 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td index 7874aa2c9e304..6e937dfedebf8 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td @@ -16,13 +16,14 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/Dialect/EmitC/IR/EmitCBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" //===----------------------------------------------------------------------===// // EmitC type definitions //===----------------------------------------------------------------------===// -class EmitC_Type - : TypeDef { +class EmitC_Type traits = []> + : TypeDef { let mnemonic = typeMnemonic; } @@ -72,4 +73,51 @@ def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> { let assemblyFormat = "`<` qualified($pointee) `>`"; } +def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> { + let summary = "EmitC array type"; + + let description = [{ + An array data type. + + Example: + + ```mlir + // Array emitted as `int32_t[10]` + !emitc.array<10xi32> + // Array emitted as `float[10][20]` + !emitc.ptr<10x20xf32> + ``` + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "Type":$elementType + ); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "ArrayRef":$shape, + "Type":$elementType + ), [{ + return $_get(elementType.getContext(), shape, elementType); + }]> + ]; + let extraClassDeclaration = [{ + /// Returns if this type is ranked (always true). + bool hasRank() const { return true; } + + /// Clone this vector type with the given shape and element type. If the + /// provided shape is `std::nullopt`, the current shape of the type is used. + ArrayType cloneWith(std::optional> shape, + Type elementType) const; + + static bool isValidElementType(Type type) { + return type.isIntOrIndexOrFloat() || + llvm::isa(type); + } + }]; + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; +} + #endif // MLIR_DIALECT_EMITC_IR_EMITCTYPES diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index e8ea4da0b089c..3b154ccf136b3 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -606,3 +606,62 @@ void emitc::OpaqueType::print(AsmPrinter &printer) const { llvm::printEscapedString(getValue(), printer.getStream()); printer << "\">"; } + +Type emitc::ArrayType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + SmallVector dimensions; + if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false, + /*withTrailingX=*/true)) + return Type(); + // Parse the element type. + auto typeLoc = parser.getCurrentLocation(); + Type elementType; + if (parser.parseType(elementType)) + return Type(); + + // Check that memref is formed from allowed types. + if (!isValidElementType(elementType)) + return parser.emitError(typeLoc, "invalid array element type"), Type(); + if (parser.parseGreater()) + return Type(); + return parser.getChecked(dimensions, elementType); +} + +void emitc::ArrayType::print(AsmPrinter &printer) const { + printer << "<"; + for (int64_t dim : getShape()) { + printer << dim << 'x'; + } + printer.printType(getElementType()); + printer << ">"; +} + +LogicalResult emitc::ArrayType::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + ::llvm::ArrayRef shape, Type elementType) { + if (shape.empty()) + return emitError() << "shape must not be empty"; + + for (auto d : shape) { + if (d <= 0) + return emitError() << "dimensions must have positive size"; + } + + if (!elementType) + return emitError() << "element type must not be none"; + + if (!isValidElementType(elementType)) + return emitError() << "invalid array element type"; + + return success(); +} + +emitc::ArrayType +emitc::ArrayType::cloneWith(std::optional> shape, + Type elementType) const { + if (!shape) + return emitc::ArrayType::get(getShape(), elementType); + return emitc::ArrayType::get(*shape, elementType); +} diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir index 54e3775ddb8ed..4c526aa93dffb 100644 --- a/mlir/test/Dialect/EmitC/invalid_types.mlir +++ b/mlir/test/Dialect/EmitC/invalid_types.mlir @@ -11,3 +11,57 @@ func.func @illegal_opaque_type_2() { // expected-error @+1 {{pointer not allowed as outer type with !emitc.opaque, use !emitc.ptr instead}} %1 = "emitc.variable"(){value = "nullptr" : !emitc.opaque<"int32_t*">} : () -> !emitc.opaque<"int32_t*"> } + +// ----- + +func.func @illegal_array_missing_spec( + // expected-error @+1 {{expected non-function type}} + %arg0: !emitc.array<>) { +} + +// ----- + +func.func @illegal_array_missing_shape( + // expected-error @+1 {{shape must not be empty}} + %arg9: !emitc.array) { +} + +// ----- + +func.func @illegal_array_missing_x( + // expected-error @+1 {{expected 'x' in dimension list}} + %arg0: !emitc.array<10> +) { +} + +// ----- + +func.func @illegal_array_non_positive_dimenson( + // expected-error @+1 {{dimensions must have positive size}} + %arg0: !emitc.array<0xi32> +) { +} + +// ----- + +func.func @illegal_array_missing_type( + // expected-error @+1 {{expected non-function type}} + %arg0: !emitc.array<10x> +) { +} + +// ----- + +func.func @illegal_array_dynamic_shape( + // expected-error @+1 {{expected static shape}} + %arg0: !emitc.array<10x?xi32> +) { +} + +// ----- + +func.func @illegal_array_unranked( + // expected-error @+1 {{expected non-function type}} + %arg0: !emitc.array<*xi32> +) { +} diff --git a/mlir/test/Dialect/EmitC/types.mlir b/mlir/test/Dialect/EmitC/types.mlir index 26d6f43a5824e..8477b0ed05977 100644 --- a/mlir/test/Dialect/EmitC/types.mlir +++ b/mlir/test/Dialect/EmitC/types.mlir @@ -39,3 +39,17 @@ func.func @pointer_types() { return } + +// CHECK-LABEL: func @array_types( +func.func @array_types( + // CHECK-SAME: !emitc.array<1xf32>, + %arg0: !emitc.array<1xf32>, + // CHECK-SAME: !emitc.array<10x20x30xi32>, + %arg1: !emitc.array<10x20x30xi32>, + // CHECK-SAME: !emitc.array<30x!emitc.ptr>, + %arg2: !emitc.array<30x!emitc.ptr>, + // CHECK-SAME: !emitc.array<30x!emitc.opaque<"int">> + %arg3: !emitc.array<30x!emitc.opaque<"int">> +) { + return +} From eaab31797a904c335f5d39a7d22e480a8e372731 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 28 Feb 2024 14:39:52 +0100 Subject: [PATCH 2/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?= =?UTF-8?q?anges=20introduced=20through=20rebase?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.4 [skip ci] --- .../mlir/Dialect/EmitC/IR/EmitCTypes.td | 92 +++++++++---------- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 56 +++++------ 2 files changed, 76 insertions(+), 72 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td index 6e937dfedebf8..8dfda3be99d5f 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td @@ -27,52 +27,6 @@ class EmitC_Type traits = []> let mnemonic = typeMnemonic; } -def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> { - let summary = "EmitC opaque type"; - - let description = [{ - An opaque data type of which the value gets emitted as is. - - Example: - - ```mlir - !emitc.opaque<"int"> - !emitc.opaque<"mytype"> - !emitc.opaque<"std::vector"> - ``` - }]; - - let parameters = (ins StringRefParameter<"the opaque value">:$value); - let hasCustomAssemblyFormat = 1; -} - -def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> { - let summary = "EmitC pointer type"; - - let description = [{ - A pointer data type. - - Example: - - ```mlir - // Pointer emitted as `int32_t*` - !emitc.ptr - // Pointer emitted as `float*` - !emitc.ptr - // Pointer emitted as `int*` - !emitc.ptr> - ``` - }]; - - let parameters = (ins "Type":$pointee); - let builders = [ - TypeBuilderWithInferredContext<(ins "Type":$pointee), [{ - return $_get(pointee.getContext(), pointee); - }]> - ]; - let assemblyFormat = "`<` qualified($pointee) `>`"; -} - def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> { let summary = "EmitC array type"; @@ -120,4 +74,50 @@ def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> { let hasCustomAssemblyFormat = 1; } +def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> { + let summary = "EmitC opaque type"; + + let description = [{ + An opaque data type of which the value gets emitted as is. + + Example: + + ```mlir + !emitc.opaque<"int"> + !emitc.opaque<"mytype"> + !emitc.opaque<"std::vector"> + ``` + }]; + + let parameters = (ins StringRefParameter<"the opaque value">:$value); + let hasCustomAssemblyFormat = 1; +} + +def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> { + let summary = "EmitC pointer type"; + + let description = [{ + A pointer data type. + + Example: + + ```mlir + // Pointer emitted as `int32_t*` + !emitc.ptr + // Pointer emitted as `float*` + !emitc.ptr + // Pointer emitted as `int*` + !emitc.ptr> + ``` + }]; + + let parameters = (ins "Type":$pointee); + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$pointee), [{ + return $_get(pointee.getContext(), pointee); + }]> + ]; + let assemblyFormat = "`<` qualified($pointee) `>`"; +} + #endif // MLIR_DIALECT_EMITC_IR_EMITCTYPES diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 3b154ccf136b3..776285d842db9 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -579,34 +579,9 @@ void emitc::OpaqueAttr::print(AsmPrinter &printer) const { #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc" //===----------------------------------------------------------------------===// -// OpaqueType +// ArrayType //===----------------------------------------------------------------------===// -Type emitc::OpaqueType::parse(AsmParser &parser) { - if (parser.parseLess()) - return Type(); - std::string value; - SMLoc loc = parser.getCurrentLocation(); - if (parser.parseOptionalString(&value) || value.empty()) { - parser.emitError(loc) << "expected non empty string in !emitc.opaque type"; - return Type(); - } - if (value.back() == '*') { - parser.emitError(loc) << "pointer not allowed as outer type with " - "!emitc.opaque, use !emitc.ptr instead"; - return Type(); - } - if (parser.parseGreater()) - return Type(); - return get(parser.getContext(), value); -} - -void emitc::OpaqueType::print(AsmPrinter &printer) const { - printer << "<\""; - llvm::printEscapedString(getValue(), printer.getStream()); - printer << "\">"; -} - Type emitc::ArrayType::parse(AsmParser &parser) { if (parser.parseLess()) return Type(); @@ -665,3 +640,32 @@ emitc::ArrayType::cloneWith(std::optional> shape, return emitc::ArrayType::get(getShape(), elementType); return emitc::ArrayType::get(*shape, elementType); } + +//===----------------------------------------------------------------------===// +// OpaqueType +//===----------------------------------------------------------------------===// + +Type emitc::OpaqueType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + std::string value; + SMLoc loc = parser.getCurrentLocation(); + if (parser.parseOptionalString(&value) || value.empty()) { + parser.emitError(loc) << "expected non empty string in !emitc.opaque type"; + return Type(); + } + if (value.back() == '*') { + parser.emitError(loc) << "pointer not allowed as outer type with " + "!emitc.opaque, use !emitc.ptr instead"; + return Type(); + } + if (parser.parseGreater()) + return Type(); + return get(parser.getContext(), value); +} + +void emitc::OpaqueType::print(AsmPrinter &printer) const { + printer << "<\""; + llvm::printEscapedString(getValue(), printer.getStream()); + printer << "\">"; +}