From 3739a0f6b7ffc03232af2732aa73a42beffa272d Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Tue, 9 May 2023 14:21:36 +0000 Subject: [PATCH 1/4] feat(TosaCanonicalizations.cpp): update Add and Sub folding for tensors of different shapes but who are splat and zero they can be folded --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 9016e41ee1b7e..936339ae18504 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -474,7 +474,7 @@ OpFoldResult AddOp::fold(ArrayRef operands) { auto resultTy = getType().dyn_cast(); if (!lhsTy || !rhsTy || !resultTy) return {}; - if (lhsTy != rhsTy) + if (lhsTy.getRank() != rhsTy.getRank()) return {}; auto resultETy = resultTy.getElementType(); @@ -504,6 +504,9 @@ OpFoldResult AddOp::fold(ArrayRef operands) { if (!lhsAttr || !rhsAttr) return {}; + if (lhsTy != rhsTy) + return {}; + return binaryFolder, std::plus>(lhsAttr, rhsAttr, lhsTy); } @@ -635,7 +638,7 @@ OpFoldResult SubOp::fold(ArrayRef operands) { auto resultTy = getType().dyn_cast(); if (!lhsTy || !rhsTy || !resultTy) return {}; - if (lhsTy != rhsTy) + if (lhsTy.getRank() != rhsTy.getRank()) return {}; auto resultETy = resultTy.getElementType(); @@ -655,6 +658,9 @@ OpFoldResult SubOp::fold(ArrayRef operands) { if (!lhsAttr || !rhsAttr) return {}; + if (lhsTy != rhsTy) + return {}; + return binaryFolder, std::minus>(lhsAttr, rhsAttr, lhsTy); } From a68396e4384068c0b830e0f15a3ed0e0fd508639 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Tue, 9 May 2023 15:07:41 +0000 Subject: [PATCH 2/4] test(constant-op-fold.mlir): add test cases for add and sub the cases look at tensors of zero with different shapes but who are splat tensors. --- mlir/test/Dialect/Tosa/constant-op-fold.mlir | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 08115787db58a..c548e1f3fb368 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -164,6 +164,16 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> { // ----- +// CHECK-LABEL: @fold_add_zero_splat_different_shape_f32 +func.func @fold_add_zero_splat_different_shape_f32(%arg0: tensor<1x10xf32>) -> tensor<1x10xf32> { + %zero = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %sub = "tosa.add"(%arg0, %zero) : (tensor<1x10xf32>, tensor<1x1xf32>) -> tensor<1x10xf32> + // CHECK: return %arg0 + return %sub : tensor<1x10xf32> +} + +// ----- + // CHECK-LABEL: @fold_div_zero_lhs_i32 func.func @fold_div_zero_lhs_i32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {value = dense<0> : tensor} : () -> tensor @@ -350,6 +360,16 @@ func.func @fold_sub_splat_f32() -> tensor<10xf32> { // ----- +// CHECK-LABEL: @fold_sub_zero_splat_different_shape_f32 +func.func @fold_sub_zero_splat_different_shape_f32(%arg0: tensor<1x10xf32>) -> tensor<1x10xf32> { + %zero = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %sub = "tosa.sub"(%arg0, %zero) : (tensor<1x10xf32>, tensor<1x1xf32>) -> tensor<1x10xf32> + // CHECK: return %arg0 + return %sub : tensor<1x10xf32> +} + +// ----- + // CHECK-LABEL: @fold_greater_splat_f32 func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { %0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> From bee883fc8803cd3f5c3aa5bd7895e25bc8301a97 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Thu, 11 May 2023 11:50:08 +0000 Subject: [PATCH 3/4] refactor(TosaCanonicalizations.cpp): remove rank check TOSA already enforces that the ranks are the same --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 936339ae18504..92240f59f9f58 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -474,8 +474,6 @@ OpFoldResult AddOp::fold(ArrayRef operands) { auto resultTy = getType().dyn_cast(); if (!lhsTy || !rhsTy || !resultTy) return {}; - if (lhsTy.getRank() != rhsTy.getRank()) - return {}; auto resultETy = resultTy.getElementType(); auto lhsAttr = operands[0].dyn_cast_or_null(); @@ -638,8 +636,6 @@ OpFoldResult SubOp::fold(ArrayRef operands) { auto resultTy = getType().dyn_cast(); if (!lhsTy || !rhsTy || !resultTy) return {}; - if (lhsTy.getRank() != rhsTy.getRank()) - return {}; auto resultETy = resultTy.getElementType(); auto lhsAttr = operands[0].dyn_cast_or_null(); From eb21102e628234cca6e5118db7a23390f0970ae2 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Thu, 11 May 2023 11:51:26 +0000 Subject: [PATCH 4/4] test(constant-op-fold.mlir): broadcast of the ar we cannot replace in this case as the add causes a broadcast and different shape to the output of the operation. Input is a 1x10 shape and the output 4x10. --- mlir/test/Dialect/Tosa/constant-op-fold.mlir | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index c548e1f3fb368..87f47cfc3b815 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -167,9 +167,21 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> { // CHECK-LABEL: @fold_add_zero_splat_different_shape_f32 func.func @fold_add_zero_splat_different_shape_f32(%arg0: tensor<1x10xf32>) -> tensor<1x10xf32> { %zero = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> - %sub = "tosa.add"(%arg0, %zero) : (tensor<1x10xf32>, tensor<1x1xf32>) -> tensor<1x10xf32> + %add = "tosa.add"(%arg0, %zero) : (tensor<1x10xf32>, tensor<1x1xf32>) -> tensor<1x10xf32> // CHECK: return %arg0 - return %sub : tensor<1x10xf32> + return %add : tensor<1x10xf32> +} + +// ----- + +// CHECK-LABEL: @fold_add_zero_broadcast_arg_f32 +func.func @fold_add_zero_broadcast_arg_f32(%arg0: tensor<1x10xf32>) -> tensor<4x10xf32> { + %zero = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<4x10xf32> + %add = "tosa.add"(%arg0, %zero) : (tensor<1x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> + // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<4x10xf32> + // CHECK: %[[ADD:.+]] = "tosa.add"(%arg0, %[[ZERO]]) : (tensor<1x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> + // CHECK: return %[[ADD]] : tensor<4x10xf32> + return %add : tensor<4x10xf32> } // -----