From 987f02ec869fa73ab107a2f44d02c963b8784c50 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Tue, 9 May 2023 14:19:03 +0000 Subject: [PATCH 1/5] feat(TosaCanonicalizations): add reciprocal folding for constants --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 ++ .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 27 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 93178288dfc1b..a278c257c7779 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1057,6 +1057,8 @@ def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [ let results = (outs Tosa_Tensor:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 9016e41ee1b7e..86422e8d791f7 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -629,6 +629,30 @@ OpFoldResult MulOp::fold(ArrayRef operands) { return mulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift()); } +OpFoldResult ReciprocalOp::fold(ArrayRef operands) { + auto constantAttr = operands[0].dyn_cast_or_null(); + auto lhsTy = getInput1().getType().dyn_cast(); + + if (!lhsTy || !constantAttr) { + return {}; + } + + if (!constantAttr.isSplat()) { + return {}; + } + + auto floatVal = constantAttr.getSplatValue(); + + if (!floatVal.isFiniteNonZero()) { + return {}; + } + + auto recipAttr = FloatAttr::get(lhsTy.getElementType(), 1.0); + APFloat recip = recipAttr.getValue(); + recip.divide(floatVal, APFloat::rmNearestTiesToEven); + return DenseElementsAttr::get(lhsTy, recip); +} + OpFoldResult SubOp::fold(ArrayRef operands) { auto lhsTy = getInput1().getType().dyn_cast(); auto rhsTy = getInput2().getType().dyn_cast(); @@ -655,6 +679,9 @@ OpFoldResult SubOp::fold(ArrayRef operands) { if (!lhsAttr || !rhsAttr) return {}; + if (lhsTy != rhsTy) + return {}; + return binaryFolder, std::minus>(lhsAttr, rhsAttr, lhsTy); } From 3f33917b3cf7da5707b41179c85a42aeff656e77 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Tue, 9 May 2023 14:56:54 +0000 Subject: [PATCH 2/5] test(constant-op-fold.mlir): add reciprocal folding case where a 1/constant is folded into the constant --- mlir/test/Dialect/Tosa/constant-op-fold.mlir | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 08115787db58a..cf0479c3bb88e 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -306,6 +306,17 @@ func.func @fold_mul_splat_f32() -> tensor<10xf32> { // ----- +// CHECK-LABEL: @fold_reciprocal_splat_f32 +func.func @fold_reciprocal_splat_f32() -> tensor { + %half = "tosa.const"() {value = dense<0.5> : tensor} : () -> tensor + %recp = "tosa.reciprocal"(%half) : (tensor) -> tensor + // CHECK: %[[CST:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor} + // CHECK: return %[[CST]] + return %recp : tensor +} + +// ----- + // CHECK-LABEL: @fold_sub_zero_rhs_f32 func.func @fold_sub_zero_rhs_f32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor From c366cad9e71688fd74f24cc2bce0c2830335bab0 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Thu, 11 May 2023 09:21:00 +0000 Subject: [PATCH 3/5] fix(TosaCanonicalizations.cpp): remove accidental condition statement --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 86422e8d791f7..df42fd08e5381 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -679,9 +679,6 @@ OpFoldResult SubOp::fold(ArrayRef operands) { if (!lhsAttr || !rhsAttr) return {}; - if (lhsTy != rhsTy) - return {}; - return binaryFolder, std::minus>(lhsAttr, rhsAttr, lhsTy); } From 8914fd1e358ea8b9f74c37a2698d4dd88a1b7c91 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Thu, 11 May 2023 09:22:47 +0000 Subject: [PATCH 4/5] style(TosaCanonicalizations.cpp): use free-function version of dyn_cast --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index df42fd08e5381..7f3903e007ddd 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -630,8 +630,8 @@ OpFoldResult MulOp::fold(ArrayRef operands) { } OpFoldResult ReciprocalOp::fold(ArrayRef operands) { - auto constantAttr = operands[0].dyn_cast_or_null(); - auto lhsTy = getInput1().getType().dyn_cast(); + auto constantAttr = dyn_cast_or_null(operands[0]); + auto lhsTy = dyn_cast(getInput1().getType()); if (!lhsTy || !constantAttr) { return {}; From 17885ab5c5ce5d79c4ba6ac27176d125703a0943 Mon Sep 17 00:00:00 2001 From: Christopher McGirr Date: Thu, 11 May 2023 09:33:41 +0000 Subject: [PATCH 5/5] refactor(TosaCanonicalizations.cpp): use TOSA semantics for fp --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 4 ---- mlir/test/Dialect/Tosa/constant-op-fold.mlir | 12 ++++++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 7f3903e007ddd..6646c5146203d 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -643,10 +643,6 @@ OpFoldResult ReciprocalOp::fold(ArrayRef operands) { auto floatVal = constantAttr.getSplatValue(); - if (!floatVal.isFiniteNonZero()) { - return {}; - } - auto recipAttr = FloatAttr::get(lhsTy.getElementType(), 1.0); APFloat recip = recipAttr.getValue(); recip.divide(floatVal, APFloat::rmNearestTiesToEven); diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index cf0479c3bb88e..85c27f40d7ace 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -317,6 +317,18 @@ func.func @fold_reciprocal_splat_f32() -> tensor { // ----- +// CHECK-LABEL: @fold_reciprocal_splat_zero_f32 +func.func @fold_reciprocal_splat_zero_f32() -> tensor { + %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor + %recp = "tosa.reciprocal"(%zero) : (tensor) -> tensor + // 0x7F800000 represents +inf as we have computed 1/0 + // CHECK: %[[CST:.*]] = "tosa.const"() {value = dense<0x7F800000> : tensor} + // CHECK: return %[[CST]] + return %recp : tensor +} + +// ----- + // CHECK-LABEL: @fold_sub_zero_rhs_f32 func.func @fold_sub_zero_rhs_f32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor