diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt index 1ee105f0ceb98..cc8d5ed9b0044 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt @@ -3,6 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc) add_mlir_interface(TosaInterfaces) set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaOpsTypesBase.h.inc -gen-typedef-decls -typedefs-dialect=tosa) +mlir_tablegen(TosaOpsTypesBase.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa) mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa) mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa) add_public_tablegen_target(MLIRTosaAttributesIncGen) @@ -10,4 +12,3 @@ add_public_tablegen_target(MLIRTosaAttributesIncGen) set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td) mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa") add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen) - diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index d3f12c34421b0..47cda3c9f481e 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect { let cppNamespace = "mlir::tosa"; let hasConstantMaterializer = 1; let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; } //===----------------------------------------------------------------------===// @@ -217,12 +218,21 @@ def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> { let cppNamespace = "mlir::OpTrait::tosa"; } +//===----------------------------------------------------------------------===// +// TOSA Operator Trait. +//===----------------------------------------------------------------------===// +// Op operands with TOSA shape types must be compile time resolvable +def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> { + let cppNamespace = "mlir::OpTrait::tosa"; +} + //===----------------------------------------------------------------------===// // TOSA Operator Class. //===----------------------------------------------------------------------===// class Tosa_Op traits = []> : - Op { + Op { } class Tosa_ElementwiseOp traits = []> : diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index 66512cbe350ec..e4f5d09064cd7 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -90,14 +90,55 @@ template class TosaElementwiseOperator : public TraitBase {}; +LogicalResult verifyTosaResolvableShapeOperands(Operation *op); +/// This class verifies that tosa shape operands are compile time resolvable +template +class TosaResolvableShapeOperands + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return verifyTosaResolvableShapeOperands(op); + } +}; + +LogicalResult verifyTosaShapeOperator(Operation *op); +/// This class indicates that op operates on tosa shape types +template +class TosaShapeOperator : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return verifyTosaShapeOperator(op); + } +}; + +LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op); +/// This class indicates that op operates on tosa shape types +template +class TosaShapeOperatorWithSameRanks + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return verifyTosaShapeOperatorWithSameRanks(op); + } +}; + } // namespace tosa } // namespace OpTrait +namespace tosa { + +bool isa_tosa_shape_type(mlir::Type t); + +} // namespace tosa + } // namespace mlir #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.h.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 5e91f6e13e145..23cb3685795d4 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1713,12 +1713,16 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> { let arguments = (ins Tosa_Tensor:$input1, - DenseI64ArrayAttr:$multiples); + Tosa_Shape:$multiples); let results = (outs Tosa_Tensor:$output ); + let extraClassDeclaration = [{ + LogicalResult getConstantMultiples(llvm::SmallVector &multiples); + }]; + let hasFolder = 1; let hasVerifier = 1; } @@ -2130,4 +2134,6 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [ include "mlir/Dialect/Tosa/IR/TosaUtilOps.td" +include "mlir/Dialect/Tosa/IR/TosaShapeOps.td" + #endif // TOSA_OPS diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td new file mode 100644 index 0000000000000..597dc32e84402 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td @@ -0,0 +1,77 @@ +//===-- TosaShapeOps.td - TOSA dialect utility operations --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines shape operators for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOSA_SHAPE_OPS +#define TOSA_SHAPE_OPS + +include "mlir/IR/OpBase.td" + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" + +include "mlir/Dialect/Tosa/IR/TosaTypesBase.td" +include "mlir/Dialect/Tosa/IR/TosaOpBase.td" + +// Op trait: operator has operands and results with TOSA shape type +def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> { + let cppNamespace = "mlir::OpTrait::tosa"; +} + +class Tosa_ShapeOp traits = []> + : Tosa_Op { + + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; + + let hasFolder = 1; +} + +// op trait: shape operator has same ranks for operands and results +def TosaShapeOperatorWithSameRanks + : NativeOpTrait<"TosaShapeOperatorWithSameRanks"> { + let cppNamespace = "mlir::OpTrait::tosa"; +} + +class Tosa_ElementwiseShapeOp traits = []> + : Tosa_ShapeOp { +} + + +//===----------------------------------------------------------------------===// +// Operator: ConstShape +//===----------------------------------------------------------------------===// +def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> { + let summary = "Constant Shape op."; + + let description = [{ + A node containing constant data for use as the input to an shape operation. May + hold data only in index data type. + + Example: + + ```mlir + // Generic form + %out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> + ``` + }]; + + let arguments = (ins IndexElementsAttr : $value); + + let results = (outs Tosa_Shape : $output); + + let hasVerifier = 1; +} + +#endif // TOSA_SHAPE_OPS diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 8a56d3212b2af..5ca7720508d54 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -13,8 +13,11 @@ #ifndef TOSA_TYPES_BASE #define TOSA_TYPES_BASE +include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" +include "mlir/Dialect/Tosa/IR/TosaOpBase.td" + //===----------------------------------------------------------------------===// // Tosa Type Definitions. //===----------------------------------------------------------------------===// @@ -218,4 +221,66 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>; def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>; def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>; +//===----------------------------------------------------------------------===// +// Tosa Type Definitions. +//===----------------------------------------------------------------------===// + +// The base class for Tosa dialect types. +class Tosa_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +//===----------------------------------------------------------------------===// +// ShapeType +//===----------------------------------------------------------------------===// +def Tosa_Shape : Tosa_Type<"shape", "shape"> { + let summary = "Shape with static rank and Index element type"; + let description = [{ + Syntax: + + ``` shape - type :: = `shape` `<` rank `>` + ``` Values with shape type represents a shape with a fixed rank and a list + of dimensions + .Rank must be zero or a positive integer + .Each dimension is represented by the builtin + Index type. + + Examples: + + ```mlir + // Shape with rank of four, for example, [1, 1, 8, 16]: + !tosa + .shape<4> + + // Shape with rank of one, for example, [16]: + !tosa + .shape<1> + + // Shape with rank zero, for example, [] (i.e., shape of scalar values): + !tosa.shape<0> + ``` + }]; + let parameters = (ins "int" : $rank); + let builders = [TypeBuilder<(ins "int" : $rank)>]; + let assemblyFormat = "`<` $rank `>`"; + + let genVerifyDecl = 1; +} + +def IsTosaShapeType : CPred<"mlir::tosa::isa_tosa_shape_type($_self)">; + +// Whether a Tosa Shape type has a rank equal to the specified rank. +class IsTosaShapeOfRankPred : And<[ + IsTosaShapeType, + CPred<[{::llvm::cast<::mlir::tosa::shapeType>($_self).getRank() == }] # rank> +]>; + +class TosaShapeOfRank + : Type, "Tosa shape type of rank " #rank>; + +def Rank1TosaShape : TosaShapeOfRank<1>; +def Rank2TosaShape : TosaShapeOfRank<2>; +def Rank4TosaShape : TosaShapeOfRank<4>; + #endif // TOSA_TYPES_BASE diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 85aa5d2dc381b..12fa94b9e62fb 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2033,7 +2033,9 @@ struct TileConverter : public OpConversionPattern { auto elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); - ArrayRef multiples = operands.getMultiples(); + SmallVector multiples; + if (failed(op.getConstantMultiples(multiples))) + return failure(); // Broadcast the newly added dimensions to their appropriate multiple. SmallVector genericShape; diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index fc7a0a5a7175b..78b3d61a80423 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -56,6 +56,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase { target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 9ee416f774a19..7d21317d5e235 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1137,6 +1137,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } +OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } + #define REDUCE_FOLDER(OP) \ OpFoldResult OP::fold(FoldAdaptor adaptor) { \ ShapedType inputTy = llvm::cast(getInput().getType()); \ @@ -1318,9 +1320,20 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { } OpFoldResult TileOp::fold(FoldAdaptor adaptor) { - bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; }); - if (allOnes && getInput1().getType() == getType()) - return getInput1(); + if (getInput1().getType() == getType()) { + if (auto multiples = llvm::dyn_cast_if_present( + adaptor.getMultiples())) { + if (multiples.isSplat() && + multiples.getSplatValue().getSExtValue() == 1) + return getInput1(); + if (auto int_array_attr = + llvm::dyn_cast(multiples)) { + if (llvm::all_of(int_array_attr.getValues(), + [](APInt v) { return v.getSExtValue() == 1; })) + return getInput1(); + } + } + } return {}; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index b96b0eb16104b..f55f99aaef738 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -130,6 +130,10 @@ SmallVector tosa::WhileOp::getLoopRegions() { return {&getBody()}; } //===----------------------------------------------------------------------===// void TosaDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc" + >(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" @@ -153,6 +157,10 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { // Tosa dialect constants only support ElementsAttr unlike standard dialect // constant which supports all attributes. + if (llvm::isa(type) && llvm::isa(value)) { + return builder.create( + loc, type, llvm::cast(value)); + } if (llvm::isa(value)) return builder.create(loc, type, llvm::cast(value)); @@ -1024,11 +1032,30 @@ LogicalResult tosa::TableOp::verify() { return success(); } +LogicalResult +tosa::TileOp::getConstantMultiples(SmallVector &multiples) { + // Multiples must be constants. + DenseIntElementsAttr multiplesAttr; + if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr))) + return failure(); + multiples = llvm::to_vector( + llvm::map_range(multiplesAttr.getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + return success(); +} + LogicalResult tosa::TileOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TileOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - ArrayRef multiples = adaptor.getMultiples(); + DenseIntElementsAttr multiplesAttr; + if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr))) + return failure(); + + SmallVector multiples = llvm::to_vector( + llvm::map_range(multiplesAttr.getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + ShapeAdaptor inputShape(adaptor.getInput1().getType()); SmallVector outputShape; if (!inputShape.hasRank()) { @@ -1054,20 +1081,25 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents( LogicalResult tosa::TileOp::verify() { ShapedType inputType = llvm::cast(getInput1().getType()); ShapedType outputType = llvm::cast(getType()); - auto multiples = getMultiples(); + + shapeType multiplesType = + llvm::cast(getMultiples().getType()); + + auto multiplesRank = multiplesType.getRank(); if (inputType.hasRank()) { - if (static_cast(inputType.getRank()) != multiples.size()) - return emitOpError("expect 'multiples' array to have length ") - << inputType.getRank() << " but got " << multiples.size() << "."; + if (inputType.getRank() != multiplesRank) + return emitOpError("expect 'multiples' to have rank ") + << inputType.getRank() << " but got " << multiplesRank << "."; if (outputType.hasRank() && inputType.getRank() != outputType.getRank()) return emitOpError("expect same input and output tensor rank."); - } else if (outputType.hasRank() && - static_cast(outputType.getRank()) != multiples.size()) + } else if (outputType.hasRank() && outputType.getRank() != multiplesRank) return emitOpError("expect 'multiples' array to have length ") - << outputType.getRank() << " but got " << multiples.size() << "."; + << outputType.getRank() << " but got " << multiplesRank << "."; - if (llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; })) + SmallVector multiples; + if (getConstantMultiples(multiples).succeeded() && + llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; })) return emitOpError( "expect element of 'multiples' to be positive integer or -1."); @@ -2225,6 +2257,91 @@ void WhileOp::print(OpAsmPrinter &parser) { parser.printOptionalAttrDictWithKeyword((*this)->getAttrs()); } +//===----------------------------------------------------------------------===// +// TOSA Shape and Shape Operators Helper functions. +//===----------------------------------------------------------------------===// + +bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) { + return mlir::isa(t); +} + +LogicalResult +mlir::tosa::shapeType::verify(function_ref emitError, + int rank) { + if (rank < 0) + return emitError() << "invalid rank (must be >= 0): " << rank; + return success(); +} + +LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) { + for (auto v : op->getOperands()) { + if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) { + Operation *definingOp = v.getDefiningOp(); + if (!definingOp || !definingOp->hasTrait()) { + return op->emitOpError("shape operand is not compile time resolvable"); + } + } + } + return success(); +} + +LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) { + for (auto type : op->getOperandTypes()) { + if (!mlir::isa(type)) { + return op->emitOpError("must have operands with tosa shape type"); + } + } + for (auto type : op->getResultTypes()) { + if (!mlir::isa(type)) { + return op->emitOpError("must have result with tosa shape type"); + } + } + return success(); +} + +LogicalResult +OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) { + if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) || + failed(verifyTosaShapeOperator(op))) + return failure(); + + // delegate function that returns rank of shape type + auto getRank = [](const Type type) { + return mlir::cast(type).getRank(); + }; + auto operandTypes = op->getOperandTypes(); + auto resultTypes = op->getResultTypes(); + + auto rank = getRank(*op->getOperandTypes().begin()); + for (auto type : operandTypes) { + if (getRank(type) != rank) { + return op->emitOpError("operands don't have matching ranks"); + } + } + for (auto type : resultTypes) { + if (getRank(type) != rank) { + return op->emitOpError("result shape has different rank than operands"); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// TOSA Shape Operators verify functions. +//===----------------------------------------------------------------------===// + +LogicalResult tosa::ConstShapeOp::verify() { + // check that number of elements in value attr equal to rank of result shape + auto count = getValue().getNumElements(); + auto rank = (cast(getResult().getType())).getRank(); + if (!(count == rank || (count == 1 && rank == 0))) { + return emitOpError("expect number of elements in attribute value (") + << count << ") to be equal to the rank (" << rank + << ") for the result shape type"; + } + return success(); +} + //===----------------------------------------------------------------------===// // TOSA Attribute Definitions. //===----------------------------------------------------------------------===// @@ -2232,6 +2349,12 @@ void WhileOp::print(OpAsmPrinter &parser) { #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" +//===----------------------------------------------------------------------===// +// TOSA Type Definitions. +//===----------------------------------------------------------------------===// +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc" + //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 8588c878bfe4f..a49870687fdc6 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -536,6 +536,8 @@ bool TosaValidation::isValidElementType(Type type) { return true; } } + } else if (mlir::isa(type)) { + return true; } return false; } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 070ade2b788b1..2c2fdcd7fffb1 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1467,21 +1467,24 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () { // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8 // CHECK: linalg.yield %[[ARG1]] : i8 // CHECK: tosa.reshape [[GENERIC]] {new_shape = array} - %0 = tosa.tile %arg0 {multiples = array} : (tensor<2x3xi8>) -> tensor<4x3xi8> + %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = tosa.tile %arg0, %cst21: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<4x3xi8> // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>) // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8 // CHECK: linalg.yield %[[ARG1]] : i8 // CHECK: tosa.reshape [[GENERIC]] {new_shape = array} - %1 = tosa.tile %arg0 {multiples = array} : (tensor<2x3xi8>) -> tensor<2x6xi8> + %cst12 = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> + %1 = tosa.tile %arg0, %cst12: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<2x6xi8> // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>) // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8 // CHECK: linalg.yield %[[ARG1]] : i8 // CHECK: tosa.reshape [[GENERIC]] {new_shape = array} - %2 = tosa.tile %arg0 {multiples = array} : (tensor<2x3xi8>) -> tensor<10x21xi8> + %cst57 = tosa.const_shape { value = dense<[5, 7]> : tensor<2xindex> } : () -> !tosa.shape<2> + %2 = tosa.tile %arg0, %cst57: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<10x21xi8> return } @@ -1501,7 +1504,8 @@ func.func @tile_dyn_input(%arg0 : tensor) -> () { // CHECK: ^bb0(%[[ARG1:.+]]: i8, // CHECK: linalg.yield %[[ARG1]] : i8 // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array} - %0 = tosa.tile %arg0 {multiples = array} : (tensor) -> tensor + %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = tosa.tile %arg0, %cst21: (tensor, !tosa.shape<2>) -> tensor return } @@ -1521,7 +1525,8 @@ func.func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () { // CHECK: ^bb0(%[[ARG1:.+]]: i8, // CHECK: linalg.yield %[[ARG1]] : i8 // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array} - %0 = tosa.tile %arg0 {multiples = array} : (tensor<2x3xi8>) -> tensor<2x?xi8> + %cst = tosa.const_shape { value = dense<[2, -1]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = tosa.tile %arg0, %cst: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<2x?xi8> return } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 8e40ed39bd7dc..1fe85b6411a7a 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -693,7 +693,8 @@ func.func @slice_nofold(%arg0: tensor) -> tensor { // CHECK-LABEL: @tile_fold func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: return %arg0 - %0 = tosa.tile %arg0 { multiples = array }: (tensor<3x4xf32>) -> tensor<3x4xf32> + %cst = tosa.const_shape { value = dense<1> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = tosa.tile %arg0, %cst: (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32> return %0 : tensor<3x4xf32> } @@ -702,7 +703,8 @@ func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK-LABEL: @tile_nofold func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> { // CHECK: tosa.tile - %0 = tosa.tile %arg0 { multiples = array }: (tensor<3x4xf32>) -> tensor<3x8xf32> + %cst = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = tosa.tile %arg0, %cst: (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<3x8xf32> return %0 : tensor<3x8xf32> } @@ -915,7 +917,8 @@ func.func @fold_reduce_rank_zero() { func.func nested @fold_tile_rank_zero() -> tensor { // CHECK-NOT: tosa.tile %0 = tensor.empty() : tensor - %1 = tosa.tile %0 {multiples = array} : (tensor) -> tensor + %cst = tosa.const_shape { value = dense<> : tensor<0xindex> } : () -> !tosa.shape<0> + %1 = tosa.tile %0, %cst : (tensor, !tosa.shape<0>) -> tensor return %1 : tensor } diff --git a/mlir/test/Dialect/Tosa/constant-tile.mlir b/mlir/test/Dialect/Tosa/constant-tile.mlir index 0b8d129935e39..a69e1d54009a7 100644 --- a/mlir/test/Dialect/Tosa/constant-tile.mlir +++ b/mlir/test/Dialect/Tosa/constant-tile.mlir @@ -6,7 +6,8 @@ func.func @tile_int_one_dim() -> (tensor<6xi32>) { // CHECK: "tosa.const"() <{value = dense // CHECK-SAME: {{\[}}1, 2, 3, 1, 2, 3] %0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> - %1 = tosa.tile %0 {multiples = array} : (tensor<3xi32>) -> tensor<6xi32> + %cst = tosa.const_shape { value = dense<[2]> : tensor<1xindex> } : () -> !tosa.shape<1> + %1 = tosa.tile %0, %cst : (tensor<3xi32>, !tosa.shape<1>) -> tensor<6xi32> return %1 : tensor<6xi32> // NO-FOLDING-CHECK: tosa.tile } @@ -21,7 +22,8 @@ func.func @tile_bool() -> (tensor<1x3x2x3xi1>) { // CHECK-SAME: [true, true, true], // CHECK-SAME: [true, true, true]]]] %0 = "tosa.const"() {value = dense<[[[[true]], [[false]], [[true]]]]> : tensor<1x3x1x1xi1>} : () -> tensor<1x3x1x1xi1> - %1 = tosa.tile %0 {multiples = array} : (tensor<1x3x1x1xi1>) -> tensor<1x3x2x3xi1> + %cst = tosa.const_shape { value = dense<[1, 1, 2, 3]> : tensor<4xindex> } : () -> !tosa.shape<4> + %1 = tosa.tile %0, %cst : (tensor<1x3x1x1xi1>, !tosa.shape<4>) -> tensor<1x3x2x3xi1> return %1 : tensor<1x3x2x3xi1> // NO-FOLDING-CHECK: tosa.tile } @@ -36,7 +38,8 @@ func.func @tile_bf16() -> (tensor<1x3x2x2xbf16>) { // CHECK-SAME: [2.000000e+00, 4.000000e+00], // CHECK-SAME: [2.000000e+00, 4.000000e+00]]]] %0 = "tosa.const"() {value = dense<[[[[0.25, 0.125]], [[0.5, 1.0]], [[2.0, 4.0]]]]> : tensor<1x3x1x2xbf16>} : () -> tensor<1x3x1x2xbf16> - %1 = tosa.tile %0 {multiples = array} : (tensor<1x3x1x2xbf16>) -> tensor<1x3x2x2xbf16> + %cst = tosa.const_shape { value = dense<[1, 1, 2, 1]> : tensor<4xindex> } : () -> !tosa.shape<4> + %1 = tosa.tile %0, %cst : (tensor<1x3x1x2xbf16>, !tosa.shape<4>) -> tensor<1x3x2x2xbf16> return %1 : tensor<1x3x2x2xbf16> // NO-FOLDING-CHECK: tosa.tile } @@ -49,7 +52,8 @@ func.func @tile_f32() -> (tensor<4x4xf32>) { // CHECK-SAME: [2.500000e-01, 1.250000e+00, 2.500000e-01, 1.250000e+00], // CHECK-SAME: [2.250000e+00, 3.250000e+00, 2.250000e+00, 3.250000e+00]] %0 = "tosa.const"() {value = dense<[[0.25, 1.25],[2.25, 3.25]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> - %1 = tosa.tile %0 {multiples = array} : (tensor<2x2xf32>) -> tensor<4x4xf32> + %cst = tosa.const_shape { value = dense<[2, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> + %1 = tosa.tile %0, %cst: (tensor<2x2xf32>, !tosa.shape<2>) -> tensor<4x4xf32> return %1 : tensor<4x4xf32> // NO-FOLDING-CHECK: tosa.tile } @@ -82,7 +86,8 @@ func.func @tile_int_many_dimensions() -> (tensor<4x6x4xi32>) { // CHECK-SAME: [9, 10, 9, 10], // CHECK-SAME: [11, 12, 11, 12]]] %0 = "tosa.const"() {value = dense<[[[1, 2],[3, 4],[5, 6]], [[7, 8],[9, 10],[11, 12]]]> : tensor<2x3x2xi32>} : () -> tensor<2x3x2xi32> - %1 = tosa.tile %0 {multiples = array} : (tensor<2x3x2xi32>) -> tensor<4x6x4xi32> + %cst = tosa.const_shape { value = dense<[2, 2, 2]> : tensor<3xindex> } : () -> !tosa.shape<3> + %1 = tosa.tile %0, %cst : (tensor<2x3x2xi32>, !tosa.shape<3>) -> tensor<4x6x4xi32> // NO-FOLDING-CHECK: tosa.tile return %1 : tensor<4x6x4xi32> } @@ -103,7 +108,8 @@ func.func @tile_f16_many_dimensions() -> (tensor<6x2x2xf16>) { // CHECK-SAME: {{\[\[}}3.000000e+00, 3.000000e+00], // CHECK-SAME: [3.000000e+00, 3.000000e+00]]] %0 = "tosa.const"() {value = dense<[[[1.0]], [[2.0]], [[3.0]]]> : tensor<3x1x1xf16>} : () -> tensor<3x1x1xf16> - %1 = tosa.tile %0 {multiples = array} : (tensor<3x1x1xf16>) -> tensor<6x2x2xf16> + %cst = tosa.const_shape { value = dense<[3, 2, 1]> : tensor<3xindex> } : () -> !tosa.shape<3> + %1 = tosa.tile %0, %cst : (tensor<3x1x1xf16>, !tosa.shape<3>) -> tensor<6x2x2xf16> // NO-FOLDING-CHECK: tosa.tile return %1 : tensor<6x2x2xf16> } @@ -112,7 +118,8 @@ func.func @tile_f16_many_dimensions() -> (tensor<6x2x2xf16>) { func.func @tile_i1_splat() -> (tensor<1x2x2x2xi1>) { // CHECK: "tosa.const"() <{value = dense : tensor<1x2x2x2xi1>}> %0 = "tosa.const"() <{value = dense : tensor<1x1x1x1xi1>}> : () -> tensor<1x1x1x1xi1> - %1 = tosa.tile %0 {multiples = array} : (tensor<1x1x1x1xi1>) -> tensor<1x2x2x2xi1> + %cst = tosa.const_shape { value = dense<[1, 2, 2, 2]> : tensor<4xindex> } : () -> !tosa.shape<4> + %1 = tosa.tile %0, %cst : (tensor<1x1x1x1xi1>, !tosa.shape<4>) -> tensor<1x2x2x2xi1> // NO-FOLDING-CHECK: tosa.tile return %1 : tensor<1x2x2x2xi1> } @@ -121,7 +128,8 @@ func.func @tile_i1_splat() -> (tensor<1x2x2x2xi1>) { func.func @tile_i32_splat() -> (tensor<1x2x2x2xi32>) { // CHECK: "tosa.const"() <{value = dense<2> : tensor<1x2x2x2xi32>}> %0 = "tosa.const"() <{value = dense<2> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32> - %1 = tosa.tile %0 {multiples = array} : (tensor<1x1x1x1xi32>) -> tensor<1x2x2x2xi32> + %cst = tosa.const_shape { value = dense<[1, 2, 2, 2]> : tensor<4xindex> } : () -> !tosa.shape<4> + %1 = tosa.tile %0, %cst : (tensor<1x1x1x1xi32>, !tosa.shape<4>) -> tensor<1x2x2x2xi32> // NO-FOLDING-CHECK: tosa.tile return %1 : tensor<1x2x2x2xi32> } @@ -130,7 +138,8 @@ func.func @tile_i32_splat() -> (tensor<1x2x2x2xi32>) { func.func @tile_f16_splat() -> (tensor<1x2x2x2xf16>) { // CHECK: "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x2x2x2xf16>}> %0 = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1x1xf16>}> : () -> tensor<1x1x1x1xf16> - %1 = tosa.tile %0 {multiples = array} : (tensor<1x1x1x1xf16>) -> tensor<1x2x2x2xf16> + %cst = tosa.const_shape { value = dense<[1, 2, 2, 2]> : tensor<4xindex> } : () -> !tosa.shape<4> + %1 = tosa.tile %0, %cst : (tensor<1x1x1x1xf16>, !tosa.shape<4>) -> tensor<1x2x2x2xf16> // NO-FOLDING-CHECK: tosa.tile return %1 : tensor<1x2x2x2xf16> } @@ -139,7 +148,8 @@ func.func @tile_f16_splat() -> (tensor<1x2x2x2xf16>) { func.func @tile_bf16_splat() -> (tensor<1x2x2x2xbf16>) { // CHECK: "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x2x2x2xbf16>}> %0 = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1x1xbf16>}> : () -> tensor<1x1x1x1xbf16> - %1 = tosa.tile %0 {multiples = array} : (tensor<1x1x1x1xbf16>) -> tensor<1x2x2x2xbf16> + %cst = tosa.const_shape { value = dense<[1, 2, 2, 2]> : tensor<4xindex> } : () -> !tosa.shape<4> + %1 = tosa.tile %0, %cst : (tensor<1x1x1x1xbf16>, !tosa.shape<4>) -> tensor<1x2x2x2xbf16> // NO-FOLDING-CHECK: tosa.tile return %1 : tensor<1x2x2x2xbf16> } \ No newline at end of file diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index cab30c8b89c35..6d547be857e24 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -669,8 +669,9 @@ func.func @test_slice_invalid_size() { func.func @test_tile_invalid_multiples() { %0 = tensor.empty() : tensor<4x31x31xf32> - // expected-error@+1 {{'tosa.tile' op expect 'multiples' array to have length 3 but got 0.}} - %1 = tosa.tile %0 {multiples = array} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32> + %cst = tosa.const_shape { value = dense<1> : tensor<1xindex> } : () -> !tosa.shape<1> + // expected-error@+1 {{'tosa.tile' op expect 'multiples' to have rank 3 but got 1.}} + %1 = tosa.tile %0, %cst: (tensor<4x31x31xf32>, !tosa.shape<1>) -> tensor<4x31x31xf32> return } @@ -678,8 +679,9 @@ func.func @test_tile_invalid_multiples() { func.func @test_tile_invalid_multiples_value() { %0 = tensor.empty() : tensor<4x31xf32> + %multiples = tosa.const_shape { value = dense<[2, -2]> : tensor<2xindex> } : () -> !tosa.shape<2> // expected-error@+1 {{'tosa.tile' op expect element of 'multiples' to be positive integer or -1.}} - %1 = tosa.tile %0 {multiples = array} : (tensor<4x31xf32>) -> tensor<4x31xf32> + %1 = tosa.tile %0, %multiples : (tensor<4x31xf32>, !tosa.shape<2>) -> tensor<4x31xf32> return } @@ -687,8 +689,9 @@ func.func @test_tile_invalid_multiples_value() { func.func @test_tile_io_rank_mismatch() { %0 = tensor.empty() : tensor<4x31xf32> + %multiples = tosa.const_shape { value = dense<[2, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> // expected-error@+1 {{'tosa.tile' op expect same input and output tensor rank.}} - %1 = tosa.tile %0 {multiples = array} : (tensor<4x31xf32>) -> tensor<4x31x31xf32> + %1 = tosa.tile %0, %multiples : (tensor<4x31xf32>, !tosa.shape<2>) -> tensor<4x31x31xf32> return } @@ -1041,3 +1044,25 @@ func.func @test_non_tosa_ops() { %2 = tensor.empty(%0) : tensor return } + +// ----- + +// expected-error@+1 {{invalid rank (must be >= 0): -1}} +func.func @test_shape_type(%arg0: !tosa.shape<-1>) -> !tosa.shape<-1> { + return %arg0 : !tosa.shape<-1> +} + +// ----- +func.func @test_const_shape() -> !tosa.shape<4> { + // expected-error@+1 {{'tosa.const_shape' op attribute 'value' failed to satisfy constraint: index elements attribute}} + %cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> !tosa.shape<4> + return %cst : !tosa.shape<4> +} + +// ----- + +func.func @test_const_shape_value() -> !tosa.shape<5> { + // expected-error@+1 {{'tosa.const_shape' op expect number of elements in attribute value (4) to be equal to the rank (5) for the result shape type}} + %cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<5> + return %cst : !tosa.shape<5> +} diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index ba8ed8a1e5f50..0fe35d88f0e73 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -95,8 +95,9 @@ func.func @test_slice(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11 // ----- // CHECK-LABEL: tile func.func @test_tile(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32> { + %cst = tosa.const_shape { value = dense<[1, 1, 1, 1, 3, 1, 2]> : tensor<7xindex> } : () -> !tosa.shape<7> // expected-error@+1 {{'tosa.tile' op failed level check: operand rank(shape) <= MAX_RANK}} - %0 = "tosa.tile"(%arg0) {multiples = array} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32> + %0 = tosa.tile %arg0, %cst : (tensor<1x1x1x1x13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x39x21x6xf32> return %0 : tensor<1x1x1x1x39x21x6xf32> } @@ -740,4 +741,3 @@ func.func @test_unranked_tensor(%arg0: tensor<*xf32>) { (tensor<*xf32>) -> tensor<*xf32> return } - diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index f2e1cff72ab28..690e208af1e5f 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -562,7 +562,8 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> { // ----- // CHECK-LABEL: tile func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> { - %0 = tosa.tile %arg0 {multiples = array} : (tensor<13x21x3xf32>) -> tensor<39x21x6xf32> + %cst = tosa.const_shape { value = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<39x21x6xf32> return %0 : tensor<39x21x6xf32> } @@ -692,3 +693,10 @@ func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> { %0 = tosa.custom %arg0 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>) -> (tensor<10xi32>) return %0 : tensor<10xi32> } + +// ----- +// CHECK-LABEL: const_shape +func.func @test_const_shape() -> !tosa.shape<4> { + %cst = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> + return %cst : !tosa.shape<4> +} diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index b16e0c296d4a0..f253ec157c330 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -543,8 +543,10 @@ func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () { // CHECK-LABEL: @test_tile func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () { - // CHECK: tosa.tile %arg0 {multiples = array} : (tensor<2x3x?xi32>) -> tensor<4x3x?xi32> - %0 = tosa.tile %arg0 {multiples = array} : (tensor<2x3x?xi32>) -> tensor + // CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[2, 1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x3x?xi32> + %cst = tosa.const_shape {value = dense<[2, 1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3> + %0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor return } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 5c2a77ca67fd4..d3f3697903d72 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -12115,6 +12115,14 @@ gentbl_cc_library( ["-gen-dialect-defs"], "include/mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc", ), + ( + ["-gen-typedef-decls"], + "include/mlir/Dialect/Tosa/IR/TosaOpsTypesBase.h.inc", + ), + ( + ["-gen-typedef-defs"], + "include/mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc", + ), ( ["-gen-attrdef-decls"], "include/mlir/Dialect/Tosa/IR/TosaAttributes.h.inc",