From 6b9f7c78482de1a2731998674156811692f36962 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 31 Mar 2023 15:23:01 +0100 Subject: [PATCH 1/7] Implement folding for constant TOSA casts Fold casts on constant tosa tensors. * Generalize unary op element-wise function application * Add test --- .../mlir/Dialect/Tosa/Transforms/Passes.h | 2 + .../Dialect/Tosa/Transforms/TosaFoldCommon.h | 9 +- .../Dialect/Tosa/Transforms/CMakeLists.txt | 1 + .../Tosa/Transforms/TosaFoldCommon.cpp | 50 +++- .../Tosa/Transforms/TosaFoldConstantCast.cpp | 171 +++++++++++++ .../Tosa/Transforms/TosaFoldConstantRSQRT.cpp | 7 +- .../Transforms/TosaFoldConstantReciprocal.cpp | 4 +- .../TosaLayerwiseConstantFoldPass.cpp | 1 + mlir/test/Dialect/Tosa/constant-cast-opt.mlir | 242 ++++++++++++++++++ 9 files changed, 469 insertions(+), 18 deletions(-) create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp create mode 100644 mlir/test/Dialect/Tosa/constant-cast-opt.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 895a1b00d2eef..3858836bb2b4b 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -29,6 +29,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns); +void populateTosaFoldConstantCastPatterns(MLIRContext *ctx, + RewritePatternSet &patterns); void populateTosaFoldConstantPowPatterns(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h index e695b604a501b..8f61728af922a 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h @@ -27,10 +27,15 @@ using DimensionType = ArrayRef; /// Type for tensor offsets. using OffsetType = size_t; +static constexpr llvm::RoundingMode tosaRoundingMode = + APFloat::rmNearestTiesToEven; + /// Transform a tensor with the given transformation function. +template DenseElementsAttr applyElementWise( const DenseElementsAttr &toTransform, - const std::function &toApply); + const std::function &toApply, + TargetType targetType); /// Apply the given transformation function on the elements of the given /// tensors. If the input tensors do not match \p targetType, broadcasting is @@ -74,7 +79,7 @@ OffsetType getBroadcastedOffset(DimensionType desiredShape, OffsetType offset); /// Function to compute the reciprocal. -APFloat computeReciprocal(const APFloat &floatVal, Type floatTy); +APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy); } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index d75fdfd7c9d15..e15d5e9463e47 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaDecomposeConv2D.cpp TosaDecomposeDepthwise.cpp TosaFoldCommon.cpp + TosaFoldConstantCast.cpp TosaFoldConstantPow.cpp TosaFoldConstantReciprocal.cpp TosaFoldConstantRSQRT.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp index 385f8aa9c4fcc..0e452a2463fcf 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp @@ -23,29 +23,52 @@ using namespace mlir; using namespace mlir::tosa; -namespace { -static constexpr llvm::RoundingMode reciprocalRoundingMode = - APFloat::rmNearestTiesToEven; -} // namespace - +template DenseElementsAttr mlir::tosa::applyElementWise( const DenseElementsAttr &toTransform, - const std::function &toApply) { - llvm::SmallVector transformedValues; + const std::function &toApply, + TargetType targetType) { + SmallVector transformedValues; // We already know the amount of values we will insert, reserve space for // all of them to avoid dynamic resizing transformedValues.reserve(toTransform.getNumElements()); - for (auto val : toTransform.getValues()) { - auto transformedVal = toApply(val, toTransform.getElementType()); + for (auto val : toTransform.getValues()) { + auto transformedVal = toApply(val, targetType); transformedValues.push_back(transformedVal); } + auto inShape = toTransform.getType(); + auto outTy = inShape.cloneWith(None, targetType); + // Replace the current tensor with one containing the computed values - auto newTensor = - DenseElementsAttr::get(toTransform.getType(), transformedValues); + auto newTensor = DenseElementsAttr::get(outTy, transformedValues); return newTensor; } +template DenseElementsAttr +mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + FloatType targetType); + +template DenseElementsAttr +mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + FloatType targetType); + +template DenseElementsAttr +mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + IntegerType targetType); + +template DenseElementsAttr +mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + IntegerType targetType); + DenseElementsAttr mlir::tosa::applyElementWise( const DenseElementsAttr &first, const DenseElementsAttr &second, TensorType targetType, @@ -182,10 +205,11 @@ OffsetType mlir::tosa::getBroadcastedOffset(DimensionType desiredShape, return indexToOffset(toBeBroadcastedShape, indexBroadcasted); } -APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, Type floatTy) { +APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, + FloatType floatTy) { auto recipAttr = FloatAttr::get(floatTy, 1.0); APFloat recip = recipAttr.getValue(); - recip.divide(floatVal, reciprocalRoundingMode); + recip.divide(floatVal, tosaRoundingMode); return recip; } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp new file mode 100644 index 0000000000000..5bb49c6a6f03b --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp @@ -0,0 +1,171 @@ +//===- TosaFoldConstantCast.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold TOSA cast operation on constant data +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaFoldConstantCast : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + static APFloat convertIntToFloat(const APInt &toConvert, + FloatType targetType) { + APFloat res(targetType.getFloatSemantics()); + res.convertFromAPInt(toConvert, true, tosaRoundingMode); + return res; + } + + static APFloat convertFloatToFloat(const APFloat &toConvert, + FloatType targetType) { + APFloat res(toConvert); + bool didLosePrecision; + res.convert(targetType.getFloatSemantics(), tosaRoundingMode, + &didLosePrecision); + return res; + } + + static APInt convertFloatToInt(const APFloat &toConvert, + IntegerType targetType) { + auto targetWidth = targetType.getIntOrFloatBitWidth(); + // Converting NaN to an integer results in an unpredictable value. Pick 0. + if (toConvert.isNaN()) { + return APInt::getZero(targetWidth); + } + + // Make sure to properly translate booleans + if (targetWidth == 1) { + return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); + } + + // Use the built-in functionality of APFloats to convert to integers. + // The result of this conversion should be an integer which might still be + // outside of the target integer range. + auto floatSize = APFloat::getSizeInBits(toConvert.getSemantics()); + APSInt converted(std::max(floatSize, targetWidth), false); + bool ignored = false; + toConvert.convertToInteger(converted, APFloat::rmNearestTiesToEven, + &ignored); + // Clip to allowed range. + if (targetWidth < floatSize) { + return converted.truncSSat(targetWidth); + } + return converted; + } + + static APInt convertIntToInt(const APInt &toConvert, IntegerType targetType) { + // Make sure to properly translate booleans + if (targetType.getWidth() == 1) { + return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); + } + return toConvert.sextOrTrunc(targetType.getIntOrFloatBitWidth()); + } + + static void warnAboutNaNToIntCast(DenseElementsAttr elements, CastOp location, + PatternRewriter &rewriter) { + // This is only relevant if the input values are float + if (!isa(elements.getElementType())) { + return; + } + // Check if it is an float to integer conversion + auto resultType = location.getOutput().getType(); + if (!isa(cast(resultType).getElementType())) { + return; + } + + for (auto val : elements.getValues()) { + // Report encountered NaNs once + if (val.isNaN()) { + location->emitWarning( + "Float tensor is casted to integer and it contains NaN values. The " + "cast results in an unspecified value."); + return; + } + } + } + + LogicalResult matchAndRewrite(CastOp tosaCast, + PatternRewriter &rewriter) const override { + auto inputTensor = tosaCast.getInput(); + + // If the input tensor is not constant, we cannot fold it. + auto isDenseConst = + notifyIfNoTosaDenseConstantTensor(inputTensor, tosaCast, rewriter); + if (failed(isDenseConst)) { + return isDenseConst; + } + + auto fromType = inputTensor.getType().getElementType(); + auto toType = tosaCast.getOutput().getType().getElementType(); + + DenseElementsAttr elements; + matchPattern(inputTensor, m_Constant(&elements)); + + // Issue a warning if we convert float -> int and NaNs are present; the + // result value is unspecified in that case + warnAboutNaNToIntCast(elements, tosaCast, rewriter); + + // Only fold splat tensors and those used only once to avoid duplicating + // them. + if (!inputTensor.hasOneUse() && !isa(elements)) { + return rewriter.notifyMatchFailure(tosaCast, + "Currently, casts will only be folded " + "if its input only has a single user"); + } + + DenseElementsAttr res; + if (auto intOutTy = dyn_cast(toType)) { + if (isa(fromType)) { + res = applyElementWise( + elements, &convertFloatToInt, intOutTy); + } else { + assert(isa(fromType)); + res = applyElementWise( + elements, &convertIntToInt, intOutTy); + } + } else { + assert(isa(toType)); + auto floatOutTy = cast(toType); + if (isa(fromType)) { + res = applyElementWise( + elements, &convertFloatToFloat, floatOutTy); + } else { + assert(isa(fromType)); + res = applyElementWise( + elements, &convertIntToFloat, floatOutTy); + } + } + + rewriter.replaceOpWithNewOp(tosaCast, res.getType(), res); + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaFoldConstantCastPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.add(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp index 5913bd2c51006..bc18e25961e7b 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp @@ -32,7 +32,7 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - static APFloat computeRSQRT(const APFloat &apFloatVal, Type floatTy) { + static APFloat computeRSQRT(const APFloat &apFloatVal, FloatType floatTy) { // The result for negative values (apart from zero) is always NaN if (apFloatVal.isNegative() && !apFloatVal.isNegZero()) { return APFloat::getNaN(apFloatVal.getSemantics()); @@ -72,7 +72,9 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern { } // Create a new tensor with the updated values - auto newTensor = applyElementWise(inputValues, &computeRSQRT); + auto newTensor = applyElementWise( + inputValues, &computeRSQRT, + cast(inputValues.getElementType())); // Replace the use of the reciprocal with the transformed tensor rewriter.replaceOpWithNewOp(rsqrt, newTensor.getType(), newTensor); @@ -84,6 +86,7 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern { } // namespace void mlir::tosa::populateTosaFoldConstantRSQRTPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { patterns.add(ctx); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp index 2f7335dc0477f..396c966c3259b 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp @@ -57,7 +57,9 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern { } // Create a new tensor with the updated values - auto newTensor = applyElementWise(inputValues, &computeReciprocal); + auto newTensor = applyElementWise( + inputValues, &computeReciprocal, + cast(inputValues.getElementType())); // Replace the use of the reciprocal with the transformed tensor rewriter.replaceOpWithNewOp(recip, newTensor.getType(), newTensor); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index 1288b1c7ade40..96d3ac483472a 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass RewritePatternSet patterns(ctx); auto func = getOperation(); + mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns); diff --git a/mlir/test/Dialect/Tosa/constant-cast-opt.mlir b/mlir/test/Dialect/Tosa/constant-cast-opt.mlir new file mode 100644 index 0000000000000..7a262d37a8289 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-cast-opt.mlir @@ -0,0 +1,242 @@ +// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s + +// ----- +// Casts from float to int + +// CHECK-LABEL: @cast_fold_f32_to_i1_all_none_zero +func.func @cast_fold_f32_to_i1_all_none_zero() -> tensor<3xi1> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true{{.*}}tensor<3xi1> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[12.0, 4.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "tosa.cast"(%0) : (tensor<3xf32>) -> tensor<3xi1> + return %1 : tensor<3xi1> +} + +// CHECK-LABEL: @cast_fold_f32_to_i1 +func.func @cast_fold_f32_to_i1() -> tensor<3xi1> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true, false, true{{.*}}tensor<3xi1> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[12.0, 0.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "tosa.cast"(%0) : (tensor<3xf32>) -> tensor<3xi1> + return %1 : tensor<3xi1> +} + +// CHECK-LABEL: @cast_fold_f32_to_i32 +func.func @cast_fold_f32_to_i32() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 4, 5{{.*}}tensor<3xi32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[12.0, 4.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "tosa.cast"(%0) : (tensor<3xf32>) -> tensor<3xi32> + return %1 : tensor<3xi32> +} + +// CHECK-LABEL: @cast_fold_f32_to_i16 +func.func @cast_fold_f32_to_i16() -> tensor<5xi16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 0, 5, 32767, -32768{{.*}}tensor<5xi16> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12.0, 0.0, 5.0, 32770.11, -32770.11]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xi16> + return %1 : tensor<5xi16> +} + +// CHECK-LABEL: @cast_fold_f16_to_i32 +func.func @cast_fold_f16_to_i32() -> tensor<6xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 14, 0, 5, 277, -278{{.*}}tensor<6xi32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12.5, 13.5, 0.0, 5.0, 277.11, -277.71]> : + tensor<6xf16> + } : () -> tensor<6xf16> + %1 = "tosa.cast"(%0) : (tensor<6xf16>) -> tensor<6xi32> + return %1 : tensor<6xi32> +} + +// CHECK-LABEL: @cast_fold_f32_to_i8 +func.func @cast_fold_f32_to_i8() -> tensor<5xi8> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 0, 5, 127, -128{{.*}}tensor<5xi8> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12.0, 0.0, 5.0, 32770.11, -32770.11]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xi8> + return %1 : tensor<5xi8> +} + +// CHECK-LABEL: @cast_fold_float_to_int_infinity_zero_nan +func.func @cast_fold_float_to_int_infinity_zero_nan() -> tensor<5xi16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}32767, -32768, 0, 0{{.*}}tensor<5xi16> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7F800000, 0xFF800000, 0.0, -0.0, 0x7FC00000]> : + tensor<5xf32> + } : () -> tensor<5xf32> + // expected-warning@below {{Float tensor is casted to integer and it contains NaN values.}} + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xi16> + return %1 : tensor<5xi16> +} + +// ----- +// Casts from int to int + +// CHECK-LABEL: @cast_fold_i16_to_i32 +func.func @cast_fold_i16_to_i32() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 0, -5{{.*}}tensor<3xi32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5]> : + tensor<3xi16> + } : () -> tensor<3xi16> + %1 = "tosa.cast"(%0) : (tensor<3xi16>) -> tensor<3xi32> + return %1 : tensor<3xi32> +} + +// CHECK-LABEL: @cast_fold_i32_to_i8 +func.func @cast_fold_i32_to_i8() -> tensor<5xi8> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 0, -5, -1, 1{{.*}}tensor<5xi8> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5, 511, -511]> : + tensor<5xi32> + } : () -> tensor<5xi32> + %1 = "tosa.cast"(%0) : (tensor<5xi32>) -> tensor<5xi8> + return %1 : tensor<5xi8> +} + + +// CHECK-LABEL: @cast_fold_i16_to_i1 +func.func @cast_fold_i16_to_i1() -> tensor<3xi1> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true, false, true{{.*}}tensor<3xi1> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5]> : + tensor<3xi16> + } : () -> tensor<3xi16> + %1 = "tosa.cast"(%0) : (tensor<3xi16>) -> tensor<3xi1> + return %1 : tensor<3xi1> +} + +// ----- +// Casts from int to float + +// CHECK-LABEL: @cast_fold_i16_to_f32 +func.func @cast_fold_i16_to_f32() -> tensor<3xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.2{{0*}}e+01, 0.{{0*}}e+00, -5.{{0*}}e+00 + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5]> : + tensor<3xi16> + } : () -> tensor<3xi16> + %1 = "tosa.cast"(%0) : (tensor<3xi16>) -> tensor<3xf32> + return %1 : tensor<3xf32> +} + +// CHECK-LABEL: @cast_fold_i16_to_f16 +func.func @cast_fold_i16_to_f16() -> tensor<3xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.2{{0*}}e+01, 0.{{0*}}e+00, -5.{{0*}}e+00 + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5]> : + tensor<3xi16> + } : () -> tensor<3xi16> + %1 = "tosa.cast"(%0) : (tensor<3xi16>) -> tensor<3xf16> + return %1 : tensor<3xf16> +} + +// CHECK-LABEL: @cast_fold_i32_to_f16 +func.func @cast_fold_i32_to_f16() -> tensor<4xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.2{{0*}}e+01, 0.{{0*}}e+00, -5.{{0*}}e+00, 0x7C00 + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5, 2147483647]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %1 = "tosa.cast"(%0) : (tensor<4xi32>) -> tensor<4xf16> + return %1 : tensor<4xf16> +} + +// ----- +// Casts from float to float + +// CHECK-LABEL: @cast_fold_f32_to_f16 +func.func @cast_fold_f32_to_f16() -> tensor<5xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.2{{0*}}e+01, 0.{{0*}}e+00, 5.{{.*}}, 3.2{{.*}}+04, -3.2{{.*}}e+04{{.*}}tensor<5xf16> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12.0, 0.0, 5.2, 32770.11, -32770.11]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xf16> + return %1 : tensor<5xf16> +} + +// CHECK-LABEL: @cast_fold_f32_to_f16_imprecise +func.func @cast_fold_f32_to_f16_imprecise() -> tensor<5xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}9.56{{.*}}e-02, 0x7C00, 0xFC00, 0.{{0*}}e+00, -0.{{0*}}e+00{{.*}}tensor<5xf16> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0.0956256023875237592352073, + 346534769.23495863245, -346534769.23495863245, + 0.000000000003, -0.000000000000001]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xf16> + return %1 : tensor<5xf16> +} + +// CHECK-LABEL: @cast_fold_f32_to_f16_infinity_zero_nan +func.func @cast_fold_f32_to_f16_infinity_zero_nan() -> tensor<5xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7C00, 0xFC00, 0.{{0*}}e+00, -0.{{0*}}e+00, 0x7E00{{.*}}tensor<5xf16> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7F800000, 0xFF800000, 0.0, -0.0, 0x7FC00000]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xf16> + return %1 : tensor<5xf16> +} + +// CHECK-LABEL: @cast_fold_f16_to_f32 +func.func @cast_fold_f16_to_f32() -> tensor<5xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.2{{0*}}e+01, 0.{{0*}}e+00, 5.{{.*}}, 3.2{{.*}}+04, -3.2{{.*}}e+04{{.*}}tensor<5xf32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12.0, 0.0, 5.2, 32770.11, -32770.11]> : + tensor<5xf16> + } : () -> tensor<5xf16> + %1 = "tosa.cast"(%0) : (tensor<5xf16>) -> tensor<5xf32> + return %1 : tensor<5xf32> +} + +// CHECK-LABEL: @cast_fold_f16_to_f32 +func.func @cast_fold_f16_to_f32_infinity_zero_nan() -> tensor<5xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0xFF800000, 0.{{0*}}e+00, -0.{{0*}}e+00, 0x7FC00000{{.*}}tensor<5xf32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7C00, 0xFC00, 0.0, -0.0, 0x7E00]> : + tensor<5xf16> + } : () -> tensor<5xf16> + %1 = "tosa.cast"(%0) : (tensor<5xf16>) -> tensor<5xf32> + return %1 : tensor<5xf32> +} From 1bcd322f338583b231c6498edddc82de67490139 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Mon, 3 Apr 2023 09:09:42 +0100 Subject: [PATCH 2/7] Replace for/if with if(any_of) --- .../Tosa/Transforms/TosaFoldConstantCast.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp index 5bb49c6a6f03b..91ea4799caad8 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp @@ -95,14 +95,12 @@ struct TosaFoldConstantCast : public OpRewritePattern { return; } - for (auto val : elements.getValues()) { - // Report encountered NaNs once - if (val.isNaN()) { - location->emitWarning( - "Float tensor is casted to integer and it contains NaN values. The " - "cast results in an unspecified value."); - return; - } + // Report encountered NaNs + auto checkNan = [](const APFloat &val) { return val.isNaN(); }; + if (any_of(elements.getValues(), checkNan)) { + location->emitWarning( + "Float tensor is casted to integer and it contains NaN values. The " + "cast results in an unspecified value."); } } From 0222e16af672721482f05f11cdbdf9d5115e9639 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Mon, 3 Apr 2023 12:34:18 +0100 Subject: [PATCH 3/7] Add information on why the conversion from NaN to int does not expect any specific value --- mlir/test/Dialect/Tosa/constant-cast-opt.mlir | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Tosa/constant-cast-opt.mlir b/mlir/test/Dialect/Tosa/constant-cast-opt.mlir index 7a262d37a8289..42c1d5bf6ff6a 100644 --- a/mlir/test/Dialect/Tosa/constant-cast-opt.mlir +++ b/mlir/test/Dialect/Tosa/constant-cast-opt.mlir @@ -74,7 +74,9 @@ func.func @cast_fold_f32_to_i8() -> tensor<5xi8> { // CHECK-LABEL: @cast_fold_float_to_int_infinity_zero_nan func.func @cast_fold_float_to_int_infinity_zero_nan() -> tensor<5xi16> { - // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}32767, -32768, 0, 0{{.*}}tensor<5xi16> + // Check if infinity and zero are translated properly. Don't expect any + // specific value for NaN, as the casted int value for NaN is unspecified. + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}32767, -32768, 0, 0, {{.*}}tensor<5xi16> // CHECK-NOT: tosa.cast // CHECK: return [[RES]] %0 = "tosa.const"() {value = From 4659d22ac9393c57d902d00c4b7091b9384ee293 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Mon, 3 Apr 2023 13:44:35 +0100 Subject: [PATCH 4/7] Add more bails for unsupported cast folds * Mark anything but int<->float casts as match failure (the spec only allows for those) * Bail on unsigned casts (whenever the information is available and requested, use it anyway), add comment on what is required to fully support unsigned casts --- .../Tosa/Transforms/TosaFoldConstantCast.cpp | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp index 91ea4799caad8..7359642b14d12 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp @@ -34,7 +34,7 @@ struct TosaFoldConstantCast : public OpRewritePattern { static APFloat convertIntToFloat(const APInt &toConvert, FloatType targetType) { APFloat res(targetType.getFloatSemantics()); - res.convertFromAPInt(toConvert, true, tosaRoundingMode); + res.convertFromAPInt(toConvert, true /* isSigned */, tosaRoundingMode); return res; } @@ -64,12 +64,15 @@ struct TosaFoldConstantCast : public OpRewritePattern { // The result of this conversion should be an integer which might still be // outside of the target integer range. auto floatSize = APFloat::getSizeInBits(toConvert.getSemantics()); - APSInt converted(std::max(floatSize, targetWidth), false); + APSInt converted(std::max(floatSize, targetWidth), targetType.isUnsigned()); bool ignored = false; toConvert.convertToInteger(converted, APFloat::rmNearestTiesToEven, &ignored); // Clip to allowed range. if (targetWidth < floatSize) { + if (targetType.isUnsigned()) { + return converted.truncUSat(targetWidth); + } return converted.truncSSat(targetWidth); } return converted; @@ -80,6 +83,9 @@ struct TosaFoldConstantCast : public OpRewritePattern { if (targetType.getWidth() == 1) { return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); } + if (targetType.isUnsigned()) { + return toConvert.zextOrTrunc(targetType.getIntOrFloatBitWidth()); + } return toConvert.sextOrTrunc(targetType.getIntOrFloatBitWidth()); } @@ -133,6 +139,26 @@ struct TosaFoldConstantCast : public OpRewritePattern { "if its input only has a single user"); } + // Report a match failure for unexpected types + if (!toType.isIntOrFloat() || !fromType.isIntOrFloat()) { + return rewriter.notifyMatchFailure( + tosaCast, "Only casts from/to int/float are supported."); + } + + auto isUnsigned = [](Type toCheck) { + return isa(toCheck) && + cast(toCheck).isUnsigned(); + }; + auto typesToCheck = {toType, fromType}; + if (llvm::any_of(typesToCheck, isUnsigned)) { + // TOSA casts currently don't support unsigned integers. + // To support them by here, one could use APSInt instead of APInts, + // however, this causes trouble with `getValues` which does not support + // APSInts currently. + return rewriter.notifyMatchFailure( + tosaCast, "Cast folding from/to unsigned integers is not supported."); + } + DenseElementsAttr res; if (auto intOutTy = dyn_cast(toType)) { if (isa(fromType)) { From c27805a75b4ba6db917bb55e967d7cfc4a35d56b Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Mon, 3 Apr 2023 14:07:40 +0100 Subject: [PATCH 5/7] Cleanup --- mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp index 7359642b14d12..0bfb4928a25e9 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp @@ -115,10 +115,9 @@ struct TosaFoldConstantCast : public OpRewritePattern { auto inputTensor = tosaCast.getInput(); // If the input tensor is not constant, we cannot fold it. - auto isDenseConst = - notifyIfNoTosaDenseConstantTensor(inputTensor, tosaCast, rewriter); - if (failed(isDenseConst)) { - return isDenseConst; + if (failed(notifyIfNoTosaDenseConstantTensor(inputTensor, tosaCast, + rewriter))) { + return failure(); } auto fromType = inputTensor.getType().getElementType(); From 77cb13359004a48b23ea9fcbd17989900a42e935 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Wed, 5 Apr 2023 13:30:25 +0100 Subject: [PATCH 6/7] Remove unused header --- mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp index 0bfb4928a25e9..94a537f7d169d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp @@ -18,7 +18,6 @@ #include #include #include -#include #include #include From 12b5f8d10d0ba77764676488bebc3a948dcf8d44 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Wed, 5 Apr 2023 13:31:07 +0100 Subject: [PATCH 7/7] Add command line flag to disable folding of integer casts --- .../mlir/Dialect/Tosa/Transforms/Passes.h | 3 ++- .../mlir/Dialect/Tosa/Transforms/Passes.td | 8 +++++-- .../Tosa/Transforms/TosaFoldConstantCast.cpp | 24 +++++++++++++++++-- .../TosaLayerwiseConstantFoldPass.cpp | 3 ++- ...constant-cast-opt-disable-int-folding.mlir | 24 +++++++++++++++++++ 5 files changed, 56 insertions(+), 6 deletions(-) create mode 100644 mlir/test/Dialect/Tosa/constant-cast-opt-disable-int-folding.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 3858836bb2b4b..1dc8bb9d62b04 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -30,7 +30,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx, void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaFoldConstantCastPatterns(MLIRContext *ctx, - RewritePatternSet &patterns); + RewritePatternSet &patterns, + bool enableIntCastFolding); void populateTosaFoldConstantPowPatterns(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index 46bd7a4780e00..2cbddd1e666eb 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -1,4 +1,4 @@ -//===-- Passes.td - TOSA pass declarations ----*- tablegen -*-===// +//===-- Passes.td - TOSA pass declarations -----------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -22,6 +22,10 @@ def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func:: }]; let constructor = "createTosaLayerwiseConstantFoldPass()"; + let options = [ + Option<"enableIntCastFolding", "enable-cast-folding-int-input", "bool", + "true", "Enable folding for casts from integer types"> + ]; } def TosaInferShapes : Pass<"tosa-infer-shapes", "func::FuncOp"> { @@ -56,7 +60,7 @@ def TosaOptionalDecompositions : Pass<"tosa-optional-decompositions", "func::FuncOp"> { let summary = "Applies Tosa operations optional decompositions"; let description = [{ - Pass to apply the Tosa operations decompositions + Pass to apply the Tosa operations decompositions exposed as populate functions in include/mlir/Dialect/Tosa/Transforms/Passes.h }]; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp index 94a537f7d169d..e5012293727b5 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include using namespace mlir; @@ -185,9 +186,28 @@ struct TosaFoldConstantCast : public OpRewritePattern { } }; +struct TosaFoldConstantFloatCasts : TosaFoldConstantCast { + + TosaFoldConstantFloatCasts(MLIRContext *ctx) : TosaFoldConstantCast(ctx) {} + + LogicalResult matchAndRewrite(CastOp tosaCast, + PatternRewriter &rewriter) const override { + if (isa(tosaCast.getInput().getType().getElementType())) { + return rewriter.notifyMatchFailure( + tosaCast, "Folding casts from int is currently disabled."); + } + + return TosaFoldConstantCast::matchAndRewrite(tosaCast, rewriter); + } +}; + } // namespace void mlir::tosa::populateTosaFoldConstantCastPatterns( - MLIRContext *ctx, RewritePatternSet &patterns) { - patterns.add(ctx); + MLIRContext *ctx, RewritePatternSet &patterns, bool enableIntCastFolding) { + if (enableIntCastFolding) { + patterns.add(ctx); + } else { + patterns.add(ctx); + } } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index 96d3ac483472a..6d0a3c89ed079 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -50,7 +50,8 @@ struct TosaLayerwiseConstantFoldPass RewritePatternSet patterns(ctx); auto func = getOperation(); - mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns); + mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns, + enableIntCastFolding); mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns); diff --git a/mlir/test/Dialect/Tosa/constant-cast-opt-disable-int-folding.mlir b/mlir/test/Dialect/Tosa/constant-cast-opt-disable-int-folding.mlir new file mode 100644 index 0000000000000..959d5c5bae034 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-cast-opt-disable-int-folding.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt --split-input-file -verify-diagnostics -tosa-layerwise-constant-fold="enable-cast-folding-int-input=false" %s | FileCheck %s + +// CHECK-LABEL: @cast_fold_f32_to_i32 +func.func @cast_fold_f32_to_i32() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 4, 5{{.*}}tensor<3xi32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[12.0, 4.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "tosa.cast"(%0) : (tensor<3xf32>) -> tensor<3xi32> + return %1 : tensor<3xi32> +} + +// CHECK-LABEL: @cast_fold_i16_to_f32 +func.func @cast_fold_i16_to_f32() -> tensor<3xf32> { + // CHECK: tosa.const + // CHECK: [[RES:]] ={{.*}}tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5]> : + tensor<3xi16> + } : () -> tensor<3xi16> + %1 = "tosa.cast"(%0) : (tensor<3xi16>) -> tensor<3xf32> + return %1 : tensor<3xf32> +}