diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 3331ca4cb8643..e3e4f8e41d92a 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -399,6 +399,7 @@ def Tosa_ClampOp : Tosa_ElemWiseUnaryOp<"clamp"> { ); let hasCanonicalizer = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 636e0deb18a0e..530e738f57e6a 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -374,6 +374,19 @@ void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +OpFoldResult ClampOp::fold(FoldAdaptor adaptor) { + // TODO: This can generalize to any cast (or other operation) + // where the output values are within a computable range. + if (auto cast = llvm::dyn_cast_or_null(getInput().getDefiningOp())) { + if (cast.getType().getElementType().isF32() && + cast.getInput().getType().getElementType().isInteger(8) && + getMinFp().convertToFloat() <= -128.0 && + getMaxFp().convertToFloat() >= 127.0) + return getInput(); + } + return {}; +} + struct ConcatSliceOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 31227698e09bf..774e838ac9459 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -140,6 +140,31 @@ func.func @clamp_maximum_f32(%arg0: tensor<4xf32>) -> tensor<4xf32> { return %1 : tensor<4xf32> } +// CHECK-LABEL: @cast_clamp +func.func @cast_clamp(%arg0: tensor<4xi8>) -> tensor<4xf32> { + // CHECK-NOT: tosa.clamp + %0 = "tosa.cast"(%arg0) : (tensor<4xi8>) -> tensor<4xf32> + %1 = "tosa.clamp"(%0) {min_fp = -129.0 : f32, max_fp = 200.0 : f32, min_int = -2 : i64, max_int = 4 : i64} : (tensor<4xf32>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: @cast_clamp2 +func.func @cast_clamp2(%arg0: tensor<4xi8>) -> tensor<4xf32> { + // CHECK-NOT: tosa.clamp + %0 = "tosa.cast"(%arg0) : (tensor<4xi8>) -> tensor<4xf32> + %1 = "tosa.clamp"(%0) {min_fp = -128.0 : f32, max_fp = 127.0 : f32, min_int = -2 : i64, max_int = 4 : i64} : (tensor<4xf32>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + +// CHECK-LABEL: @cast_clamp_negative +func.func @cast_clamp_negative(%arg0: tensor<4xi8>) -> tensor<4xf32> { + // CHECK: tosa.clamp + %0 = "tosa.cast"(%arg0) : (tensor<4xi8>) -> tensor<4xf32> + %1 = "tosa.clamp"(%0) {min_fp = -5.0 : f32, max_fp = 3.0 : f32, min_int = -2 : i64, max_int = 4 : i64} : (tensor<4xf32>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + + // CHECK-LABEL: @concat_fold_zero func.func @concat_fold_zero(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // CHECK: "tosa.concat"(%arg1, %arg2) <{axis = 1 : i64}>