diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index fa3a9e1a50d23..fed20da33afd2 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1856,7 +1856,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> { //===----------------------------------------------------------------------===// // Operator: cast //===----------------------------------------------------------------------===// -def Tosa_CastOp: Tosa_Op<"cast", [Pure, +def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape, DeclareOpInterfaceMethods]> { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 37f84a1b5424d..4d77e7b07521e 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1332,10 +1332,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> { // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<8.620690e-03> : tensor<3x600x1200xf32>}> : () -> tensor<3x600x1200xf32> // CHECK: return %[[VAL_0]] : tensor<3x600x1200xf32> // CHECK: } - %0 = "tosa.const"(){ value = dense<116.0>: tensor }: () -> tensor - %1 = "tosa.cast"(%0) : (tensor) -> tensor<3x600x1200xf32> - %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32> - return %2 : tensor<3x600x1200xf32> + %0 = "tosa.const"(){ value = dense<116.0>: tensor<3x600x1200xf32> }: () -> tensor<3x600x1200xf32> + %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xf32>) -> tensor<3x600x1200xf32> + return %1 : tensor<3x600x1200xf32> } // ----- @@ -1343,10 +1342,9 @@ func.func nested @fold_reciprocal() -> tensor<3x600x1200xf32> { // CHECK-LABEL: @do_not_fold_reciprocal_int func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> { // CHECK: tosa.reciprocal - %0 = "tosa.const"(){ value = dense<11>: tensor }: () -> tensor - %1 = "tosa.cast"(%0) : (tensor) -> tensor<3x600x1200xi32> - %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32> - return %2 : tensor<3x600x1200xi32> + %0 = "tosa.const"(){ value = dense<11>: tensor<3x600x1200xi32> }: () -> tensor<3x600x1200xi32> + %1 = "tosa.reciprocal"(%0): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32> + return %1 : tensor<3x600x1200xi32> } // -----