diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 93178288dfc1b..a278c257c7779 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1057,6 +1057,8 @@ def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [ let results = (outs Tosa_Tensor:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 9016e41ee1b7e..6646c5146203d 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -629,6 +629,26 @@ OpFoldResult MulOp::fold(ArrayRef operands) { return mulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift()); } +OpFoldResult ReciprocalOp::fold(ArrayRef operands) { + auto constantAttr = dyn_cast_or_null(operands[0]); + auto lhsTy = dyn_cast(getInput1().getType()); + + if (!lhsTy || !constantAttr) { + return {}; + } + + if (!constantAttr.isSplat()) { + return {}; + } + + auto floatVal = constantAttr.getSplatValue(); + + auto recipAttr = FloatAttr::get(lhsTy.getElementType(), 1.0); + APFloat recip = recipAttr.getValue(); + recip.divide(floatVal, APFloat::rmNearestTiesToEven); + return DenseElementsAttr::get(lhsTy, recip); +} + OpFoldResult SubOp::fold(ArrayRef operands) { auto lhsTy = getInput1().getType().dyn_cast(); auto rhsTy = getInput2().getType().dyn_cast(); diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 08115787db58a..85c27f40d7ace 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -306,6 +306,29 @@ func.func @fold_mul_splat_f32() -> tensor<10xf32> { // ----- +// CHECK-LABEL: @fold_reciprocal_splat_f32 +func.func @fold_reciprocal_splat_f32() -> tensor { + %half = "tosa.const"() {value = dense<0.5> : tensor} : () -> tensor + %recp = "tosa.reciprocal"(%half) : (tensor) -> tensor + // CHECK: %[[CST:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor} + // CHECK: return %[[CST]] + return %recp : tensor +} + +// ----- + +// CHECK-LABEL: @fold_reciprocal_splat_zero_f32 +func.func @fold_reciprocal_splat_zero_f32() -> tensor { + %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor + %recp = "tosa.reciprocal"(%zero) : (tensor) -> tensor + // 0x7F800000 represents +inf as we have computed 1/0 + // CHECK: %[[CST:.*]] = "tosa.const"() {value = dense<0x7F800000> : tensor} + // CHECK: return %[[CST]] + return %recp : tensor +} + +// ----- + // CHECK-LABEL: @fold_sub_zero_rhs_f32 func.func @fold_sub_zero_rhs_f32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor