diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 0aa3e910da0fe..a6d6da13cd2c5 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -32,6 +32,8 @@ void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaFoldConstantAddPatterns(MLIRContext *ctx, RewritePatternSet &patterns); +void populateTosaFoldConstantClampPatterns(MLIRContext *ctx, + RewritePatternSet &patterns); void populateTosaFoldConstantCastPatterns(MLIRContext *ctx, RewritePatternSet &patterns, bool enableIntCastFolding); diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h index 67d856a95492e..e699272965a9c 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h @@ -93,6 +93,16 @@ bool constantBinaryOpShouldBeFolded(TosaOp binaryOp, DenseElementsAttr valuesFirst, DenseElementsAttr valuesSecond); +/// Heuristic to decide when to replace a unary operation on a constant with the +/// folded value. +/// Folding operations on constants can lead to an increased memory usage +/// whenever the input cannot be replaced but a new constant is inserted. Hence, +/// this will currently only suggest folding when the memory impact is +/// negligible. +/// Takes the \p unaryOp and the constant input \p values. +/// \returns Whether folding should be applied. +bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values); + /// Function to compute the reciprocal. APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy); diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index f02a19a0da1a2..1b6a3530d6a14 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaFoldCommon.cpp TosaFoldConstantAdd.cpp TosaFoldConstantCast.cpp + TosaFoldConstantClamp.cpp TosaFoldConstantPow.cpp TosaFoldConstantReciprocal.cpp TosaFoldConstantRSQRT.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp index ef392911e158d..e40ed044aaef4 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp @@ -243,6 +243,21 @@ bool mlir::tosa::constantBinaryOpShouldBeFolded( return firstOp == secondOp && numUsers == 2; } +bool mlir::tosa::constantUnaryOpShouldBeFolded(TosaOp unaryOp, + DenseElementsAttr values) { + assert(unaryOp->getNumOperands() == 1); + auto inputOp = unaryOp->getOperand(0); + + // If the input is a splat, we don't care for the number of users + if (isa(values)) { + return true; + } + + // If this is the only use of the tensors it will be replaced an no + // additional memory is required. + return inputOp.hasOneUse(); +} + APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, FloatType floatTy) { auto recipAttr = FloatAttr::get(floatTy, 1.0); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp new file mode 100644 index 0000000000000..29f249ee75b1b --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp @@ -0,0 +1,166 @@ +//===- TosaFoldConstantClamp.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold TOSA Clamp operation on constant data +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaFoldConstantClamp : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + static void + changeSemanticsLossless(APFloat &floatVal, + const llvm::fltSemantics *floatSemantics) { + bool losesInfo; + floatVal.convert(*floatSemantics, tosaRoundingMode, &losesInfo); + assert(!losesInfo); + } + + DenseElementsAttr applyClamp(DenseElementsAttr inputValues, + const APInt &lowerBound, const APInt &upperBound, + TensorType resultType) const { + + // Determine the width for the APInt comparison + auto comparisonWidth = + std::max(inputValues.getElementType().getIntOrFloatBitWidth(), + lowerBound.getBitWidth()); + // Sign-extend the upper and lower bound + auto extUpperBound = upperBound.sext(comparisonWidth); + auto extLowerBound = lowerBound.sext(comparisonWidth); + + // Determine the result type + auto resultingIntType = cast(resultType.getElementType()); + + // Lambda to perform the clamp + auto clampFun = [&extLowerBound, &extUpperBound, + &comparisonWidth](const APInt &val, IntegerType type) { + auto clampedUpper = + llvm::APIntOps::smin(val.sext(comparisonWidth), extUpperBound); + auto fullyClamped = llvm::APIntOps::smax(clampedUpper, extLowerBound); + assert(type.getWidth() >= fullyClamped.getSignificantBits()); + return fullyClamped.trunc(type.getWidth()); + }; + auto newTensor = applyElementWise( + inputValues, clampFun, resultingIntType); + + return newTensor; + } + + DenseElementsAttr applyClamp(DenseElementsAttr inputValues, + APFloat lowerBound, APFloat upperBound, + TensorType resultType) const { + auto inputValType = cast(inputValues.getElementType()); + auto inputWidth = inputValType.getWidth(); + auto bWidth = APFloat::semanticsSizeInBits(lowerBound.getSemantics()); + auto *comparisonSem = inputWidth < bWidth + ? &lowerBound.getSemantics() + : &inputValType.getFloatSemantics(); + + changeSemanticsLossless(lowerBound, comparisonSem); + changeSemanticsLossless(upperBound, comparisonSem); + + auto resultingFloatType = cast(resultType.getElementType()); + + // Ensure that the value is larger than the lower bound and smaller than the + // upper bound + auto clampFun = [&lowerBound, &upperBound, &comparisonSem](APFloat val, + FloatType type) { + if (val.isNaN()) { + return APFloat::getNaN(type.getFloatSemantics()); + } + changeSemanticsLossless(val, comparisonSem); + auto clampedUpper = val < upperBound ? val : upperBound; + auto fullyClamped = clampedUpper < lowerBound ? lowerBound : clampedUpper; + changeSemanticsLossless(fullyClamped, &type.getFloatSemantics()); + return fullyClamped; + }; + auto newTensor = applyElementWise( + inputValues, clampFun, resultingFloatType); + + return newTensor; + } + + LogicalResult matchAndRewrite(ClampOp clampOp, + PatternRewriter &rewriter) const override { + auto valsToClamp = clampOp.getInput(); + auto inputElementType = valsToClamp.getType().getElementType(); + + // Check if the input is constant + if (failed(notifyIfNoTosaDenseConstantTensor(valsToClamp, clampOp, + rewriter))) { + return failure(); + } + + if (isa(inputElementType) && + cast(inputElementType).isUnsigned()) { + return rewriter.notifyMatchFailure( + clampOp, "Currently, unsigned integer clamps are unsupported."); + } + + // Extract the tensor values + DenseElementsAttr inputValues; + matchPattern(valsToClamp, m_Constant(&inputValues)); + + if (!constantUnaryOpShouldBeFolded(clampOp, inputValues)) { + return rewriter.notifyMatchFailure( + clampOp, + "Currently, clamps will only be folded if this requires only " + "little additional memory usage."); + } + + // Apply the clamp to all values of the int/float tensor + auto resultType = clampOp.getType(); + DenseElementsAttr newTensor; + if (isa(inputElementType)) { + auto lowerBoundVal = clampOp.getMinIntAttr().getValue(); + auto upperBoundVal = clampOp.getMaxIntAttr().getValue(); + assert(lowerBoundVal.getBitWidth() == upperBoundVal.getBitWidth()); + + newTensor = + applyClamp(inputValues, lowerBoundVal, upperBoundVal, resultType); + } else { + assert(isa(inputElementType)); + auto lowerBoundVal = clampOp.getMinFp(); + auto upperBoundVal = clampOp.getMaxFp(); + assert(APFloat::getSizeInBits(lowerBoundVal.getSemantics()) == + APFloat::getSizeInBits(upperBoundVal.getSemantics())); + + newTensor = + applyClamp(inputValues, lowerBoundVal, upperBoundVal, resultType); + } + + rewriter.replaceOpWithNewOp(clampOp, newTensor.getType(), + newTensor); + + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaFoldConstantClampPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.add(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp index bc18e25961e7b..6260e1560ded1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp @@ -63,9 +63,8 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern { DenseElementsAttr inputValues; matchPattern(inputTensor, m_Constant(&inputValues)); - // Only fold splat tensors and those used only once to avoid duplicating - // them. - if (!inputTensor.hasOneUse() && !isa(inputValues)) { + // Check whether this should be folded. + if (!constantUnaryOpShouldBeFolded(rsqrt, inputValues)) { return rewriter.notifyMatchFailure( rsqrt, "Currently, reciprocals will only be folded if the input " "tensor has a single user"); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp index 396c966c3259b..1bcc4ce8aee4e 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp @@ -45,12 +45,8 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern { DenseElementsAttr inputValues; matchPattern(inputTensor, m_Constant(&inputValues)); - // Our transformation replaces the input tensor with the transformed tensor. - // If the input has several users we need to keep the input. This can - // result in a significantly increased memory usage, such that we currently - // refrain from applying the transformation in that case. - // Allow this only for splat values, because the amount of data is small. - if (!inputTensor.hasOneUse() && !isa(inputValues)) { + // Check whether this should be folded. + if (!constantUnaryOpShouldBeFolded(recip, inputValues)) { return rewriter.notifyMatchFailure( recip, "Currently, reciprocals will only be folded if the input " "tensor has a single user"); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index be71b1ec2b37f..1f230b6b7d1e5 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -53,6 +53,7 @@ struct TosaLayerwiseConstantFoldPass mlir::tosa::populateTosaFoldConstantAddPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns, enableIntCastFolding); + mlir::tosa::populateTosaFoldConstantClampPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns); diff --git a/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir b/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir new file mode 100644 index 0000000000000..276e87405e695 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir @@ -0,0 +1,91 @@ +// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s + +// Int clamp + +// CHECK-LABEL: @clamp_fold_integer +func.func @clamp_fold_integer() -> tensor<3xi16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-2, 0, 1{{.*}}tensor<3xi16> + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[-12, 0, 5]> : tensor<3xi16>} : () -> tensor<3xi16> + %1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 1 : i64, min_fp = 0.0 : f32, min_int = -2 : i64} + : (tensor<3xi16>) -> tensor<3xi16> + return %1 : tensor<3xi16> +} + +// CHECK-LABEL: @clamp_fold_integer_equal_lower_upper +func.func @clamp_fold_integer_equal_lower_upper() -> tensor<3xi8> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<17>{{.*}}tensor<3xi8> + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[2, 0, -5]> : tensor<3xi8>} : () -> tensor<3xi8> + %1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 17 : i64, min_fp = 0.0 : f32, min_int = 17 : i64} + : (tensor<3xi8>) -> tensor<3xi8> + return %1 : tensor<3xi8> +} + +// CHECK-LABEL: @clamp_fold_integer_maximum_larger_than_result_type +func.func @clamp_fold_integer_maximum_larger_than_result_type() -> tensor<3xi8> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}9, 4, 4{{.*}}tensor<3xi8> + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[9, 0, -5]> : tensor<3xi8>} : () -> tensor<3xi8> + %1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, min_int = 4 : i64} + : (tensor<3xi8>) -> tensor<3xi8> + return %1 : tensor<3xi8> +} + +// Float clamp + +// CHECK-LABEL: @clamp_fold_float +func.func @clamp_fold_float() -> tensor<3xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-2.{{0*}}e+00, {{[8-9]}}.{{[0-9]*}}e-01, 1.{{0*}}e+00{{.*}}tensor<3xf16> + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[-12.4, 0.9, 5.2]> : tensor<3xf16>} : () -> tensor<3xf16> + %1 = "tosa.clamp"(%0) {max_fp = 1.00 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64} + : (tensor<3xf16>) -> tensor<3xf16> + return %1 : tensor<3xf16> +} + +// CHECK-LABEL: @clamp_fold_float_infty_nan +func.func @clamp_fold_float_infty_nan() -> tensor<5xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.{{0*}}e+00, -2.{{0*}}e+00, 0.{{0*}}e+00, -0.{{0*}}e+00, 0x7FC00000{{.*}}tensor<5xf32> + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7F800000, 0xFF800000, 0.0, -0.0, 0x7FC00000]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.clamp"(%0) {max_fp = 1.00 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64} + : (tensor<5xf32>) -> tensor<5xf32> + return %1 : tensor<5xf32> +} + +// CHECK-LABEL: @clamp_fold_float_infinity_upper +func.func @clamp_fold_float_infinity_upper() -> tensor<5xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, -2.{{0*}}e+00, 9.{{0*}}e+00, -0.{{0*}}e+00, 0x7FC00000{{.*}}tensor<5xf32> + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7F800000, 0xFF800000, 9.0, -0.0, 0x7FC00000]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.clamp"(%0) {max_fp = 0x7F800000 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64} + : (tensor<5xf32>) -> tensor<5xf32> + return %1 : tensor<5xf32> +} + +// CHECK-LABEL: @clamp_fold_float_maximum_larger_than_result_type +func.func @clamp_fold_float_maximum_larger_than_result_type() -> tensor<2xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.83{{[0-9]*}}e+01, -5.{{0*}}e-01 + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[18.32, -0.98747]> : + tensor<2xf16> + } : () -> tensor<2xf16> + %1 = "tosa.clamp"(%0) {max_fp = 3.4028234e+38 : f32, max_int = 1594 : i64, min_fp = -0.5 : f32, min_int = -17 : i64} + : (tensor<2xf16>) -> tensor<2xf16> + return %1 : tensor<2xf16> +}