From 7c052a7c42fc6cf0731d86ca91de187a1a4d1f8c Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Tue, 2 May 2023 08:43:29 +0100 Subject: [PATCH 1/3] [mlir][tosa] Constant folding for reciprocal Add constant fold for tosa.reciprocal, which can be applied if the input is a dense constant tensor. The reciprocal is computed for every element and the result is a tensor with the same dimensions as the input tensor. As the input tensor might require a lot of memory and the folding might double the required memory, a heuristic decides when to actually apply the folding. Currently, the operation will be replaced only if the input constant is a splat (i.e. requires little memory) or has in single user (similar to the already existing fold for constant transposes). This keeps the additionally required space low. --- .../mlir/Dialect/Tosa/Transforms/Passes.h | 2 + .../Dialect/Tosa/Transforms/TosaFoldCommon.h | 60 ++++++++ .../Dialect/Tosa/Transforms/CMakeLists.txt | 2 + .../Tosa/Transforms/TosaFoldCommon.cpp | 113 +++++++++++++++ .../Transforms/TosaFoldConstantReciprocal.cpp | 79 ++++++++++ .../TosaLayerwiseConstantFoldPass.cpp | 1 + .../Tosa/constant-reciprocal-fold.mlir | 137 ++++++++++++++++++ 7 files changed, 394 insertions(+) create mode 100644 mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp create mode 100644 mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index d6ae78196f4cb..c81f59b3d5d36 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -30,6 +30,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns); +void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, + RewritePatternSet &patterns); void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx, RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h new file mode 100644 index 0000000000000..ecc168b7eaf86 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h @@ -0,0 +1,60 @@ +//===- TosaFoldCommon.h - Helper Functions for Folds ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper functions useful for various different TOSA constant folds. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H +#define MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H + +#include +#include +#include +#include + +namespace mlir { +namespace tosa { + +static constexpr llvm::RoundingMode tosaRoundingMode = + APFloat::rmNearestTiesToEven; + +/// Transform a tensor with the given transformation function. +template +DenseElementsAttr applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + TargetType targetType); + +/// Function that checks if \p toCheck is a dense TOSA constant float tensor. +LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, + TosaOp location, + PatternRewriter &rewriter); + +/// Function that checks if \p toCheck is a dense TOSA constant tensor. +LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue toCheck, + TosaOp location, + PatternRewriter &rewriter); + +/// Function that checks if the type contained in \p toCheck is float. +LogicalResult notifyIfNotFloat(TypedValue toCheck, TosaOp location, + PatternRewriter &rewriter); + +/// Heuristic to decide when to replace a unary operation on a constant with the +/// folded value. +/// Folding operations on constants can lead to an increased memory usage +/// whenever the input cannot be replaced but a new constant is inserted. Hence, +/// this will currently only suggest folding when the memory impact is +/// negligible. +/// Takes the \p unaryOp and the constant input \p values. +/// \returns Whether folding should be applied. +bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values); + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 4f5a54de0c734..05c502949761d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -2,6 +2,8 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaDecomposeTransposeConv.cpp TosaDecomposeConv2D.cpp TosaDecomposeDepthwise.cpp + TosaFoldCommon.cpp + TosaFoldConstantReciprocal.cpp TosaFoldConstantTranspose.cpp TosaInferShapes.cpp TosaLayerwiseConstantFoldPass.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp new file mode 100644 index 0000000000000..1adce501573a5 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp @@ -0,0 +1,113 @@ +//===- TosaFoldCommon.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper functions useful for various different TOSA constant folds. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::tosa; + +template +DenseElementsAttr mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + TargetType targetType) { + SmallVector transformedValues; + // We already know the amount of values we will insert, reserve space for + // all of them to avoid dynamic resizing + transformedValues.reserve(toTransform.getNumElements()); + for (auto val : toTransform.getValues()) { + auto transformedVal = toApply(val, targetType); + transformedValues.push_back(transformedVal); + } + + auto inShape = toTransform.getType(); + auto outTy = inShape.cloneWith({}, targetType); + + // Create a new tensor containing the computed values + return DenseElementsAttr::get(outTy, transformedValues); +} + +template DenseElementsAttr +mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + FloatType targetType); + +LogicalResult +mlir::tosa::notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, + TosaOp location, + PatternRewriter &rewriter) { + auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter); + if (failed(floatCheck)) { + return floatCheck; + } + return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter); +} + +LogicalResult +mlir::tosa::notifyIfNoTosaDenseConstantTensor(TypedValue toCheck, + TosaOp location, + PatternRewriter &rewriter) { + // Check whether the tensor is constant and dense + // TODO We currently ensure the tensor is dense by using the correct type for + // the bind_value, however we do not actually need this value. It would be + // nicer to only have a check here. + DenseElementsAttr tmp; + if (!matchPattern(toCheck, m_Constant(&tmp))) { + return rewriter.notifyMatchFailure(location, + "Non-const or non-dense input tensor"); + } + + // Make sure it actually is a TOSA constant (the match allows for other + // constants as well) + if (isa(toCheck.getDefiningOp())) { + return success(); + } + + return rewriter.notifyMatchFailure(location, + "The reciprocal can only be folded if " + "it operates on a TOSA constant"); +} + +LogicalResult mlir::tosa::notifyIfNotFloat(TypedValue toCheck, + TosaOp location, + PatternRewriter &rewriter) { + if (isa(toCheck.getType().getElementType())) { + return success(); + } + return rewriter.notifyMatchFailure(location, + "Unexpected input tensor type: the " + "TOSA spec only allows floats"); +} + +bool mlir::tosa::constantUnaryOpShouldBeFolded(TosaOp unaryOp, + DenseElementsAttr values) { + assert(unaryOp->getNumOperands() == 1); + auto inputOp = unaryOp->getOperand(0); + + // If the input is a splat, we don't care for the number of users + if (isa(values)) { + return true; + } + + // If this is the only use of the tensor it should be replaced as no + // additional memory is required + return inputOp.hasOneUse(); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp new file mode 100644 index 0000000000000..327213a186f55 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp @@ -0,0 +1,79 @@ +//===- TosaFoldConstantReciprocal.cpp -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold TOSA reciprocal operation on constant data +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/FloatingPointMode.h" +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaFoldConstantReciprocal : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + static APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy) { + auto recipAttr = FloatAttr::get(floatTy, 1.0); + APFloat recip = recipAttr.getValue(); + recip.divide(floatVal, tosaRoundingMode); + + return recip; + } + + LogicalResult matchAndRewrite(ReciprocalOp recip, + PatternRewriter &rewriter) const override { + auto inputTensor = recip.getInput1(); + + // Check that we can apply folding + auto preCondCheck = + notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter); + if (failed(preCondCheck)) { + return preCondCheck; + } + + // Extract the tensor values + DenseElementsAttr inputValues; + matchPattern(inputTensor, m_Constant(&inputValues)); + + // Check whether this should be folded. + if (!constantUnaryOpShouldBeFolded(recip, inputValues)) { + return rewriter.notifyMatchFailure( + recip, "Currently, reciprocals will only be folded if the input " + "tensor has a single user"); + } + + // Create a new tensor with the updated values + auto newTensor = applyElementWise( + inputValues, &computeReciprocal, + cast(inputValues.getElementType())); + + // Replace the use of the reciprocal with the transformed tensor + rewriter.replaceOpWithNewOp(recip, newTensor.getType(), newTensor); + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaFoldConstantReciprocalPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.add(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index a217f66cd84c6..2e2d338abbe4b 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass RewritePatternSet patterns(ctx); auto func = getOperation(); + mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns); populateTosaOpsCanonicalizationPatterns(ctx, patterns); diff --git a/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir b/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir new file mode 100644 index 0000000000000..cc71c43d53ce2 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir @@ -0,0 +1,137 @@ +// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s + +// CHECK-LABEL: @reciprocal_fold_single_valued +func.func @reciprocal_fold_single_valued() -> tensor { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.5{{0*}}e-01{{.*}}tensor + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<4.0> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_fold_splat +func.func @reciprocal_fold_splat() -> tensor<12x7xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.5{{0*}}e-01{{.*}}tensor<12x7xf32> + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<4.0> : tensor<12x7xf32>} : () -> tensor<12x7xf32> + %1 = "tosa.reciprocal"(%0) : (tensor<12x7xf32>) -> tensor<12x7xf32> + return %1 : tensor<12x7xf32> +} + +// CHECK-LABEL: @reciprocal_div_zero +func.func @reciprocal_div_zero() -> tensor { + // 0x7F800000 is the value for +infinity + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_neg_zero +func.func @reciprocal_div_neg_zero() -> tensor { + // 0xFF800000 is the value for -infinity + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0xFF800000 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<-0.0> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_nan +func.func @reciprocal_div_nan() -> tensor { + // 0x7FC00000 is the value for NAN + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7FC00000 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<0x7FC00000> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_infinity +func.func @reciprocal_div_infinity() -> tensor { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<0.{{0*}}e+00> + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<0x7F800000> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_neg_infinity +func.func @reciprocal_div_neg_infinity() -> tensor { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<-0.{{0*}}e+00> + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<0xFF800000> : tensor} : () -> tensor + %1 = "tosa.reciprocal"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @reciprocal_div_underflow +func.func @reciprocal_div_underflow() -> tensor<2xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-0.{{0*}}e+00, 0.{{0*}}e+00 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[-6.0e+15, 6.0e+15]> : tensor<2xf16>} : () -> tensor<2xf16> + %1 = "tosa.reciprocal"(%0) : (tensor<2xf16>) -> tensor<2xf16> + return %1 : tensor<2xf16> +} + +// CHECK-LABEL: @reciprocal_div_overflow +func.func @reciprocal_div_overflow() -> tensor<2xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7C00, 0xFC00 + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[0.0000001, -0.0000001]> : tensor<2xf16>} : () -> tensor<2xf16> + %1 = "tosa.reciprocal"(%0) : (tensor<2xf16>) -> tensor<2xf16> + return %1 : tensor<2xf16> +} + +// CHECK-LABEL: @reciprocal_no_fold +// The folding optimization works only intra-procedurally, so we won't be able +// to fold anything here +func.func @reciprocal_no_fold(%arg0: tensor) -> tensor { + // CHECK: tosa.reciprocal + // CHECK-NEXT: return + %0 = "tosa.reciprocal"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @reciprocal_fold +func.func @reciprocal_fold() -> tensor<4x6xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const + // CHECK-SAME{LITERAL}: [[5.68828249, 11.4416485, 1.6880486, 0.680272102, -0.875350117, 0.342313349], + // CHECK-SAME{LITERAL}: [-4.81231928, 0.698080301, 0.65432179, -82.6446304, -4.33651352, -0.747551739], + // CHECK-SAME{LITERAL}: [-12.4378109, 13.140605, 1.89501607, 0.885582745, 4.08830738, 1.4396776], + // CHECK-SAME{LITERAL}: [2.02880907, -1.53280187, 0.552730501, 7.15819644, 0.64495325, -0.973709881]] + // CHECK-NOT: tosa.reciprocal + // CHECK: return [[RES]] + %0 = "tosa.const"() { value = dense<[ + [ 0.1758, 0.0874, 0.5924, 1.4700, -1.1424, 2.9213], + [-0.2078, 1.4325, 1.5283, -0.0121, -0.2306, -1.3377], + [-0.0804, 0.0761, 0.5277, 1.1292, 0.2446, 0.6946], + [ 0.4929, -0.6524, 1.8092, 0.1397, 1.5505, -1.0270]]> + : tensor<4x6xf32> + } : () -> tensor<4x6xf32> + %1 = "tosa.reciprocal"(%0) : (tensor<4x6xf32>) -> tensor<4x6xf32> + return %1 : tensor<4x6xf32> +} + +// CHECK-LABEL: @reciprocal_of_const_sparse +// Sparse tensors are currently not supported +func.func @reciprocal_of_const_sparse() -> tensor<32xbf16> { + // CHECK: tosa.const + // CHECK: tosa.reciprocal + %0 = "tosa.const"() { value = sparse< + [[0], [3], [11], [17], [20], [23], [25], [30], [31]], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]> + : tensor<32xbf16> } : () -> tensor<32xbf16> + %1 = "tosa.reciprocal"(%0) : (tensor<32xbf16>) -> tensor<32xbf16> + return %1 : tensor<32xbf16> +} From 5090c363c36347f437abfb2a4ceaff58de64dce6 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Wed, 3 May 2023 08:17:19 +0100 Subject: [PATCH 2/3] Drop not very useful comment --- mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp index 1adce501573a5..0c6050921a17d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp @@ -40,7 +40,6 @@ DenseElementsAttr mlir::tosa::applyElementWise( auto inShape = toTransform.getType(); auto outTy = inShape.cloneWith({}, targetType); - // Create a new tensor containing the computed values return DenseElementsAttr::get(outTy, transformedValues); } From d3b06cd44640450d859c26a36e2e9a68e638f45d Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Thu, 11 May 2023 14:09:31 +0100 Subject: [PATCH 3/3] Cleanup constant folding helpers Use two header files to distinguish more broadly useful functions and ones that are only used very locally. --- .../mlir/Dialect/Tosa/Utils/FoldUtils.h | 41 ++++++++++++++++ .../Tosa/Transforms/TosaFoldCommon.cpp | 33 +------------ .../Dialect/Tosa/Transforms/TosaFoldCommon.h | 11 ----- .../Transforms/TosaFoldConstantReciprocal.cpp | 3 +- mlir/lib/Dialect/Utils/CMakeLists.txt | 1 + mlir/lib/Dialect/Utils/FoldUtils.cpp | 48 +++++++++++++++++++ 6 files changed, 94 insertions(+), 43 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Tosa/Utils/FoldUtils.h rename mlir/{include/mlir => lib}/Dialect/Tosa/Transforms/TosaFoldCommon.h (83%) create mode 100644 mlir/lib/Dialect/Utils/FoldUtils.cpp diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/FoldUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/FoldUtils.h new file mode 100644 index 0000000000000..b77e6614c7b28 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/Utils/FoldUtils.h @@ -0,0 +1,41 @@ +//===- FoldUtils.h - Helper Functions for Folds -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper functions useful for various different TOSA constant folds. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_TOSA_UTILS_FOLD_UTILS_H +#define MLIR_DIALECT_TOSA_UTILS_FOLD_UTILS_H + +#include +#include + +namespace mlir { +namespace tosa { + +/// Rounding mode to be used on floating point operations that require rounding. +static constexpr llvm::RoundingMode tosaRoundingMode = + llvm::APFloat::rmNearestTiesToEven; + +/// Apply the given transformation \p toApply to every element of the tensor to +/// be transformed \p toTransform. +/// +/// Elements of \p toTransform are extracted as \p SrcValueType. +/// +/// \returns A tensor with the same size as \p toTransform, containing +/// \p TargetValueType values of type \p TargetType. +template +DenseElementsAttr applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + TargetType targetType); + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_UTILS_FOLD_UTILS_H diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp index 0c6050921a17d..2864e411c3153 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp @@ -10,11 +10,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include -#include -#include +#include "TosaFoldCommon.h" +#include #include #include #include @@ -23,32 +20,6 @@ using namespace mlir; using namespace mlir::tosa; -template -DenseElementsAttr mlir::tosa::applyElementWise( - const DenseElementsAttr &toTransform, - const std::function &toApply, - TargetType targetType) { - SmallVector transformedValues; - // We already know the amount of values we will insert, reserve space for - // all of them to avoid dynamic resizing - transformedValues.reserve(toTransform.getNumElements()); - for (auto val : toTransform.getValues()) { - auto transformedVal = toApply(val, targetType); - transformedValues.push_back(transformedVal); - } - - auto inShape = toTransform.getType(); - auto outTy = inShape.cloneWith({}, targetType); - - return DenseElementsAttr::get(outTy, transformedValues); -} - -template DenseElementsAttr -mlir::tosa::applyElementWise( - const DenseElementsAttr &toTransform, - const std::function &toApply, - FloatType targetType); - LogicalResult mlir::tosa::notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, TosaOp location, diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.h similarity index 83% rename from mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h rename to mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.h index ecc168b7eaf86..912c566b1ec58 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.h @@ -13,23 +13,12 @@ #define MLIR_DIALECT_TOSA_TRANSFORMS_TOSA_FOLD_COMMON_H #include -#include #include #include namespace mlir { namespace tosa { -static constexpr llvm::RoundingMode tosaRoundingMode = - APFloat::rmNearestTiesToEven; - -/// Transform a tensor with the given transformation function. -template -DenseElementsAttr applyElementWise( - const DenseElementsAttr &toTransform, - const std::function &toApply, - TargetType targetType); - /// Function that checks if \p toCheck is a dense TOSA constant float tensor. LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, TosaOp location, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp index 327213a186f55..d77b4398eeb0f 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp @@ -10,9 +10,10 @@ // //===----------------------------------------------------------------------===// +#include "TosaFoldCommon.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" +#include "mlir/Dialect/Tosa/Utils/FoldUtils.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt index 7d40caebe1e05..685788aeb0074 100644 --- a/mlir/lib/Dialect/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Utils/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRDialectUtils + FoldUtils.cpp IndexingUtils.cpp ReshapeOpsUtils.cpp StructuredOpsUtils.cpp diff --git a/mlir/lib/Dialect/Utils/FoldUtils.cpp b/mlir/lib/Dialect/Utils/FoldUtils.cpp new file mode 100644 index 0000000000000..fe5754cf498cc --- /dev/null +++ b/mlir/lib/Dialect/Utils/FoldUtils.cpp @@ -0,0 +1,48 @@ +//===- FoldUtils.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper functions useful for various different TOSA constant folds. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Utils/FoldUtils.h" + +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::tosa; + +template +DenseElementsAttr mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + TargetType targetType) { + SmallVector transformedValues; + // We already know the amount of values we will insert, reserve space for + // all of them to avoid dynamic resizing + transformedValues.reserve(toTransform.getNumElements()); + for (auto val : toTransform.getValues()) { + auto transformedVal = toApply(val, targetType); + transformedValues.push_back(transformedVal); + } + + // Make sure that the output tensor has the expected output type + auto inShape = toTransform.getType(); + auto outTy = inShape.cloneWith({}, targetType); + + return DenseElementsAttr::get(outTy, transformedValues); +} + +template DenseElementsAttr +mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + FloatType targetType);