From df366216692643fd31013fa71c08e1286cb33989 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Wed, 5 Apr 2023 11:13:16 +0100 Subject: [PATCH 1/3] Implement folding for constant tosa adds. * Add test case for add * Refactor: Put heuristic of whether to fold a binary operation or not in the common folding header. * Found new corner case: Fold binary operations even if the inputs have multiple uses, but only in the special case that these inputs are equal and those two are the only uses. --- .../mlir/Dialect/Tosa/Transforms/Passes.h | 2 + .../Dialect/Tosa/Transforms/TosaFoldCommon.h | 17 +- .../Dialect/Tosa/Transforms/CMakeLists.txt | 1 + .../Tosa/Transforms/TosaFoldCommon.cpp | 47 +++++- .../Tosa/Transforms/TosaFoldConstantAdd.cpp | 98 ++++++++++++ .../Tosa/Transforms/TosaFoldConstantPow.cpp | 23 ++- .../TosaLayerwiseConstantFoldPass.cpp | 1 + mlir/test/Dialect/Tosa/constant-add-opt.mlir | 148 ++++++++++++++++++ mlir/test/Dialect/Tosa/constant-pow-opt.mlir | 10 ++ 9 files changed, 330 insertions(+), 17 deletions(-) create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantAdd.cpp create mode 100644 mlir/test/Dialect/Tosa/constant-add-opt.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 3858836bb2b4b..82240ac531d5b 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -29,6 +29,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns); +void populateTosaFoldConstantAddPatterns(MLIRContext *ctx, + RewritePatternSet &patterns); void populateTosaFoldConstantCastPatterns(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaFoldConstantPowPatterns(MLIRContext *ctx, diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h index 8f61728af922a..c41044706fb52 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h @@ -40,10 +40,12 @@ DenseElementsAttr applyElementWise( /// Apply the given transformation function on the elements of the given /// tensors. If the input tensors do not match \p targetType, broadcasting is /// applied. +template DenseElementsAttr applyElementWise( const DenseElementsAttr &first, const DenseElementsAttr &second, TensorType targetType, - const std::function &toApply); + const std::function + &toApply); /// Function that checks if \p toCheck is a dense TOSA constant float tensor. LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, @@ -78,6 +80,19 @@ OffsetType getBroadcastedOffset(DimensionType desiredShape, DimensionType toBeBroadcastedShape, OffsetType offset); +/// Heuristic to decide when to replace a binary operation on constants with the +/// folded value. +/// Folding operations on constants can lead to an increased memory usage +/// whenever none of the inputs can be replaced but a new constant that is +/// inserted. Hence, this will currently only suggest folding when the memory +/// impact is negligible. +/// The \p binaryOp and the constant values of both operands, \p valuesFirst +/// and \p valuesSecond. +/// \returns Whether folding should be applied. +bool constantBinaryOpShouldBeFolded(TosaOp binaryOp, + DenseElementsAttr valuesFirst, + DenseElementsAttr valuesSecond); + /// Function to compute the reciprocal. APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy); diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index e15d5e9463e47..0804b6462aaac 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaDecomposeConv2D.cpp TosaDecomposeDepthwise.cpp TosaFoldCommon.cpp + TosaFoldConstantAdd.cpp TosaFoldConstantCast.cpp TosaFoldConstantPow.cpp TosaFoldConstantReciprocal.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp index 0e452a2463fcf..10dab3f7e7fd5 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp @@ -69,12 +69,14 @@ mlir::tosa::applyElementWise( const std::function &toApply, IntegerType targetType); +template DenseElementsAttr mlir::tosa::applyElementWise( const DenseElementsAttr &first, const DenseElementsAttr &second, TensorType targetType, - const std::function &toApply) { + const std::function + &toApply) { // Make sure to use the correct values in case broadcasting is required - SmallVector transformedValues; + SmallVector transformedValues; // We already know the amount of values we will insert, reserve space for // all of them to avoid dynamic resizing auto targetSize = 1; @@ -86,9 +88,9 @@ DenseElementsAttr mlir::tosa::applyElementWise( // Apply the given function to each pair of values from the input tensors. // Make sure to broadcast the offsets properly. - auto firstIt = first.getValues(); + auto firstIt = first.getValues(); auto firstShape = first.getType().getShape(); - auto secondIt = second.getValues(); + auto secondIt = second.getValues(); auto secondShape = second.getType().getShape(); for (auto offset = 0; offset < targetSize; offset++) { OffsetType offsetInTargetFirst = @@ -105,6 +107,17 @@ DenseElementsAttr mlir::tosa::applyElementWise( return newTensor; } +template DenseElementsAttr +mlir::tosa::applyElementWise( + const DenseElementsAttr &first, const DenseElementsAttr &second, + TensorType targetType, + const std::function &toApply); + +template DenseElementsAttr mlir::tosa::applyElementWise( + const DenseElementsAttr &first, const DenseElementsAttr &second, + TensorType targetType, + const std::function &toApply); + LogicalResult mlir::tosa::notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, TosaOp location, @@ -205,6 +218,32 @@ OffsetType mlir::tosa::getBroadcastedOffset(DimensionType desiredShape, return indexToOffset(toBeBroadcastedShape, indexBroadcasted); } +bool mlir::tosa::constantBinaryOpShouldBeFolded( + TosaOp binaryOp, DenseElementsAttr valuesFirst, + DenseElementsAttr valuesSecond) { + assert(binaryOp->getNumOperands() == 2); + auto firstOp = binaryOp->getOperand(0); + auto secondOp = binaryOp->getOperand(1); + + // If both tensors are splat, we don't care for the number of users + if (isa(valuesFirst) && + isa(valuesSecond)) { + return true; + } + + // If this is the only use of one of the tensors, it will be replaced an no + // additional memory is required. + if (firstOp.hasOneUse() || secondOp.hasOneUse()) { + return true; + } + + // Fold it both inputs are equal and those are the only uses. Don't fold + // otherwise. + auto numUsers = + std::distance(firstOp.getUses().begin(), firstOp.getUses().end()); + return firstOp == secondOp && numUsers == 2; +} + APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, FloatType floatTy) { auto recipAttr = FloatAttr::get(floatTy, 1.0); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantAdd.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantAdd.cpp new file mode 100644 index 0000000000000..c5b2dc3aca487 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantAdd.cpp @@ -0,0 +1,98 @@ +//===- TosaFoldConstantAdd.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold TOSA Add operation on constant data +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaFoldConstantAdd : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + static APInt computeIntAdd(const APInt &first, const APInt &second) { + return first.sadd_sat(second); + } + + static APFloat computeFloatAdd(const APFloat &first, const APFloat &second) { + return first + second; + } + + LogicalResult matchAndRewrite(AddOp addOp, + PatternRewriter &rewriter) const override { + auto leftOp = addOp.getInput1(); + auto rightOp = addOp.getInput2(); + + auto resultType = addOp.getType(); + auto lhsElemType = leftOp.getType().getElementType(); + auto rhsElemType = rightOp.getType().getElementType(); + if (lhsElemType != rhsElemType) { + return rewriter.notifyMatchFailure( + addOp, "Expected type of add arguments to match."); + } + + // Check if both tensors are constant + auto rhsIsConstantCheck = + notifyIfNoTosaDenseConstantTensor(leftOp, addOp, rewriter); + if (failed(rhsIsConstantCheck)) { + return rhsIsConstantCheck; + } + auto lhsIsConstantCheck = + notifyIfNoTosaDenseConstantTensor(rightOp, addOp, rewriter); + if (failed(lhsIsConstantCheck)) { + return lhsIsConstantCheck; + } + + // Extract the tensor values + DenseElementsAttr lhsValues; + matchPattern(leftOp, m_Constant(&lhsValues)); + + DenseElementsAttr rhsValues; + matchPattern(rightOp, m_Constant(&rhsValues)); + + if (!constantBinaryOpShouldBeFolded(addOp, lhsValues, rhsValues)) { + return rewriter.notifyMatchFailure( + addOp, "Currently, adds will only be folded if this requires only " + "little additional memory usage."); + } + + DenseElementsAttr newTensor; + if (isa(lhsElemType)) { + assert(isa(rhsElemType) && + isa(resultType.getElementType())); + newTensor = applyElementWise( + lhsValues, rhsValues, resultType, &computeIntAdd); + } else { + assert(isa(lhsElemType) && isa(rhsElemType) && + isa(resultType.getElementType())); + newTensor = applyElementWise( + lhsValues, rhsValues, resultType, &computeFloatAdd); + } + rewriter.replaceOpWithNewOp(addOp, newTensor.getType(), newTensor); + + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaFoldConstantAddPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.add(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp index e0d3f377b9e47..6cd3d1d393dfd 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp @@ -63,6 +63,11 @@ struct TosaFoldConstantPow : public OpRewritePattern { auto baseOp = powOp.getInput1(); auto expOp = powOp.getInput2(); + if (baseOp.getType().getElementType() != expOp.getType().getElementType()) { + return rewriter.notifyMatchFailure( + powOp, "Expected type of pow arguments to match."); + } + // Check if both tensors are constant auto baseIsConstCheck = notifyIfNotConstantFloatTosaTensor(baseOp, powOp, rewriter); @@ -82,20 +87,14 @@ struct TosaFoldConstantPow : public OpRewritePattern { DenseElementsAttr expValues; matchPattern(expOp, m_Constant(&expValues)); - // If both tensors are splat, we don't care for the number of users - if (!isa(baseValues) || - !isa(expValues)) { - // Make sure that at least one of the constant input tensors can be - // replaced (i.e. only has a single user) - if (!baseOp.hasOneUse() && !expOp.hasOneUse()) { - return rewriter.notifyMatchFailure( - powOp, "Currently, pows will only be folded if at least one input " - "tensor only has a single user"); - } + if (!constantBinaryOpShouldBeFolded(powOp, baseValues, expValues)) { + return rewriter.notifyMatchFailure( + powOp, "Currently, pows will only be folded if this requires only " + "little additional memory usage."); } - auto newTensor = - applyElementWise(baseValues, expValues, powOp.getType(), &computePower); + auto newTensor = applyElementWise( + baseValues, expValues, powOp.getType(), &computePower); rewriter.replaceOpWithNewOp(powOp, newTensor.getType(), newTensor); return success(); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index 96d3ac483472a..07d73dc0cfb05 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass RewritePatternSet patterns(ctx); auto func = getOperation(); + mlir::tosa::populateTosaFoldConstantAddPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns); diff --git a/mlir/test/Dialect/Tosa/constant-add-opt.mlir b/mlir/test/Dialect/Tosa/constant-add-opt.mlir new file mode 100644 index 0000000000000..0445857f60a4e --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-add-opt.mlir @@ -0,0 +1,148 @@ +// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s + +// ----- +// Float additions + +// CHECK-LABEL: @add_fold_float +func.func @add_fold_float() -> tensor<4xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-1.5{{.*}}e+02, 1.9{{.*}}e+00, 0.{{0*}}e+00, 5.{{0*}}e+00 + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17.4978, 4.9882, 0.0, -0.0]> : + tensor<4xf16> + } : () -> tensor<4xf16> + %1 = "tosa.const"() {value = + dense<[-132.7, -3.0, -0.0, 5.0]> : + tensor<4xf16> + } : () -> tensor<4xf16> + %2 = "tosa.add"(%0, %1) : (tensor<4xf16>, tensor<4xf16>) -> tensor<4xf16> + return %2 : tensor<4xf16> +} + +// CHECK-LABEL: @add_fold_float_infinity_nan +func.func @add_fold_float_infinity_nan() -> tensor<6xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x7FC00000 + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x7F800000]> : + tensor<6xf32> + } : () -> tensor<6xf32> + %1 = "tosa.const"() {value = + dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000]> : + tensor<6xf32> + } : () -> tensor<6xf32> + %2 = "tosa.add"(%0, %1) : (tensor<6xf32>, tensor<6xf32>) -> tensor<6xf32> + return %2 : tensor<6xf32> +} + +// CHECK-LABEL: @add_fold_float_overflow +func.func @add_fold_float_overflow() -> tensor<2xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0xFF800000 + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[3.1e+38, -3.1e+38]> : + tensor<2xf32> + } : () -> tensor<2xf32> + %1 = "tosa.const"() {value = + dense<[2.1e+38, -1.1e+38]> : + tensor<2xf32> + } : () -> tensor<2xf32> + %2 = "tosa.add"(%0, %1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %2 : tensor<2xf32> +} + +// ----- +// Int additions + +// CHECK-LABEL: @add_fold_int +func.func @add_fold_int() -> tensor<4xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-149, 1, 0, 5 + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0, 0]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %1 = "tosa.const"() {value = + dense<[-132, -3, 0, 5]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %2 = "tosa.add"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: @add_fold_int_overflow +func.func @add_fold_int_overflow() -> tensor<4xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2147483647, 2147483647, -2147483648, -2147483648 + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[2147483647, 2147483640, -2147483648, -2147483640]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %1 = "tosa.const"() {value = + dense<[1, 10, -1, -30]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %2 = "tosa.add"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %2 : tensor<4xi32> +} + +// ----- +// self-addition + +// CHECK-LABEL: @add_fold_equal_args +func.func @add_fold_equal_args() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-34, 8, 0 + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0]> : + tensor<3xi32> + } : () -> tensor<3xi32> + %2 = "tosa.add"(%0, %0) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + return %2 : tensor<3xi32> +} + +// ----- +// Broadcasted additions + +// CHECK-LABEL: @add_fold_int_broadcast_simple +func.func @add_fold_int_broadcast_simple() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-29, -8, -12 + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0]> : + tensor<3xi32> + } : () -> tensor<3xi32> + %1 = "tosa.const"() {value = + dense<-12> : + tensor<1xi32> + } : () -> tensor<1xi32> + %2 = "tosa.add"(%0, %1) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32> + return %2 : tensor<3xi32> +} + +// CHECK-LABEL: @add_fold_int_broadcast_complex +func.func @add_fold_int_broadcast_complex() -> tensor<3x3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const + // CHECK-SAME{LITERAL}: [[-29, -10, -13], + // CHECK-SAME{LITERAL}: [-11, 8, 5], + // CHECK-SAME{LITERAL}: [7, 26, 23]] + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[[-17], [1], [19]]> : + tensor<3x1xi32> + } : () -> tensor<3x1xi32> + %1 = "tosa.const"() {value = + dense<[[-12, 7, 4]]> : + tensor<1x3xi32> + } : () -> tensor<1x3xi32> + %2 = "tosa.add"(%0, %1) : (tensor<3x1xi32>, tensor<1x3xi32>) -> tensor<3x3xi32> + return %2 : tensor<3x3xi32> +} diff --git a/mlir/test/Dialect/Tosa/constant-pow-opt.mlir b/mlir/test/Dialect/Tosa/constant-pow-opt.mlir index 66d997e7b6a37..279f7206537a5 100644 --- a/mlir/test/Dialect/Tosa/constant-pow-opt.mlir +++ b/mlir/test/Dialect/Tosa/constant-pow-opt.mlir @@ -55,6 +55,16 @@ func.func @pow_fold_nan_cases() -> tensor<3xf32> { return %2 : tensor<3xf32> } +// CHECK-LABEL: @pow_fold_equal_args +func.func @pow_fold_equal_args() -> tensor<2xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.56{{0*}}e+02, 5.8 + // CHECK-NOT: tosa.pow + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[4.0, 2.22]> : tensor<2xf16>} : () -> tensor<2xf16> + %2 = "tosa.pow"(%0, %0) : (tensor<2xf16>, tensor<2xf16>) -> tensor<2xf16> + return %2 : tensor<2xf16> +} + // CHECK-LABEL: @pow_fold_tensor_broadcast_exp func.func @pow_fold_tensor_broadcast_exp() -> tensor<3xf16> { // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.6{{0*}}e+01, 4.929690e+00, 9.609370e+00{{.*}}tensor<3xf16> From d1f90a80346e284080018201c946c6fe53232a18 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Wed, 5 Apr 2023 15:27:37 +0100 Subject: [PATCH 2/3] Fix description --- .../mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h index c41044706fb52..adfde331f7e6e 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h @@ -83,11 +83,11 @@ OffsetType getBroadcastedOffset(DimensionType desiredShape, /// Heuristic to decide when to replace a binary operation on constants with the /// folded value. /// Folding operations on constants can lead to an increased memory usage -/// whenever none of the inputs can be replaced but a new constant that is -/// inserted. Hence, this will currently only suggest folding when the memory -/// impact is negligible. -/// The \p binaryOp and the constant values of both operands, \p valuesFirst -/// and \p valuesSecond. +/// whenever none of the inputs can be replaced but a new constant is inserted. +/// Hence, this will currently only suggest folding when the memory impact is +/// negligible. +/// Takes the \p binaryOp and the constant values of both operands, +/// \p valuesFirst and \p valuesSecond. /// \returns Whether folding should be applied. bool constantBinaryOpShouldBeFolded(TosaOp binaryOp, DenseElementsAttr valuesFirst, From 3d9e1d7225fda867bb1c84a3e938b9fd481bbfaa Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Thu, 6 Apr 2023 13:11:39 +0100 Subject: [PATCH 3/3] Remove overly general template argument Arguments of binary ops always have to have the tensor element types, so use a single template parameter instead of two separate ones for them. --- .../mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h | 4 ++-- mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp | 13 ++++++------- .../Dialect/Tosa/Transforms/TosaFoldConstantAdd.cpp | 6 +++--- .../Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp | 2 +- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h index adfde331f7e6e..67d856a95492e 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h @@ -40,11 +40,11 @@ DenseElementsAttr applyElementWise( /// Apply the given transformation function on the elements of the given /// tensors. If the input tensors do not match \p targetType, broadcasting is /// applied. -template +template DenseElementsAttr applyElementWise( const DenseElementsAttr &first, const DenseElementsAttr &second, TensorType targetType, - const std::function + const std::function &toApply); /// Function that checks if \p toCheck is a dense TOSA constant float tensor. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp index 10dab3f7e7fd5..f37ff880d1f20 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp @@ -69,11 +69,11 @@ mlir::tosa::applyElementWise( const std::function &toApply, IntegerType targetType); -template +template DenseElementsAttr mlir::tosa::applyElementWise( const DenseElementsAttr &first, const DenseElementsAttr &second, TensorType targetType, - const std::function + const std::function &toApply) { // Make sure to use the correct values in case broadcasting is required SmallVector transformedValues; @@ -88,9 +88,9 @@ DenseElementsAttr mlir::tosa::applyElementWise( // Apply the given function to each pair of values from the input tensors. // Make sure to broadcast the offsets properly. - auto firstIt = first.getValues(); + auto firstIt = first.getValues(); auto firstShape = first.getType().getShape(); - auto secondIt = second.getValues(); + auto secondIt = second.getValues(); auto secondShape = second.getType().getShape(); for (auto offset = 0; offset < targetSize; offset++) { OffsetType offsetInTargetFirst = @@ -107,13 +107,12 @@ DenseElementsAttr mlir::tosa::applyElementWise( return newTensor; } -template DenseElementsAttr -mlir::tosa::applyElementWise( +template DenseElementsAttr mlir::tosa::applyElementWise( const DenseElementsAttr &first, const DenseElementsAttr &second, TensorType targetType, const std::function &toApply); -template DenseElementsAttr mlir::tosa::applyElementWise( +template DenseElementsAttr mlir::tosa::applyElementWise( const DenseElementsAttr &first, const DenseElementsAttr &second, TensorType targetType, const std::function &toApply); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantAdd.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantAdd.cpp index c5b2dc3aca487..0fb6ad6f1d6d8 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantAdd.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantAdd.cpp @@ -76,12 +76,12 @@ struct TosaFoldConstantAdd : public OpRewritePattern { if (isa(lhsElemType)) { assert(isa(rhsElemType) && isa(resultType.getElementType())); - newTensor = applyElementWise( - lhsValues, rhsValues, resultType, &computeIntAdd); + newTensor = applyElementWise(lhsValues, rhsValues, + resultType, &computeIntAdd); } else { assert(isa(lhsElemType) && isa(rhsElemType) && isa(resultType.getElementType())); - newTensor = applyElementWise( + newTensor = applyElementWise( lhsValues, rhsValues, resultType, &computeFloatAdd); } rewriter.replaceOpWithNewOp(addOp, newTensor.getType(), newTensor); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp index 6cd3d1d393dfd..71edbec090f9d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp @@ -93,7 +93,7 @@ struct TosaFoldConstantPow : public OpRewritePattern { "little additional memory usage."); } - auto newTensor = applyElementWise( + auto newTensor = applyElementWise( baseValues, expValues, powOp.getType(), &computePower); rewriter.replaceOpWithNewOp(powOp, newTensor.getType(), newTensor);