From 6497e4aeacf8570107b62fa14c90ec35254e719a Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Thu, 20 Apr 2023 13:11:07 +0100 Subject: [PATCH 1/3] Implement constant clamp folding * Introduce global heuristic when to fold unary operators * Add folding constant clamps + test case --- .../mlir/Dialect/Tosa/Transforms/Passes.h | 2 + .../Dialect/Tosa/Transforms/TosaFoldCommon.h | 10 + .../Dialect/Tosa/Transforms/CMakeLists.txt | 1 + .../Tosa/Transforms/TosaFoldCommon.cpp | 15 ++ .../Tosa/Transforms/TosaFoldConstantClamp.cpp | 186 ++++++++++++++++++ .../Tosa/Transforms/TosaFoldConstantRSQRT.cpp | 5 +- .../Transforms/TosaFoldConstantReciprocal.cpp | 8 +- .../TosaLayerwiseConstantFoldPass.cpp | 1 + .../test/Dialect/Tosa/constant-clamp-opt.mlir | 66 +++++++ 9 files changed, 285 insertions(+), 9 deletions(-) create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp create mode 100644 mlir/test/Dialect/Tosa/constant-clamp-opt.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 0aa3e910da0fe..a6d6da13cd2c5 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -32,6 +32,8 @@ void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaFoldConstantAddPatterns(MLIRContext *ctx, RewritePatternSet &patterns); +void populateTosaFoldConstantClampPatterns(MLIRContext *ctx, + RewritePatternSet &patterns); void populateTosaFoldConstantCastPatterns(MLIRContext *ctx, RewritePatternSet &patterns, bool enableIntCastFolding); diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h index 67d856a95492e..e699272965a9c 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h @@ -93,6 +93,16 @@ bool constantBinaryOpShouldBeFolded(TosaOp binaryOp, DenseElementsAttr valuesFirst, DenseElementsAttr valuesSecond); +/// Heuristic to decide when to replace a unary operation on a constant with the +/// folded value. +/// Folding operations on constants can lead to an increased memory usage +/// whenever the input cannot be replaced but a new constant is inserted. Hence, +/// this will currently only suggest folding when the memory impact is +/// negligible. +/// Takes the \p unaryOp and the constant input \p values. +/// \returns Whether folding should be applied. +bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values); + /// Function to compute the reciprocal. APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy); diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index f02a19a0da1a2..1b6a3530d6a14 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaFoldCommon.cpp TosaFoldConstantAdd.cpp TosaFoldConstantCast.cpp + TosaFoldConstantClamp.cpp TosaFoldConstantPow.cpp TosaFoldConstantReciprocal.cpp TosaFoldConstantRSQRT.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp index ef392911e158d..e40ed044aaef4 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp @@ -243,6 +243,21 @@ bool mlir::tosa::constantBinaryOpShouldBeFolded( return firstOp == secondOp && numUsers == 2; } +bool mlir::tosa::constantUnaryOpShouldBeFolded(TosaOp unaryOp, + DenseElementsAttr values) { + assert(unaryOp->getNumOperands() == 1); + auto inputOp = unaryOp->getOperand(0); + + // If the input is a splat, we don't care for the number of users + if (isa(values)) { + return true; + } + + // If this is the only use of the tensors it will be replaced an no + // additional memory is required. + return inputOp.hasOneUse(); +} + APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, FloatType floatTy) { auto recipAttr = FloatAttr::get(floatTy, 1.0); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp new file mode 100644 index 0000000000000..6c7f3643849f1 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp @@ -0,0 +1,186 @@ +//===- TosaFoldConstantClamp.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold TOSA Clamp operation on constant data +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaFoldConstantClamp : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + static void + changeSemanticsLossless(APFloat &floatVal, + const llvm::fltSemantics *floatSemantics) { + bool losesInfo; + floatVal.convert(*floatSemantics, tosaRoundingMode, &losesInfo); + assert(!losesInfo); + } + + DenseElementsAttr applyClamp(DenseElementsAttr inputValues, + const APInt &lowerBound, const APInt &upperBound, + TensorType resultType) const { + + // Determine the width for the APInt comparison + auto comparisonWidth = + std::max(inputValues.getElementType().getIntOrFloatBitWidth(), + lowerBound.getBitWidth()); + + auto resultingIntType = cast(resultType.getElementType()); + + // Ensure that the value is larger than the lower bound + auto clampLower = [&lowerBound, &comparisonWidth](const APInt &val, + IntegerType type) { + auto clampedLower = llvm::APIntOps::smax( + val.sext(comparisonWidth), lowerBound.sext(comparisonWidth)); + // Make sure the output value has the correct type + assert(type.getWidth() >= clampedLower.getSignificantBits()); + return clampedLower.trunc(type.getWidth()); + }; + auto newTensor = applyElementWise( + inputValues, clampLower, resultingIntType); + + // Next, make sure the upper bound is adhered to + auto clampUpper = [&upperBound, &comparisonWidth](const APInt &val, + IntegerType type) { + auto clampedUpper = llvm::APIntOps::smin( + val.sext(comparisonWidth), upperBound.sext(comparisonWidth)); + assert(type.getWidth() >= clampedUpper.getSignificantBits()); + return clampedUpper.trunc(type.getWidth()); + }; + newTensor = applyElementWise( + newTensor, clampUpper, resultingIntType); + + return newTensor; + } + + DenseElementsAttr applyClamp(DenseElementsAttr inputValues, + APFloat lowerBound, APFloat upperBound, + TensorType resultType) const { + auto inputValType = cast(inputValues.getElementType()); + auto inputWidth = inputValType.getWidth(); + auto bWidth = APFloat::semanticsSizeInBits(lowerBound.getSemantics()); + auto *comparisonSem = inputWidth < bWidth + ? &lowerBound.getSemantics() + : &inputValType.getFloatSemantics(); + + changeSemanticsLossless(lowerBound, comparisonSem); + changeSemanticsLossless(upperBound, comparisonSem); + + auto resultingFloatType = cast(resultType.getElementType()); + + // Ensure that the value is larger than the lower bound + auto clampLower = [&lowerBound, &comparisonSem](APFloat val, + FloatType type) { + if (val.isNaN()) { + return APFloat::getNaN(type.getFloatSemantics()); + } + changeSemanticsLossless(val, comparisonSem); + auto clampedLower = val < lowerBound ? lowerBound : val; + changeSemanticsLossless(clampedLower, &type.getFloatSemantics()); + return clampedLower; + }; + auto newTensor = applyElementWise( + inputValues, clampLower, resultingFloatType); + + // Next, make sure the upper bound is adhered to + auto clampUpper = [&upperBound, &comparisonSem](APFloat val, + FloatType type) { + if (val.isNaN()) { + return APFloat::getNaN(type.getFloatSemantics()); + } + changeSemanticsLossless(val, comparisonSem); + auto clampedUpper = val < upperBound ? val : upperBound; + changeSemanticsLossless(clampedUpper, &type.getFloatSemantics()); + return clampedUpper; + }; + newTensor = applyElementWise( + newTensor, clampUpper, resultingFloatType); + + return newTensor; + } + + LogicalResult matchAndRewrite(ClampOp clampOp, + PatternRewriter &rewriter) const override { + auto valsToClamp = clampOp.getInput(); + auto inputElementType = valsToClamp.getType().getElementType(); + + // Check if the input is constant + if (failed(notifyIfNoTosaDenseConstantTensor(valsToClamp, clampOp, + rewriter))) { + return failure(); + } + + if (isa(inputElementType) && + cast(inputElementType).isUnsigned()) { + return rewriter.notifyMatchFailure( + clampOp, "Currently, unsigned integer clamps are unsupported."); + } + + // Extract the tensor values + DenseElementsAttr inputValues; + matchPattern(valsToClamp, m_Constant(&inputValues)); + + if (!constantUnaryOpShouldBeFolded(clampOp, inputValues)) { + return rewriter.notifyMatchFailure( + clampOp, + "Currently, clamps will only be folded if this requires only " + "little additional memory usage."); + } + + // Apply the clamp to all values of the int/float tensor + auto resultType = clampOp.getType(); + DenseElementsAttr newTensor; + if (isa(inputElementType)) { + auto lowerBoundVal = clampOp.getMinIntAttr().getValue(); + auto upperBoundVal = clampOp.getMaxIntAttr().getValue(); + assert(lowerBoundVal.getBitWidth() == upperBoundVal.getBitWidth()); + + newTensor = + applyClamp(inputValues, lowerBoundVal, upperBoundVal, resultType); + } else { + assert(isa(inputElementType)); + auto lowerBoundVal = clampOp.getMinFp(); + auto upperBoundVal = clampOp.getMaxFp(); + assert(APFloat::getSizeInBits(lowerBoundVal.getSemantics()) == + APFloat::getSizeInBits(upperBoundVal.getSemantics())); + + newTensor = + applyClamp(inputValues, lowerBoundVal, upperBoundVal, resultType); + } + + rewriter.replaceOpWithNewOp(clampOp, newTensor.getType(), + newTensor); + + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaFoldConstantClampPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.add(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp index bc18e25961e7b..6260e1560ded1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp @@ -63,9 +63,8 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern { DenseElementsAttr inputValues; matchPattern(inputTensor, m_Constant(&inputValues)); - // Only fold splat tensors and those used only once to avoid duplicating - // them. - if (!inputTensor.hasOneUse() && !isa(inputValues)) { + // Check whether this should be folded. + if (!constantUnaryOpShouldBeFolded(rsqrt, inputValues)) { return rewriter.notifyMatchFailure( rsqrt, "Currently, reciprocals will only be folded if the input " "tensor has a single user"); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp index 396c966c3259b..1bcc4ce8aee4e 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp @@ -45,12 +45,8 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern { DenseElementsAttr inputValues; matchPattern(inputTensor, m_Constant(&inputValues)); - // Our transformation replaces the input tensor with the transformed tensor. - // If the input has several users we need to keep the input. This can - // result in a significantly increased memory usage, such that we currently - // refrain from applying the transformation in that case. - // Allow this only for splat values, because the amount of data is small. - if (!inputTensor.hasOneUse() && !isa(inputValues)) { + // Check whether this should be folded. + if (!constantUnaryOpShouldBeFolded(recip, inputValues)) { return rewriter.notifyMatchFailure( recip, "Currently, reciprocals will only be folded if the input " "tensor has a single user"); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index be71b1ec2b37f..1f230b6b7d1e5 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -53,6 +53,7 @@ struct TosaLayerwiseConstantFoldPass mlir::tosa::populateTosaFoldConstantAddPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns, enableIntCastFolding); + mlir::tosa::populateTosaFoldConstantClampPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns); diff --git a/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir b/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir new file mode 100644 index 0000000000000..585370a767d3c --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s + +// Int clamp + +// CHECK-LABEL: @clamp_fold_integer +func.func @clamp_fold_integer() -> tensor<3xi16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-2, 0, 1{{.*}}tensor<3xi16> + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[-12, 0, 5]> : tensor<3xi16>} : () -> tensor<3xi16> + %1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 1 : i64, min_fp = 0.0 : f32, min_int = -2 : i64} + : (tensor<3xi16>) -> tensor<3xi16> + return %1 : tensor<3xi16> +} + +// CHECK-LABEL: @clamp_fold_integer_equal_lower_upper +func.func @clamp_fold_integer_equal_lower_upper() -> tensor<3xi8> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<17>{{.*}}tensor<3xi8> + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[2, 0, -5]> : tensor<3xi8>} : () -> tensor<3xi8> + %1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 17 : i64, min_fp = 0.0 : f32, min_int = 17 : i64} + : (tensor<3xi8>) -> tensor<3xi8> + return %1 : tensor<3xi8> +} + +// Float clamp + +// CHECK-LABEL: @clamp_fold_float +func.func @clamp_fold_float() -> tensor<3xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}-2.{{0*}}e+00, {{[8-9]}}.{{[0-9]*}}e-01, 1.{{0*}}e+00{{.*}}tensor<3xf16> + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[-12.4, 0.9, 5.2]> : tensor<3xf16>} : () -> tensor<3xf16> + %1 = "tosa.clamp"(%0) {max_fp = 1.00 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64} + : (tensor<3xf16>) -> tensor<3xf16> + return %1 : tensor<3xf16> +} + +// CHECK-LABEL: @clamp_fold_float_infty_nan +func.func @clamp_fold_float_infty_nan() -> tensor<5xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.{{0*}}e+00, -2.{{0*}}e+00, 0.{{0*}}e+00, -0.{{0*}}e+00, 0x7FC00000{{.*}}tensor<5xf32> + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7F800000, 0xFF800000, 0.0, -0.0, 0x7FC00000]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.clamp"(%0) {max_fp = 1.00 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64} + : (tensor<5xf32>) -> tensor<5xf32> + return %1 : tensor<5xf32> +} + +// CHECK-LABEL: @clamp_fold_float_infinity_upper +func.func @clamp_fold_float_infinity_upper() -> tensor<5xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, -2.{{0*}}e+00, 9.{{0*}}e+00, -0.{{0*}}e+00, 0x7FC00000{{.*}}tensor<5xf32> + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7F800000, 0xFF800000, 9.0, -0.0, 0x7FC00000]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.clamp"(%0) {max_fp = 0x7F800000 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64} + : (tensor<5xf32>) -> tensor<5xf32> + return %1 : tensor<5xf32> +} From 18c4a2ef58774db70c007ec6a516548bb988b7de Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 21 Apr 2023 12:58:35 +0100 Subject: [PATCH 2/3] Avoid multiple calls to applyElementWise * Merge lambdas that clamp to the upper and lower bound into a single one performing both * Add tests with clamp boundaries which cannot be represented in the type of the value to be clamped --- .../Tosa/Transforms/TosaFoldConstantClamp.cpp | 62 +++++++------------ .../test/Dialect/Tosa/constant-clamp-opt.mlir | 25 ++++++++ 2 files changed, 46 insertions(+), 41 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp index 6c7f3643849f1..4d20d51b4ac7c 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp @@ -17,7 +17,6 @@ #include "mlir/Pass/Pass.h" #include #include -#include #include #include #include @@ -47,31 +46,24 @@ struct TosaFoldConstantClamp : public OpRewritePattern { auto comparisonWidth = std::max(inputValues.getElementType().getIntOrFloatBitWidth(), lowerBound.getBitWidth()); + // Sign-extend the upper and lower bound + auto extUpperBound = upperBound.sext(comparisonWidth); + auto extLowerBound = lowerBound.sext(comparisonWidth); + // Determine the result type auto resultingIntType = cast(resultType.getElementType()); - // Ensure that the value is larger than the lower bound - auto clampLower = [&lowerBound, &comparisonWidth](const APInt &val, - IntegerType type) { - auto clampedLower = llvm::APIntOps::smax( - val.sext(comparisonWidth), lowerBound.sext(comparisonWidth)); - // Make sure the output value has the correct type - assert(type.getWidth() >= clampedLower.getSignificantBits()); - return clampedLower.trunc(type.getWidth()); + // Lambda to perform the clamp + auto clampUpper = [&extLowerBound, &extUpperBound, + &comparisonWidth](const APInt &val, IntegerType type) { + auto clampedUpper = + llvm::APIntOps::smin(val.sext(comparisonWidth), extUpperBound); + auto fullyClamped = llvm::APIntOps::smax(clampedUpper, extLowerBound); + assert(type.getWidth() >= fullyClamped.getSignificantBits()); + return fullyClamped.trunc(type.getWidth()); }; auto newTensor = applyElementWise( - inputValues, clampLower, resultingIntType); - - // Next, make sure the upper bound is adhered to - auto clampUpper = [&upperBound, &comparisonWidth](const APInt &val, - IntegerType type) { - auto clampedUpper = llvm::APIntOps::smin( - val.sext(comparisonWidth), upperBound.sext(comparisonWidth)); - assert(type.getWidth() >= clampedUpper.getSignificantBits()); - return clampedUpper.trunc(type.getWidth()); - }; - newTensor = applyElementWise( - newTensor, clampUpper, resultingIntType); + inputValues, clampUpper, resultingIntType); return newTensor; } @@ -91,34 +83,22 @@ struct TosaFoldConstantClamp : public OpRewritePattern { auto resultingFloatType = cast(resultType.getElementType()); - // Ensure that the value is larger than the lower bound - auto clampLower = [&lowerBound, &comparisonSem](APFloat val, - FloatType type) { + // Ensure that the value is larger than the lower bound and smaller than the + // upper bound + auto clampLower = [&lowerBound, &upperBound, + &comparisonSem](APFloat val, FloatType type) { if (val.isNaN()) { return APFloat::getNaN(type.getFloatSemantics()); } changeSemanticsLossless(val, comparisonSem); - auto clampedLower = val < lowerBound ? lowerBound : val; - changeSemanticsLossless(clampedLower, &type.getFloatSemantics()); - return clampedLower; + auto clampedUpper = val < upperBound ? val : upperBound; + auto fullyClamped = clampedUpper < lowerBound ? lowerBound : clampedUpper; + changeSemanticsLossless(fullyClamped, &type.getFloatSemantics()); + return fullyClamped; }; auto newTensor = applyElementWise( inputValues, clampLower, resultingFloatType); - // Next, make sure the upper bound is adhered to - auto clampUpper = [&upperBound, &comparisonSem](APFloat val, - FloatType type) { - if (val.isNaN()) { - return APFloat::getNaN(type.getFloatSemantics()); - } - changeSemanticsLossless(val, comparisonSem); - auto clampedUpper = val < upperBound ? val : upperBound; - changeSemanticsLossless(clampedUpper, &type.getFloatSemantics()); - return clampedUpper; - }; - newTensor = applyElementWise( - newTensor, clampUpper, resultingFloatType); - return newTensor; } diff --git a/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir b/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir index 585370a767d3c..276e87405e695 100644 --- a/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir +++ b/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir @@ -24,6 +24,17 @@ func.func @clamp_fold_integer_equal_lower_upper() -> tensor<3xi8> { return %1 : tensor<3xi8> } +// CHECK-LABEL: @clamp_fold_integer_maximum_larger_than_result_type +func.func @clamp_fold_integer_maximum_larger_than_result_type() -> tensor<3xi8> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}9, 4, 4{{.*}}tensor<3xi8> + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[9, 0, -5]> : tensor<3xi8>} : () -> tensor<3xi8> + %1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, min_int = 4 : i64} + : (tensor<3xi8>) -> tensor<3xi8> + return %1 : tensor<3xi8> +} + // Float clamp // CHECK-LABEL: @clamp_fold_float @@ -64,3 +75,17 @@ func.func @clamp_fold_float_infinity_upper() -> tensor<5xf32> { : (tensor<5xf32>) -> tensor<5xf32> return %1 : tensor<5xf32> } + +// CHECK-LABEL: @clamp_fold_float_maximum_larger_than_result_type +func.func @clamp_fold_float_maximum_larger_than_result_type() -> tensor<2xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.83{{[0-9]*}}e+01, -5.{{0*}}e-01 + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[18.32, -0.98747]> : + tensor<2xf16> + } : () -> tensor<2xf16> + %1 = "tosa.clamp"(%0) {max_fp = 3.4028234e+38 : f32, max_int = 1594 : i64, min_fp = -0.5 : f32, min_int = -17 : i64} + : (tensor<2xf16>) -> tensor<2xf16> + return %1 : tensor<2xf16> +} From 44349a7af131d152267ef613f0a1159a888d0197 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 21 Apr 2023 13:08:19 +0100 Subject: [PATCH 3/3] Refactor clamp function names --- .../Tosa/Transforms/TosaFoldConstantClamp.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp index 4d20d51b4ac7c..29f249ee75b1b 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp @@ -54,8 +54,8 @@ struct TosaFoldConstantClamp : public OpRewritePattern { auto resultingIntType = cast(resultType.getElementType()); // Lambda to perform the clamp - auto clampUpper = [&extLowerBound, &extUpperBound, - &comparisonWidth](const APInt &val, IntegerType type) { + auto clampFun = [&extLowerBound, &extUpperBound, + &comparisonWidth](const APInt &val, IntegerType type) { auto clampedUpper = llvm::APIntOps::smin(val.sext(comparisonWidth), extUpperBound); auto fullyClamped = llvm::APIntOps::smax(clampedUpper, extLowerBound); @@ -63,7 +63,7 @@ struct TosaFoldConstantClamp : public OpRewritePattern { return fullyClamped.trunc(type.getWidth()); }; auto newTensor = applyElementWise( - inputValues, clampUpper, resultingIntType); + inputValues, clampFun, resultingIntType); return newTensor; } @@ -85,8 +85,8 @@ struct TosaFoldConstantClamp : public OpRewritePattern { // Ensure that the value is larger than the lower bound and smaller than the // upper bound - auto clampLower = [&lowerBound, &upperBound, - &comparisonSem](APFloat val, FloatType type) { + auto clampFun = [&lowerBound, &upperBound, &comparisonSem](APFloat val, + FloatType type) { if (val.isNaN()) { return APFloat::getNaN(type.getFloatSemantics()); } @@ -97,7 +97,7 @@ struct TosaFoldConstantClamp : public OpRewritePattern { return fullyClamped; }; auto newTensor = applyElementWise( - inputValues, clampLower, resultingFloatType); + inputValues, clampFun, resultingFloatType); return newTensor; }