diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td index 7874aa2c9e304..8dfda3be99d5f 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td @@ -16,16 +16,64 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/Dialect/EmitC/IR/EmitCBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" //===----------------------------------------------------------------------===// // EmitC type definitions //===----------------------------------------------------------------------===// -class EmitC_Type - : TypeDef { +class EmitC_Type traits = []> + : TypeDef { let mnemonic = typeMnemonic; } +def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> { + let summary = "EmitC array type"; + + let description = [{ + An array data type. + + Example: + + ```mlir + // Array emitted as `int32_t[10]` + !emitc.array<10xi32> + // Array emitted as `float[10][20]` + !emitc.ptr<10x20xf32> + ``` + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "Type":$elementType + ); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "ArrayRef":$shape, + "Type":$elementType + ), [{ + return $_get(elementType.getContext(), shape, elementType); + }]> + ]; + let extraClassDeclaration = [{ + /// Returns if this type is ranked (always true). + bool hasRank() const { return true; } + + /// Clone this vector type with the given shape and element type. If the + /// provided shape is `std::nullopt`, the current shape of the type is used. + ArrayType cloneWith(std::optional> shape, + Type elementType) const; + + static bool isValidElementType(Type type) { + return type.isIntOrIndexOrFloat() || + llvm::isa(type); + } + }]; + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; +} + def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> { let summary = "EmitC opaque type"; diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index e8ea4da0b089c..776285d842db9 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -578,6 +578,69 @@ void emitc::OpaqueAttr::print(AsmPrinter &printer) const { #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc" +//===----------------------------------------------------------------------===// +// ArrayType +//===----------------------------------------------------------------------===// + +Type emitc::ArrayType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + SmallVector dimensions; + if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false, + /*withTrailingX=*/true)) + return Type(); + // Parse the element type. + auto typeLoc = parser.getCurrentLocation(); + Type elementType; + if (parser.parseType(elementType)) + return Type(); + + // Check that memref is formed from allowed types. + if (!isValidElementType(elementType)) + return parser.emitError(typeLoc, "invalid array element type"), Type(); + if (parser.parseGreater()) + return Type(); + return parser.getChecked(dimensions, elementType); +} + +void emitc::ArrayType::print(AsmPrinter &printer) const { + printer << "<"; + for (int64_t dim : getShape()) { + printer << dim << 'x'; + } + printer.printType(getElementType()); + printer << ">"; +} + +LogicalResult emitc::ArrayType::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + ::llvm::ArrayRef shape, Type elementType) { + if (shape.empty()) + return emitError() << "shape must not be empty"; + + for (auto d : shape) { + if (d <= 0) + return emitError() << "dimensions must have positive size"; + } + + if (!elementType) + return emitError() << "element type must not be none"; + + if (!isValidElementType(elementType)) + return emitError() << "invalid array element type"; + + return success(); +} + +emitc::ArrayType +emitc::ArrayType::cloneWith(std::optional> shape, + Type elementType) const { + if (!shape) + return emitc::ArrayType::get(getShape(), elementType); + return emitc::ArrayType::get(*shape, elementType); +} + //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir index 54e3775ddb8ed..4c526aa93dffb 100644 --- a/mlir/test/Dialect/EmitC/invalid_types.mlir +++ b/mlir/test/Dialect/EmitC/invalid_types.mlir @@ -11,3 +11,57 @@ func.func @illegal_opaque_type_2() { // expected-error @+1 {{pointer not allowed as outer type with !emitc.opaque, use !emitc.ptr instead}} %1 = "emitc.variable"(){value = "nullptr" : !emitc.opaque<"int32_t*">} : () -> !emitc.opaque<"int32_t*"> } + +// ----- + +func.func @illegal_array_missing_spec( + // expected-error @+1 {{expected non-function type}} + %arg0: !emitc.array<>) { +} + +// ----- + +func.func @illegal_array_missing_shape( + // expected-error @+1 {{shape must not be empty}} + %arg9: !emitc.array) { +} + +// ----- + +func.func @illegal_array_missing_x( + // expected-error @+1 {{expected 'x' in dimension list}} + %arg0: !emitc.array<10> +) { +} + +// ----- + +func.func @illegal_array_non_positive_dimenson( + // expected-error @+1 {{dimensions must have positive size}} + %arg0: !emitc.array<0xi32> +) { +} + +// ----- + +func.func @illegal_array_missing_type( + // expected-error @+1 {{expected non-function type}} + %arg0: !emitc.array<10x> +) { +} + +// ----- + +func.func @illegal_array_dynamic_shape( + // expected-error @+1 {{expected static shape}} + %arg0: !emitc.array<10x?xi32> +) { +} + +// ----- + +func.func @illegal_array_unranked( + // expected-error @+1 {{expected non-function type}} + %arg0: !emitc.array<*xi32> +) { +} diff --git a/mlir/test/Dialect/EmitC/types.mlir b/mlir/test/Dialect/EmitC/types.mlir index 26d6f43a5824e..8477b0ed05977 100644 --- a/mlir/test/Dialect/EmitC/types.mlir +++ b/mlir/test/Dialect/EmitC/types.mlir @@ -39,3 +39,17 @@ func.func @pointer_types() { return } + +// CHECK-LABEL: func @array_types( +func.func @array_types( + // CHECK-SAME: !emitc.array<1xf32>, + %arg0: !emitc.array<1xf32>, + // CHECK-SAME: !emitc.array<10x20x30xi32>, + %arg1: !emitc.array<10x20x30xi32>, + // CHECK-SAME: !emitc.array<30x!emitc.ptr>, + %arg2: !emitc.array<30x!emitc.ptr>, + // CHECK-SAME: !emitc.array<30x!emitc.opaque<"int">> + %arg3: !emitc.array<30x!emitc.opaque<"int">> +) { + return +}