From 351076f87935cc6a05d4238c68a5c07d89ab28d7 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Mon, 27 Mar 2023 10:43:50 +0100 Subject: [PATCH 1/3] Cleanup: replace comments that were not generalized, use doxygen-style comments --- .../mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h | 10 +++++----- mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h index 5ab0c8bca523b..6582bc31fc92e 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h @@ -20,26 +20,26 @@ namespace mlir { namespace tosa { -// Transform a tensor with the given transformation function. +/// Transform a tensor with the given transformation function. DenseElementsAttr applyElementWise( const DenseElementsAttr &toTransform, const std::function &toApply); -/// Function that checks if arg is a dense TOSA constant float tensor +/// Function that checks if \p toCheck is a dense TOSA constant float tensor. LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, TosaOp location, PatternRewriter &); -/// Function that checks if arg is a dense TOSA constant tensor +/// Function that checks if \p toCheck is a dense TOSA constant tensor. LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue toCheck, TosaOp location, PatternRewriter &); -/// Function that checks if the contained type is float +/// Function that checks if the type contained in \p toCheck is float. LogicalResult notifyIfNotFloat(TypedValue toCheck, TosaOp location, PatternRewriter &); -/// Function to compute the reciprocal +/// Function to compute the reciprocal. APFloat computeReciprocal(const APFloat &, Type); } // namespace tosa diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp index 3188acb29948e..c4e0b6e9fa2a1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp @@ -34,11 +34,11 @@ DenseElementsAttr mlir::tosa::applyElementWise( // all of them to avoid dynamic resizing transformedValues.reserve(toTransform.getNumElements()); for (auto val : toTransform.getValues()) { - auto recipVal = toApply(val, toTransform.getElementType()); - transformedValues.push_back(recipVal); + auto transformedVal = toApply(val, toTransform.getElementType()); + transformedValues.push_back(transformedVal); } - // Replace the current tensor with one containing the computed reciprocals + // Replace the current tensor with one containing the computed values auto newTensor = DenseElementsAttr::get(toTransform.getType(), transformedValues); return newTensor; From a8712f6ddfd8a562e59fd60aad8e3300c7beefab Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Mon, 27 Mar 2023 12:07:47 +0100 Subject: [PATCH 2/3] Implement folding for the power operation. Use the TOSA spec for POW ([0]+[1]). [0] https://www.mlplatform.org/tosa/tosa_spec.html#_pow [1] https://www.mlplatform.org/tosa/tosa_spec.html#_main_inference_profile tosa.pow can be applied to tensors of different shapes, in which case broadcasting is applied. Implement TOSA broadcasting helpers as specified here [2] as well. [2] https://www.mlplatform.org/tosa/tosa_spec.html#_tensor_access_helpers --- .../mlir/Dialect/Tosa/Transforms/Passes.h | 2 + .../Dialect/Tosa/Transforms/TosaFoldCommon.h | 31 +++++ .../Dialect/Tosa/Transforms/CMakeLists.txt | 1 + .../Tosa/Transforms/TosaFoldCommon.cpp | 87 ++++++++++++++ .../Tosa/Transforms/TosaFoldConstantPow.cpp | 110 ++++++++++++++++++ .../TosaLayerwiseConstantFoldPass.cpp | 1 + mlir/test/Dialect/Tosa/constant-pow-opt.mlir | 92 +++++++++++++++ 7 files changed, 324 insertions(+) create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp create mode 100644 mlir/test/Dialect/Tosa/constant-pow-opt.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 47d6932c15887..895a1b00d2eef 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -29,6 +29,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns); +void populateTosaFoldConstantPowPatterns(MLIRContext *ctx, + RewritePatternSet &patterns); void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaFoldConstantRSQRTPatterns(MLIRContext *ctx, diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h index 6582bc31fc92e..3ec24ba91e7e6 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h @@ -13,6 +13,7 @@ #define MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H #include +#include #include #include #include @@ -20,11 +21,24 @@ namespace mlir { namespace tosa { +/// Type that represents tensor dimensions. +using DimensionType = ArrayRef; + +/// Type for tensor offsets. +using OffsetType = size_t; + /// Transform a tensor with the given transformation function. DenseElementsAttr applyElementWise( const DenseElementsAttr &toTransform, const std::function &toApply); +/// Apply the given transformation function on the elements of the given +/// tensors. If the input tensors do not match \p targetType, broadcasting is +/// applied. +DenseElementsAttr applyElementWise( + const DenseElementsAttr &, const DenseElementsAttr &, TensorType targetType, + const std::function &toApply); + /// Function that checks if \p toCheck is a dense TOSA constant float tensor. LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, TosaOp location, @@ -39,6 +53,23 @@ LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue toCheck, LogicalResult notifyIfNotFloat(TypedValue toCheck, TosaOp location, PatternRewriter &); +/// Compute the offset in \p shape which corresponds to the given \p index. +OffsetType indexToOffset(DimensionType shape, DimensionType index); + +/// Compute the index into \p shape which corresponds to the given \p offset. +SmallVector offsetToIndex(DimensionType shape, OffsetType offset); + +/// Given an \p index into \p desiredShape, compute the corresponding index into +/// \p toBeBroadcasted. +SmallVector getBroadcastedIndex(DimensionType desiredShape, + DimensionType toBeBroadcasted, + DimensionType index); +/// Given an \p offset into \p desiredShape, compute the corresponding offset +/// into \p toBeBroadcasted. +OffsetType getBroadcastedOffset(DimensionType desiredShape, + DimensionType toBeBroadcasted, + OffsetType offset); + /// Function to compute the reciprocal. APFloat computeReciprocal(const APFloat &, Type); diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index f321de297daaf..d75fdfd7c9d15 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaDecomposeConv2D.cpp TosaDecomposeDepthwise.cpp TosaFoldCommon.cpp + TosaFoldConstantPow.cpp TosaFoldConstantReciprocal.cpp TosaFoldConstantRSQRT.cpp TosaFoldConstantTranspose.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp index c4e0b6e9fa2a1..42c4023ad05ca 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp @@ -12,7 +12,9 @@ #include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include #include +#include #include #include #include @@ -44,6 +46,42 @@ DenseElementsAttr mlir::tosa::applyElementWise( return newTensor; } +DenseElementsAttr mlir::tosa::applyElementWise( + const DenseElementsAttr &first, const DenseElementsAttr &second, + TensorType targetType, + const std::function &toApply) { + // Make sure to use the correct values in case broadcasting is required + SmallVector transformedValues; + // We already know the amount of values we will insert, reserve space for + // all of them to avoid dynamic resizing + auto targetSize = 1; + auto targetShape = targetType.getShape(); + for (const auto &dimSize : targetShape) { + targetSize *= dimSize; + } + transformedValues.reserve(targetSize); + + // Apply the given function to each pair of values from the input tensors. + // Make sure to broadcast the offsets properly. + auto firstIt = first.getValues(); + auto firstShape = first.getType().getShape(); + auto secondIt = second.getValues(); + auto secondShape = second.getType().getShape(); + for (auto offset = 0; offset < targetSize; offset++) { + OffsetType offsetInTargetFirst = + getBroadcastedOffset(targetShape, firstShape, offset); + OffsetType offsetInTargetSecond = + getBroadcastedOffset(targetShape, secondShape, offset); + auto res = + toApply(firstIt[offsetInTargetFirst], secondIt[offsetInTargetSecond]); + transformedValues.push_back(res); + } + + // Generate a tensor with the computed values. + auto newTensor = DenseElementsAttr::get(targetType, transformedValues); + return newTensor; +} + LogicalResult mlir::tosa::notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, TosaOp location, @@ -91,6 +129,55 @@ LogicalResult mlir::tosa::notifyIfNotFloat(TypedValue toCheck, "TOSA spec only allows floats"); } +OffsetType mlir::tosa::indexToOffset(DimensionType shape, DimensionType index) { + OffsetType offset = 0; + for (size_t i = 0; i < shape.size(); i++) { + offset = offset * shape[i] + index[i]; + } + return offset; +} + +SmallVector mlir::tosa::offsetToIndex(DimensionType shape, + OffsetType offset) { + auto rank = shape.size(); + // The rank of the index will be equal to the rank of the shape + SmallVector resultIndex; + resultIndex.reserve(rank); + // Compute all the index values from the last to the first one, reverse the + // vector afterwards as there is no convenient push_front. + for (int32_t i = rank - 1; i >= 0; i--) { + resultIndex.push_back(offset % shape[i]); + offset /= shape[i]; + } + std::reverse(resultIndex.begin(), resultIndex.end()); + return resultIndex; +} + +SmallVector +mlir::tosa::getBroadcastedIndex(DimensionType desiredShape, + DimensionType toBeBroadcasted, + DimensionType index) { + SmallVector broadCasted; + broadCasted.reserve(desiredShape.size()); + for (size_t i = 0; i < desiredShape.size(); i++) { + auto toInsert = 0; + if (toBeBroadcasted[i] == desiredShape[i]) { + toInsert = index[i]; + } + broadCasted.push_back(toInsert); + } + return broadCasted; +} + +OffsetType mlir::tosa::getBroadcastedOffset(DimensionType desiredShape, + DimensionType toBeBroadcasted, + OffsetType offset) { + auto indexInTarget = offsetToIndex(desiredShape, offset); + auto indexBroadcasted = + getBroadcastedIndex(desiredShape, toBeBroadcasted, indexInTarget); + return indexToOffset(toBeBroadcasted, indexBroadcasted); +} + APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, Type floatTy) { auto recipAttr = FloatAttr::get(floatTy, 1.0); APFloat recip = recipAttr.getValue(); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp new file mode 100644 index 0000000000000..e0d3f377b9e47 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp @@ -0,0 +1,110 @@ +//===- TosaFoldConstantPow.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold TOSA Pow operation on constant data +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaFoldConstantPow : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + static APFloat computePower(const APFloat &base, const APFloat &exp) { + // Propagate NaN + if (base.isNaN() || exp.isNaN()) { + return APFloat::getNaN(base.getSemantics()); + } + // TOSA defines 0.0**0.0 as NaN + if (base.isZero() && exp.isZero()) { + return APFloat::getNaN(base.getSemantics()); + } + // In case the value is negative, the exponent needs to be an integer + if (base.isNegative() && !base.isZero()) { + if (!exp.isInteger()) { + return APFloat::getNaN(base.getSemantics()); + } + } + + // Actually compute base**exp. Special cases for [-]infinity and [-]0 are + // already handled in accordance with the TOSA spec. + auto powFloat = std::pow(base.convertToFloat(), exp.convertToFloat()); + auto res = APFloat(powFloat); + + bool lostPrecision; + res.convert(base.getSemantics(), APFloat::rmNearestTiesToEven, + &lostPrecision); + return res; + } + + LogicalResult matchAndRewrite(PowOp powOp, + PatternRewriter &rewriter) const override { + auto baseOp = powOp.getInput1(); + auto expOp = powOp.getInput2(); + + // Check if both tensors are constant + auto baseIsConstCheck = + notifyIfNotConstantFloatTosaTensor(baseOp, powOp, rewriter); + if (failed(baseIsConstCheck)) { + return baseIsConstCheck; + } + auto expIsConstCheck = + notifyIfNotConstantFloatTosaTensor(expOp, powOp, rewriter); + if (failed(expIsConstCheck)) { + return expIsConstCheck; + } + + // Extract the tensor values + DenseElementsAttr baseValues; + matchPattern(baseOp, m_Constant(&baseValues)); + + DenseElementsAttr expValues; + matchPattern(expOp, m_Constant(&expValues)); + + // If both tensors are splat, we don't care for the number of users + if (!isa(baseValues) || + !isa(expValues)) { + // Make sure that at least one of the constant input tensors can be + // replaced (i.e. only has a single user) + if (!baseOp.hasOneUse() && !expOp.hasOneUse()) { + return rewriter.notifyMatchFailure( + powOp, "Currently, pows will only be folded if at least one input " + "tensor only has a single user"); + } + } + + auto newTensor = + applyElementWise(baseValues, expValues, powOp.getType(), &computePower); + rewriter.replaceOpWithNewOp(powOp, newTensor.getType(), newTensor); + + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaFoldConstantPowPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.add(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index a7bb2c19f9885..1288b1c7ade40 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass RewritePatternSet patterns(ctx); auto func = getOperation(); + mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns); diff --git a/mlir/test/Dialect/Tosa/constant-pow-opt.mlir b/mlir/test/Dialect/Tosa/constant-pow-opt.mlir new file mode 100644 index 0000000000000..66d997e7b6a37 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-pow-opt.mlir @@ -0,0 +1,92 @@ +// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s + +// CHECK-LABEL: @pow_fold_tiny +func.func @pow_fold_tiny() -> tensor { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.6{{0*}}e+01{{.*}}tensor + // CHECK-NOT: tosa.pow + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<4.0> : tensor} : () -> tensor + %1 = "tosa.const"() {value = dense<2.0> : tensor} : () -> tensor + %2 = "tosa.pow"(%0, %1) : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: @pow_fold_tensor +func.func @pow_fold_tensor() -> tensor<3xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.56{{0*}}e+02, 1.191410e+00, -3.099610e+00{{.*}}tensor<3xf16> + // CHECK-NOT: tosa.pow + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[4.0, 2.22, -3.1]> : tensor<3xf16>} : () -> tensor<3xf16> + %1 = "tosa.const"() {value = dense<[4.0, 0.22, 1.0]> : tensor<3xf16>} : () -> tensor<3xf16> + %2 = "tosa.pow"(%0, %1) : (tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16> + return %2 : tensor<3xf16> +} + +// CHECK-LABEL: @pow_fold_overflow +func.func @pow_fold_overflow() -> tensor<2xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7C00, 0xFC00{{.*}}tensor<2xf16> + // CHECK-NOT: tosa.pow + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[65500.0, -65500.0]> : tensor<2xf16>} : () -> tensor<2xf16> + %1 = "tosa.const"() {value = dense<[2.0, 3.0]> : tensor<2xf16>} : () -> tensor<2xf16> + %2 = "tosa.pow"(%0, %1) : (tensor<2xf16>, tensor<2xf16>) -> tensor<2xf16> + return %2 : tensor<2xf16> +} + +// CHECK-LABEL: @pow_fold_underflow +func.func @pow_fold_underflow() -> tensor<2xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}[0.0{{0*}}e+00, -0.0{{0*}}e+00{{.*}}tensor<2xf16> + // CHECK-NOT: tosa.pow + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[0.000001, -0.000001]> : tensor<2xf16>} : () -> tensor<2xf16> + %1 = "tosa.const"() {value = dense<[10.0, 9.0]> : tensor<2xf16>} : () -> tensor<2xf16> + %2 = "tosa.pow"(%0, %1) : (tensor<2xf16>, tensor<2xf16>) -> tensor<2xf16> + return %2 : tensor<2xf16> +} + +// CHECK-LABEL: @pow_fold_nan_cases +func.func @pow_fold_nan_cases() -> tensor<3xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<0x7FC00000>{{.*}}tensor<3xf32> + // CHECK-NOT: tosa.pow + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[0.0, -1.25, 0x7FC00000]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "tosa.const"() {value = dense<[0.0, 0.745, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %2 = "tosa.pow"(%0, %1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + return %2 : tensor<3xf32> +} + +// CHECK-LABEL: @pow_fold_tensor_broadcast_exp +func.func @pow_fold_tensor_broadcast_exp() -> tensor<3xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.6{{0*}}e+01, 4.929690e+00, 9.609370e+00{{.*}}tensor<3xf16> + // CHECK-NOT: tosa.pow + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[4.0, 2.22, -3.1]> : tensor<3xf16>} : () -> tensor<3xf16> + %1 = "tosa.const"() {value = dense<2.0> : tensor<1xf16>} : () -> tensor<1xf16> + %2 = "tosa.pow"(%0, %1) : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16> + return %2 : tensor<3xf16> +} + +// CHECK-LABEL: @pow_fold_tensor_broadcast_base +func.func @pow_fold_tensor_broadcast_base() -> tensor<3xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.6{{0*}}e+01, 4.660160e+00, 1.166380e-01{{.*}}tensor<3xf16> + // CHECK-NOT: tosa.pow + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[4.0, 2.22, -3.1]> : tensor<3xf16>} : () -> tensor<3xf16> + %1 = "tosa.const"() {value = dense<2.0> : tensor<1xf16>} : () -> tensor<1xf16> + %2 = "tosa.pow"(%1, %0) : (tensor<1xf16>, tensor<3xf16>) -> tensor<3xf16> + return %2 : tensor<3xf16> +} + +// CHECK-LABEL: @pow_fold_broadcast_two_dimensions +func.func @pow_fold_broadcast_two_dimensions() -> tensor<3x3xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const + // CHECK-SAME{LITERAL}: [[388.023529, 1.102940e+03, 2554.37329], + // CHECK-SAME{LITERAL}: [75281.1328, 538664.813, 0x4A1FF040], + // CHECK-SAME{LITERAL}: [24.2514629, 42.4044418, 66.4508896]] + // CHECK-NOT: tosa.pow + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[[4.0, 5.1, 6.2]]> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %1 = "tosa.const"() {value = dense<[[4.3], [8.1], [2.3]]> : tensor<3x1xf32>} : () -> tensor<3x1xf32> + %2 = "tosa.pow"(%0, %1) : (tensor<1x3xf32>, tensor<3x1xf32>) -> tensor<3x3xf32> + return %2 : tensor<3x3xf32> +} From 2f882995d1735d4612e46b5aaee00b4124d3ca6e Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Tue, 28 Mar 2023 08:15:03 +0100 Subject: [PATCH 3/3] Clean-up for TOSA pow folding Address review comments: * Add names to all function arguments in header * `toBeBroadcasted` -> `toBeBroadcastedShape` * Describe return values for broadcasting functions * Add shortcut for the offset computation if no broadcasting is needed --- .../Dialect/Tosa/Transforms/TosaFoldCommon.h | 21 +++++++++++-------- .../Tosa/Transforms/TosaFoldCommon.cpp | 16 ++++++++------ 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h index 3ec24ba91e7e6..e695b604a501b 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h @@ -36,22 +36,23 @@ DenseElementsAttr applyElementWise( /// tensors. If the input tensors do not match \p targetType, broadcasting is /// applied. DenseElementsAttr applyElementWise( - const DenseElementsAttr &, const DenseElementsAttr &, TensorType targetType, + const DenseElementsAttr &first, const DenseElementsAttr &second, + TensorType targetType, const std::function &toApply); /// Function that checks if \p toCheck is a dense TOSA constant float tensor. LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, TosaOp location, - PatternRewriter &); + PatternRewriter &rewriter); /// Function that checks if \p toCheck is a dense TOSA constant tensor. LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue toCheck, TosaOp location, - PatternRewriter &); + PatternRewriter &rewriter); /// Function that checks if the type contained in \p toCheck is float. LogicalResult notifyIfNotFloat(TypedValue toCheck, TosaOp location, - PatternRewriter &); + PatternRewriter &rewriter); /// Compute the offset in \p shape which corresponds to the given \p index. OffsetType indexToOffset(DimensionType shape, DimensionType index); @@ -60,18 +61,20 @@ OffsetType indexToOffset(DimensionType shape, DimensionType index); SmallVector offsetToIndex(DimensionType shape, OffsetType offset); /// Given an \p index into \p desiredShape, compute the corresponding index into -/// \p toBeBroadcasted. +/// \p toBeBroadcastedShape. +/// \returns broadcasted index into \p toBeBroadcastedShape. SmallVector getBroadcastedIndex(DimensionType desiredShape, - DimensionType toBeBroadcasted, + DimensionType toBeBroadcastedShape, DimensionType index); /// Given an \p offset into \p desiredShape, compute the corresponding offset -/// into \p toBeBroadcasted. +/// into \p toBeBroadcastedShape. +/// \returns broadcasted offset into \p toBeBroadcastedShape. OffsetType getBroadcastedOffset(DimensionType desiredShape, - DimensionType toBeBroadcasted, + DimensionType toBeBroadcastedShape, OffsetType offset); /// Function to compute the reciprocal. -APFloat computeReciprocal(const APFloat &, Type); +APFloat computeReciprocal(const APFloat &floatVal, Type floatTy); } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp index 42c4023ad05ca..385f8aa9c4fcc 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp @@ -12,9 +12,9 @@ #include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include #include #include +#include #include #include #include @@ -155,13 +155,13 @@ SmallVector mlir::tosa::offsetToIndex(DimensionType shape, SmallVector mlir::tosa::getBroadcastedIndex(DimensionType desiredShape, - DimensionType toBeBroadcasted, + DimensionType toBeBroadcastedShape, DimensionType index) { SmallVector broadCasted; broadCasted.reserve(desiredShape.size()); for (size_t i = 0; i < desiredShape.size(); i++) { auto toInsert = 0; - if (toBeBroadcasted[i] == desiredShape[i]) { + if (toBeBroadcastedShape[i] == desiredShape[i]) { toInsert = index[i]; } broadCasted.push_back(toInsert); @@ -170,12 +170,16 @@ mlir::tosa::getBroadcastedIndex(DimensionType desiredShape, } OffsetType mlir::tosa::getBroadcastedOffset(DimensionType desiredShape, - DimensionType toBeBroadcasted, + DimensionType toBeBroadcastedShape, OffsetType offset) { + // Simply return the offset if the shapes are equal. + if (desiredShape.equals(toBeBroadcastedShape)) { + return offset; + } auto indexInTarget = offsetToIndex(desiredShape, offset); auto indexBroadcasted = - getBroadcastedIndex(desiredShape, toBeBroadcasted, indexInTarget); - return indexToOffset(toBeBroadcasted, indexBroadcasted); + getBroadcastedIndex(desiredShape, toBeBroadcastedShape, indexInTarget); + return indexToOffset(toBeBroadcastedShape, indexBroadcasted); } APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, Type floatTy) {