diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 9016e41ee1b7e..92240f59f9f58 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -474,8 +474,6 @@ OpFoldResult AddOp::fold(ArrayRef operands) { auto resultTy = getType().dyn_cast(); if (!lhsTy || !rhsTy || !resultTy) return {}; - if (lhsTy != rhsTy) - return {}; auto resultETy = resultTy.getElementType(); auto lhsAttr = operands[0].dyn_cast_or_null(); @@ -504,6 +502,9 @@ OpFoldResult AddOp::fold(ArrayRef operands) { if (!lhsAttr || !rhsAttr) return {}; + if (lhsTy != rhsTy) + return {}; + return binaryFolder, std::plus>(lhsAttr, rhsAttr, lhsTy); } @@ -635,8 +636,6 @@ OpFoldResult SubOp::fold(ArrayRef operands) { auto resultTy = getType().dyn_cast(); if (!lhsTy || !rhsTy || !resultTy) return {}; - if (lhsTy != rhsTy) - return {}; auto resultETy = resultTy.getElementType(); auto lhsAttr = operands[0].dyn_cast_or_null(); @@ -655,6 +654,9 @@ OpFoldResult SubOp::fold(ArrayRef operands) { if (!lhsAttr || !rhsAttr) return {}; + if (lhsTy != rhsTy) + return {}; + return binaryFolder, std::minus>(lhsAttr, rhsAttr, lhsTy); } diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 08115787db58a..87f47cfc3b815 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -164,6 +164,28 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> { // ----- +// CHECK-LABEL: @fold_add_zero_splat_different_shape_f32 +func.func @fold_add_zero_splat_different_shape_f32(%arg0: tensor<1x10xf32>) -> tensor<1x10xf32> { + %zero = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %add = "tosa.add"(%arg0, %zero) : (tensor<1x10xf32>, tensor<1x1xf32>) -> tensor<1x10xf32> + // CHECK: return %arg0 + return %add : tensor<1x10xf32> +} + +// ----- + +// CHECK-LABEL: @fold_add_zero_broadcast_arg_f32 +func.func @fold_add_zero_broadcast_arg_f32(%arg0: tensor<1x10xf32>) -> tensor<4x10xf32> { + %zero = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<4x10xf32> + %add = "tosa.add"(%arg0, %zero) : (tensor<1x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> + // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<4x10xf32> + // CHECK: %[[ADD:.+]] = "tosa.add"(%arg0, %[[ZERO]]) : (tensor<1x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> + // CHECK: return %[[ADD]] : tensor<4x10xf32> + return %add : tensor<4x10xf32> +} + +// ----- + // CHECK-LABEL: @fold_div_zero_lhs_i32 func.func @fold_div_zero_lhs_i32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {value = dense<0> : tensor} : () -> tensor @@ -350,6 +372,16 @@ func.func @fold_sub_splat_f32() -> tensor<10xf32> { // ----- +// CHECK-LABEL: @fold_sub_zero_splat_different_shape_f32 +func.func @fold_sub_zero_splat_different_shape_f32(%arg0: tensor<1x10xf32>) -> tensor<1x10xf32> { + %zero = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %sub = "tosa.sub"(%arg0, %zero) : (tensor<1x10xf32>, tensor<1x1xf32>) -> tensor<1x10xf32> + // CHECK: return %arg0 + return %sub : tensor<1x10xf32> +} + +// ----- + // CHECK-LABEL: @fold_greater_splat_f32 func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { %0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>