From d319b8ce11de26bfd65c2728170e720b70c10d20 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 16 May 2023 15:17:30 +0200 Subject: [PATCH 01/10] [mlir][tosa] Fix constant folding of tosa.mul --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 + mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 4 ++-- mlir/test/Dialect/Tosa/canonicalize.mlir | 13 +++++++++++++ mlir/test/Dialect/Tosa/invalid.mlir | 8 ++++++++ 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 015be1f8f4885..0ac000e2bd978 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1835,6 +1835,7 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure, // Operator: const //===----------------------------------------------------------------------===// def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure, + AllShapesMatch<["value", "output"]>, FirstAttrDerivedResultType]> { let summary = "Constant op."; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 7e161b1298211..97fb220b53290 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -647,13 +647,13 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { const int64_t shift = resultETy.isa() ? getShift() : 0; if (rhsTy == resultTy) { if (isSplatZero(resultETy, lhsAttr)) - return lhsAttr; + return lhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, lhsAttr, shift)) return rhs; } if (lhsTy == resultTy) { if (isSplatZero(resultETy, rhsAttr)) - return rhsAttr; + return rhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, rhsAttr, shift)) return lhs; } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 9633a9405514b..4ad6ce65655ed 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -203,6 +203,19 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> { return %1 : tensor<2x3xi32> } +// CHECK-LABEL: @mul_zero_broadcast +func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) { + // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}> : () -> tensor<2x3xf32> + // CHECK-NOT: tosa.mul + %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %1 = "tosa.mul"(%arg0, %zeros) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32> + + // CHECK-NOT: tosa.mul + // CHECK: return %[[ZERO]], %[[ZERO]] + %2 = "tosa.mul"(%zeros, %arg0) {shift = 0 : i32} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32> +} + // CHECK-LABEL: @select_same_value func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { %0 = "tosa.select"(%arg0, %arg1, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index edb4bb0a873ec..e285a9de1d66d 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -143,3 +143,11 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () { %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<13x21x3xf32>) -> tensor<13x21x3x1xi32> return } + +// ----- + +func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> { + // expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}} + %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32> + return %0 : tensor<100x100xf32> +} From 07d8cd0edadce74d7f3c75e0f052d5dbf9fd2d15 Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Thu, 1 Jun 2023 07:59:25 +0000 Subject: [PATCH 02/10] Support lowering tosa.custom_op to another dialect operation. --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 3f970befa38dc..a7750a7f7518c 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -506,6 +506,13 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } } + // tosa::CustomOp + if (auto customOp = dyn_cast(op)) { + return llvm::StringSwitch(customOp.getIdentifierAttr().str()) + .Case("atan2", rewriter.create(loc, resultTypes, args)) + .Default(nullptr); + } + (void)rewriter.notifyMatchFailure( op, "unhandled op for linalg body calculation for elementwise op"); return nullptr; @@ -2067,6 +2074,7 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns( PointwiseConverter, PointwiseConverter, PointwiseConverter, + PointwiseConverter, IdentityNConverter, ReduceConverter, ReduceConverter, From 88b3950278d233623757d8a5ca6dd31772a4742c Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Thu, 1 Jun 2023 11:26:48 +0000 Subject: [PATCH 03/10] Adds lit_tests for tosa.custom_op lowering to LinAlg. --- .../TosaToLinalg/tosa-to-linalg.mlir | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 3e654ab9c56b0..b94867a9f7e51 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1410,3 +1410,25 @@ func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, % return %0 : tensor<1x12x5x5xf32> } +// ----- + +// CHECK-LABEL: @test_custom_ops +func.func @test_custom_ops(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> () { + // CHECK: linalg.generic + // CHECK: math.atan2 + %2 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "atan2", implementation_attrs = "UNDEF"}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + return +} + + +// ----- + +// CHECK-LABEL: @test_custom_ops_dyn +func.func @test_custom_ops_dyn(%arg0: tensor, %arg1: tensor) -> () { + // CHECK: linalg.generic + // CHECK: math.atan2 + %2 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "atan2", implementation_attrs = "UNDEF"}> : (tensor, tensor) -> tensor + + return +} \ No newline at end of file From 4880bfccb767e0e8ffc5cab29d72f792d126bf6d Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Tue, 30 May 2023 10:43:24 -0700 Subject: [PATCH 04/10] Lowering for 'tosa.scatter' This patch adds support for `tosa.scatter` lowering in the `--tosa-to-scf` pass. Here's an example for this lowering: ``` func.func @tosa( %valuesIn : tensor<3x7x5xi32>, %indices : tensor<3x6xi32>, %input : tensor<3x6x5xi32>) -> tensor<3x7x5xi32> { %0 = "tosa.scatter"(%valuesIn, %indices, %input) : (tensor<3x7x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> (tensor<3x7x5xi32>) return %0 : tensor<3x7x5xi32> } ``` translates to func.func @tosa(%arg0: tensor<3x7x5xi32>, %arg1: tensor<3x6xi32>, %arg2: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> { %c0 = arith.constant 0 : index %c3 = arith.constant 3 : index %c1 = arith.constant 1 : index %c6 = arith.constant 6 : index %c2 = arith.constant 2 : index %c5 = arith.constant 5 : index %c0_0 = arith.constant 0 : index %c1_1 = arith.constant 1 : index %0 = scf.for %arg3 = %c0_0 to %c3 step %c1_1 iter_args(%arg4 = %arg0) -> (tensor<3x7x5xi32>) { %1 = scf.for %arg5 = %c0_0 to %c6 step %c1_1 iter_args(%arg6 = %arg4) -> (tensor<3x7x5xi32>) { %extracted = tensor.extract %arg1[%arg3, %arg5] : tensor<3x6xi32> %2 = arith.index_cast %extracted : i32 to index %extracted_slice = tensor.extract_slice %arg2[%arg3, %arg5, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<3x6x5xi32> to tensor %inserted_slice = tensor.insert_slice %extracted_slice into %arg6[%arg3, %2, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor into tensor<3x7x5xi32> scf.yield %inserted_slice : tensor<3x7x5xi32> } scf.yield %1 : tensor<3x7x5xi32> } return %0 : tensor<3x7x5xi32> } ``` We have attempted an alternative lowering pass that uses `tensor.scatter` as an intermediate step. However, we opted to aim straight at the `scf` dialect for the following reasons: - The `tensor.scatter` op doesn't seem to be used anywhere. There is no available lowering pass for this op (although we have one that we'll upstream soon). - The `tosa.scatter` and `tensor.scatter` op have different indexing semantics. The `indices` argument of `tosa.scatter` must be non-trivially modified and restructured (e.g. with a `linalg.generic` op) to adapt to the needs of `tensor.scatter`. While this overhead may be simplified and fused after a subsequent `tensor.scatter` lowering, it adds complex logic and an obscure intermediate state. Unless there is a good reason to go through the `tensor` dialect that we're missing, this additional complexity may not be justified. Reviewed By: eric-k256 Differential Revision: https://reviews.llvm.org/D151117 --- mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp | 73 ++++++++++++++++++- .../Conversion/TosaToSCF/TosaToSCFPass.cpp | 2 +- .../Conversion/TosaToSCF/tosa-to-scf.mlir | 30 ++++++++ 3 files changed, 102 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp index 8f10497d99c32..9139bf191fdf1 100644 --- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp +++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp @@ -82,6 +82,75 @@ class IfOpConverter : public OpRewritePattern { } }; +class ScatterOpConverter : public OpRewritePattern { + static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor, + int64_t dim) { + return builder.createOrFold(loc, tensor, dim); + } + + static Value createIndexConst(OpBuilder &builder, Location loc, + int64_t value) { + return builder.create(loc, value); + } + +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ScatterOp scatter, + PatternRewriter &rewriter) const final { + auto valuesIn = scatter.getValuesIn(); + auto indices = scatter.getIndices(); + auto input = scatter.getInput(); + auto loc = scatter.getLoc(); + + // N, W, C are chosen to match the TOSA spec + auto dimN = createTensorDim(rewriter, loc, input, 0); + auto dimW = createTensorDim(rewriter, loc, input, 1); + auto dimC = createTensorDim(rewriter, loc, input, 2); + + auto zero = createIndexConst(rewriter, loc, 0); + auto one = createIndexConst(rewriter, loc, 1); + + // Loop bounds + auto lbs = llvm::SmallVector(2, zero); + auto steps = llvm::SmallVector(2, one); + auto ubs = llvm::SmallVector{{dimN, dimW}}; + + auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange args) -> scf::ValueVector { + auto n = ivs[0]; + + // Read the index and cast it to index type + auto index = builder.create(loc, indices, ivs); + auto castIndex = builder.create( + loc, builder.getIndexType(), index); + + // Offset, sizes, and strides for the input tensor + auto inputOffset = llvm::to_vector(ivs); + inputOffset.push_back(zero); + + llvm::SmallVector sizes = {one, one, dimC}; + llvm::SmallVector strides = {one, one, one}; + + auto slice = builder.create( + loc, input, inputOffset, sizes, strides); + + // Insert the slice into the output accumulator tensor. + llvm::SmallVector outputOffset = {n, castIndex, zero}; + auto updated = builder.create( + loc, slice, args[0], outputOffset, sizes, strides); + + return {updated}; + }; + + auto loops = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps, + ValueRange{valuesIn}, buildBody); + rewriter.replaceOp(scatter, loops.results); + + return success(); + } +}; + class WhileOpConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -106,6 +175,6 @@ class WhileOpConverter : public OpRewritePattern { void mlir::tosa::populateTosaToSCFConversionPatterns( RewritePatternSet *patterns) { - patterns->add(patterns->getContext()); - patterns->add(patterns->getContext()); + patterns->add( + patterns->getContext()); } diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp index 759b730556d7a..d14535029132f 100644 --- a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp +++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp @@ -37,7 +37,7 @@ struct TosaToSCF : public impl::TosaToSCFBase { RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); auto *op = getOperation(); diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir index 59931137cdf5b..4f0e29539b6e4 100644 --- a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir +++ b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir @@ -56,3 +56,33 @@ func.func @if_test(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) return %0 : tensor } + +// ----- + +// CHECK-LABEL: func @scatter_test +// CHECK-SAME: ([[VALUES_IN:%.+]]: tensor<3x7x5xi32>, [[INDICES:%.+]]: tensor<3x6xi32>, [[INPUT:%.+]]: tensor<3x6x5xi32>) +func.func @scatter_test(%values_in: tensor<3x7x5xi32>, %indices : tensor<3x6xi32>, %input: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> { + + // CHECK-DAG: [[C_0:%.+]] = arith.constant 0 : index + // CHECK-DAG: [[C_1:%.+]] = arith.constant 1 : index + // CHECK-DAG: [[C_2:%.+]] = arith.constant 2 : index + // CHECK-DAG: [[C_3:%.+]] = arith.constant 3 : index + // CHECK-DAG: [[C_5:%.+]] = arith.constant 5 : index + // CHECK-DAG: [[C_6:%.+]] = arith.constant 6 : index + // CHECK-DAG: [[C_0_0:%.+]] = arith.constant 0 : index + // CHECK-DAG: [[C_1_0:%.+]] = arith.constant 1 : index + // CHECK: [[RESULT_0:%.+]] = scf.for [[ITER_VAR_0:%.+]] = [[C_0_0]] to [[C_3]] step [[C_1_0]] iter_args([[ITER_ARG_0:%.+]] = [[VALUES_IN]]) -> (tensor<3x7x5xi32>) { + // CHECK: [[RESULT_1:%.+]] = scf.for [[ITER_VAR_1:%.+]] = [[C_0_0]] to [[C_6]] step [[C_1_0]] iter_args([[ITER_ARG_1:%.+]] = [[ITER_ARG_0]]) -> (tensor<3x7x5xi32>) { + // CHECK-DAG: [[EXTRACTED:%.+]] = tensor.extract [[INDICES]][[[ITER_VAR_0]], [[ITER_VAR_1]]] : tensor<3x6xi32> + // CHECK-DAG: [[EXTRACTED_CAST:%.+]] = arith.index_cast [[EXTRACTED]] : i32 to index + // CHECK-DAG: [[EXTRACTED_SLICE:%.+]] = tensor.extract_slice [[INPUT]][[[ITER_VAR_0]], [[ITER_VAR_1]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor<3x6x5xi32> to tensor + // CHECK-DAG: [[INSERTED_SLICE:%.+]] = tensor.insert_slice [[EXTRACTED_SLICE]] into [[ITER_ARG_1]][[[ITER_VAR_0]], [[EXTRACTED_CAST]], [[C_0_0]]] [[[C_1_0]], [[C_1_0]], [[C_5]]] [[[C_1_0]], [[C_1_0]], [[C_1_0]]] : tensor into tensor<3x7x5xi32> + // CHECK: scf.yield [[INSERTED_SLICE]] : tensor<3x7x5xi32> + // CHECK: } + // CHECK: scf.yield [[RESULT_1]] : tensor<3x7x5xi32> + // CHECK: } + %0 = "tosa.scatter"(%values_in, %indices, %input) : (tensor<3x7x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> (tensor<3x7x5xi32>) + + // CHECK: return [[RESULT_0]] : tensor<3x7x5xi32> + return %0 : tensor<3x7x5xi32> +} From 5de799cb7632a4aa84440ffeb69284cfd713e55b Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Thu, 1 Jun 2023 14:23:00 +0000 Subject: [PATCH 05/10] Generic support for legalizing tosa.custom_op into another dialect operation. --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 11 ++++++++--- mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir | 10 ++++++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index a7750a7f7518c..105ee086db723 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -508,9 +508,14 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, // tosa::CustomOp if (auto customOp = dyn_cast(op)) { - return llvm::StringSwitch(customOp.getIdentifierAttr().str()) - .Case("atan2", rewriter.create(loc, resultTypes, args)) - .Default(nullptr); + // Only legalize tosa.custom_op's that are marked as implementable with + // 'linalg.generic' by looking at the 'implementation_attrs' attribute + auto implementationAttr = customOp.getImplementationAttrs(); + if (implementationAttr == "linalg.generic") { + OperationState state(loc, customOp.getIdentifierAttr(), args, + resultTypes); + return rewriter.create(state)->getResult(0); + } } (void)rewriter.notifyMatchFailure( diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index b94867a9f7e51..6483e29e7a9c2 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1414,9 +1414,12 @@ func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, % // CHECK-LABEL: @test_custom_ops func.func @test_custom_ops(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> () { + // CHECK: linalg.generic + // CHECK: math.sin // CHECK: linalg.generic // CHECK: math.atan2 - %2 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "atan2", implementation_attrs = "UNDEF"}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %2 = "tosa.custom"(%arg0) <{config = "UNDEF", identifier = "math.sin", implementation_attrs = "linalg.generic"}> : (tensor<1xf32>) -> tensor<1xf32> + %3 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "math.atan2", implementation_attrs = "linalg.generic"}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> return } @@ -1426,9 +1429,12 @@ func.func @test_custom_ops(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> () { // CHECK-LABEL: @test_custom_ops_dyn func.func @test_custom_ops_dyn(%arg0: tensor, %arg1: tensor) -> () { + // CHECK: linalg.generic + // CHECK: math.cos // CHECK: linalg.generic // CHECK: math.atan2 - %2 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "atan2", implementation_attrs = "UNDEF"}> : (tensor, tensor) -> tensor + %2 = "tosa.custom"(%arg0) <{config = "UNDEF", identifier = "math.cos", implementation_attrs = "linalg.generic"}> : (tensor) -> tensor + %3 = "tosa.custom"(%arg0, %arg1) <{config = "UNDEF", identifier = "math.atan2", implementation_attrs = "linalg.generic"}> : (tensor, tensor) -> tensor return } \ No newline at end of file From 9b67e540d06dfd6aaec3c1fa64673733c8aedf72 Mon Sep 17 00:00:00 2001 From: Matthias Gehre <93204396+mgehre-amd@users.noreply.github.com> Date: Mon, 12 Jun 2023 15:00:55 +0200 Subject: [PATCH 06/10] TOSA: Fold concat where one argument has zero elements (#41) --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 ++ .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 24 +++++++++++++++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 7 ++++++ 3 files changed, 33 insertions(+) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 0ac000e2bd978..85c7f05b83d95 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1460,6 +1460,8 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [ Tosa_Tensor:$output ); + let hasFolder = 1; + let hasCanonicalizer = 1; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 97fb220b53290..8fdd6cf35500f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -494,6 +494,30 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, // Operator Folders. //===----------------------------------------------------------------------===// +static bool hasZeroSize(Type ty) { + auto ranked = dyn_cast(ty); + if (!ranked) + return false; + return any_of(ranked.getShape(), [](auto d) { return d == 0; }); +} + +OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { + /// Remove operands that have zero elements. + bool changed = false; + for (size_t i = 0; i < getInput1().size(); ) { + auto input = getInput1()[i]; + if (hasZeroSize(input.getType())) { + getInput1Mutable().erase(i); + changed = true; + } else { + ++i; + } + } + if (changed) + return getResult(); + return {}; +} + template DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy) { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 4ad6ce65655ed..e049bc164150c 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -86,6 +86,13 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { return %1 : tensor<4xi8> } +// CHECK-LABEL: @concat_fold_zero +func.func @concat_fold_zero(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "tosa.concat"(%arg1, %arg2) <{axis = 1 : i64}> + %0 = "tosa.concat"(%arg0, %arg1, %arg2) {axis = 1 : i64}: (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + // CHECK-LABEL: @concat_fold func.func @concat_fold(%arg0: tensor) -> tensor { // CHECK: return %arg0 From 0749db1c9722ae0662a6c9b00b9f142c5f01a505 Mon Sep 17 00:00:00 2001 From: Matthias Gehre <93204396+mgehre-amd@users.noreply.github.com> Date: Mon, 12 Jun 2023 15:01:18 +0200 Subject: [PATCH 07/10] Some tosa verifiers (#42) * TOSA: Extend verifier to reshape of to check newShape attr * TOSA: add initial verifier for SliceOp --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 + mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 52 ++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 85c7f05b83d95..1ff274185acd1 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1587,6 +1587,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 5c17d281c2ec7..f68d92a672d91 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -742,6 +742,12 @@ mlir::LogicalResult tosa::ReshapeOp::verify() { ShapedType outputType = getType().cast(); if (inputType.hasStaticShape() && outputType.hasStaticShape()) { + if (getNewShape() != outputType.getShape()) { + return emitOpError() << "newShape attribute " << getNewShape() + << " does not match output type " + << outputType.getShape(); + } + int64_t inputElementsNum = inputType.getNumElements(); int64_t outputElementsNum = outputType.getNumElements(); if (inputElementsNum != outputElementsNum) { @@ -749,6 +755,52 @@ mlir::LogicalResult tosa::ReshapeOp::verify() { << " elements into " << outputElementsNum; } } + + return mlir::success(); +} + +mlir::LogicalResult tosa::SliceOp::verify() { + // TODO: Complete verification + ShapedType inputType = getInput().getType().cast(); + ShapedType outputType = getType().cast(); + + if (inputType.getRank() != outputType.getRank()) { + return emitOpError() << "rank of input (" << inputType.getRank() + << ") and output (" + << outputType.getRank() + << ") must match"; + } + + if (getSize() != outputType.getShape()) { + return emitOpError() << "size attribute " << getSize() + << " does not match output type " + << outputType.getShape(); + } + + if ((int64_t)getStart().size() != inputType.getRank()) { + return emitOpError() << "rank of start (" << getStart().size() + << ") and input (" + << inputType.getRank() + << ") must match"; + } + if ((int64_t)getSize().size() != inputType.getRank()) { + return emitOpError() << "rank of size (" << getSize().size() + << ") and input (" + << inputType.getRank() + << ") must match"; + } + + for (int i = 0; i < outputType.getRank(); ++i) { + auto dimSize = inputType.getShape()[i]; + if (dimSize != ShapedType::kDynamic && getStart()[i] + getSize()[i] > inputType.getShape()[i]) { + return emitOpError() << "start (" << getStart()[i] + << ") plus size (" + << getSize()[i] + << ") goes out of bounds of input size (" + << inputType.getShape()[i] + << ") in dimension " << i; + } + } return mlir::success(); } From 20fa0e82fc5a56cb233a8abcfd4d5108ab4858fa Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Mon, 12 Jun 2023 15:36:31 +0200 Subject: [PATCH 08/10] TorchToLinAlg: fix tosa.clamp legalization for integer types. (#43) --- .../Conversion/TosaToLinalg/TosaToLinalg.cpp | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 105ee086db723..342634364cc10 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -381,23 +381,23 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, if (isa(op) && elementTy.isa()) { auto intTy = elementTy.cast(); - int32_t min = static_cast( - op->getAttr("min_int").cast().getValue().getSExtValue()); - int32_t max = static_cast( - op->getAttr("max_int").cast().getValue().getSExtValue()); + int64_t min = + op->getAttr("min_int").cast().getValue().getSExtValue(); + int64_t max = + op->getAttr("max_int").cast().getValue().getSExtValue(); if (intTy.isUnsignedInteger()) { - min = std::max(min, 0); - max = std::min( + min = std::max(min, (int64_t)0); + max = std::min( max, APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue()); } else { - min = std::max( - min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) - .getSExtValue()); - max = std::min( - max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) - .getSExtValue()); + min = + std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) + .getSExtValue()); + max = + std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) + .getSExtValue()); } auto minVal = rewriter.create( From c2367d64b775ac1e38bdccadb9380e3e88631826 Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Tue, 13 Jun 2023 18:20:38 +0200 Subject: [PATCH 09/10] TosaToLinAlg: fix tosa.cast legalization of FP->Int for non FP32 types. (#45) --- .../Conversion/TosaToLinalg/TosaToLinalg.cpp | 21 +++++++++++++++++-- .../TosaToLinalg/tosa-to-linalg.mlir | 11 ++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 342634364cc10..2e280dba469c9 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -471,16 +471,33 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { - auto intMin = rewriter.create( + Value intMin = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); - auto intMax = rewriter.create( + Value intMax = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); + // Since F32 constants are created, we may still need to convert them to + // the correct type. + auto convertType = [&](Type ty, Value arg) { + auto argTy = arg.getType(); + bool bitExtend = + argTy.getIntOrFloatBitWidth() < ty.getIntOrFloatBitWidth(); + if (ty != argTy) { + if (!bitExtend) + arg = rewriter.create(loc, ty, arg); + else + arg = rewriter.create(loc, ty, arg); + } + return arg; + }; + intMin = convertType(srcTy, intMin); + intMax = convertType(srcTy, intMax); + auto rounded = rewriter.create(loc, args[0]); auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 6483e29e7a9c2..70d09cde7bc7f 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -270,6 +270,17 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () { // CHECK: arith.extf %0 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xf32> + // CHECK: linalg.generic + // CHECK: %[[C_LOWEST:.+]] = arith.constant -2.14748365E+9 + // CHECK: %[[C_MAX:.+]] = arith.constant 2.14748365E+9 + // CHECK: arith.truncf %[[C_LOWEST]] : f32 to f16 + // CHECK: arith.truncf %[[C_MAX]] : f32 to f16 + // CHECK: math.roundeven + // CHECK: arith.minf + // CHECK: arith.maxf + // CHECK: arith.fptosi + %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32> + return } From bc0e73a7d3c8a72598ff61f59726614474cac10c Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 14 Jun 2023 18:03:55 +0200 Subject: [PATCH 10/10] TOSA: Allow to transpose 7D tensors and higher --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 38766d8361167..2d04fb169deae 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1595,12 +1595,12 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [ }]; let arguments = (ins - Tosa_Tensor1Dto6D:$input1, + Tosa_Tensor:$input1, Tosa_Int32Or64Tensor:$perms ); let results = ( - outs Tosa_Tensor1Dto6D:$output + outs Tosa_Tensor:$output ); let extraClassDeclaration = [{