diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 895a1b00d2eef..1dc8bb9d62b04 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -29,6 +29,9 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns); +void populateTosaFoldConstantCastPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + bool enableIntCastFolding); void populateTosaFoldConstantPowPatterns(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index 46bd7a4780e00..2cbddd1e666eb 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -1,4 +1,4 @@ -//===-- Passes.td - TOSA pass declarations ----*- tablegen -*-===// +//===-- Passes.td - TOSA pass declarations -----------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -22,6 +22,10 @@ def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func:: }]; let constructor = "createTosaLayerwiseConstantFoldPass()"; + let options = [ + Option<"enableIntCastFolding", "enable-cast-folding-int-input", "bool", + "true", "Enable folding for casts from integer types"> + ]; } def TosaInferShapes : Pass<"tosa-infer-shapes", "func::FuncOp"> { @@ -56,7 +60,7 @@ def TosaOptionalDecompositions : Pass<"tosa-optional-decompositions", "func::FuncOp"> { let summary = "Applies Tosa operations optional decompositions"; let description = [{ - Pass to apply the Tosa operations decompositions + Pass to apply the Tosa operations decompositions exposed as populate functions in include/mlir/Dialect/Tosa/Transforms/Passes.h }]; diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h index e695b604a501b..8f61728af922a 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h @@ -27,10 +27,15 @@ using DimensionType = ArrayRef; /// Type for tensor offsets. using OffsetType = size_t; +static constexpr llvm::RoundingMode tosaRoundingMode = + APFloat::rmNearestTiesToEven; + /// Transform a tensor with the given transformation function. +template DenseElementsAttr applyElementWise( const DenseElementsAttr &toTransform, - const std::function &toApply); + const std::function &toApply, + TargetType targetType); /// Apply the given transformation function on the elements of the given /// tensors. If the input tensors do not match \p targetType, broadcasting is @@ -74,7 +79,7 @@ OffsetType getBroadcastedOffset(DimensionType desiredShape, OffsetType offset); /// Function to compute the reciprocal. -APFloat computeReciprocal(const APFloat &floatVal, Type floatTy); +APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy); } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index d75fdfd7c9d15..e15d5e9463e47 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaDecomposeConv2D.cpp TosaDecomposeDepthwise.cpp TosaFoldCommon.cpp + TosaFoldConstantCast.cpp TosaFoldConstantPow.cpp TosaFoldConstantReciprocal.cpp TosaFoldConstantRSQRT.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp index 385f8aa9c4fcc..0e452a2463fcf 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp @@ -23,29 +23,52 @@ using namespace mlir; using namespace mlir::tosa; -namespace { -static constexpr llvm::RoundingMode reciprocalRoundingMode = - APFloat::rmNearestTiesToEven; -} // namespace - +template DenseElementsAttr mlir::tosa::applyElementWise( const DenseElementsAttr &toTransform, - const std::function &toApply) { - llvm::SmallVector transformedValues; + const std::function &toApply, + TargetType targetType) { + SmallVector transformedValues; // We already know the amount of values we will insert, reserve space for // all of them to avoid dynamic resizing transformedValues.reserve(toTransform.getNumElements()); - for (auto val : toTransform.getValues()) { - auto transformedVal = toApply(val, toTransform.getElementType()); + for (auto val : toTransform.getValues()) { + auto transformedVal = toApply(val, targetType); transformedValues.push_back(transformedVal); } + auto inShape = toTransform.getType(); + auto outTy = inShape.cloneWith(None, targetType); + // Replace the current tensor with one containing the computed values - auto newTensor = - DenseElementsAttr::get(toTransform.getType(), transformedValues); + auto newTensor = DenseElementsAttr::get(outTy, transformedValues); return newTensor; } +template DenseElementsAttr +mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + FloatType targetType); + +template DenseElementsAttr +mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + FloatType targetType); + +template DenseElementsAttr +mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + IntegerType targetType); + +template DenseElementsAttr +mlir::tosa::applyElementWise( + const DenseElementsAttr &toTransform, + const std::function &toApply, + IntegerType targetType); + DenseElementsAttr mlir::tosa::applyElementWise( const DenseElementsAttr &first, const DenseElementsAttr &second, TensorType targetType, @@ -182,10 +205,11 @@ OffsetType mlir::tosa::getBroadcastedOffset(DimensionType desiredShape, return indexToOffset(toBeBroadcastedShape, indexBroadcasted); } -APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, Type floatTy) { +APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal, + FloatType floatTy) { auto recipAttr = FloatAttr::get(floatTy, 1.0); APFloat recip = recipAttr.getValue(); - recip.divide(floatVal, reciprocalRoundingMode); + recip.divide(floatVal, tosaRoundingMode); return recip; } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp new file mode 100644 index 0000000000000..e5012293727b5 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantCast.cpp @@ -0,0 +1,213 @@ +//===- TosaFoldConstantCast.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold TOSA cast operation on constant data +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaFoldConstantCast : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + static APFloat convertIntToFloat(const APInt &toConvert, + FloatType targetType) { + APFloat res(targetType.getFloatSemantics()); + res.convertFromAPInt(toConvert, true /* isSigned */, tosaRoundingMode); + return res; + } + + static APFloat convertFloatToFloat(const APFloat &toConvert, + FloatType targetType) { + APFloat res(toConvert); + bool didLosePrecision; + res.convert(targetType.getFloatSemantics(), tosaRoundingMode, + &didLosePrecision); + return res; + } + + static APInt convertFloatToInt(const APFloat &toConvert, + IntegerType targetType) { + auto targetWidth = targetType.getIntOrFloatBitWidth(); + // Converting NaN to an integer results in an unpredictable value. Pick 0. + if (toConvert.isNaN()) { + return APInt::getZero(targetWidth); + } + + // Make sure to properly translate booleans + if (targetWidth == 1) { + return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); + } + + // Use the built-in functionality of APFloats to convert to integers. + // The result of this conversion should be an integer which might still be + // outside of the target integer range. + auto floatSize = APFloat::getSizeInBits(toConvert.getSemantics()); + APSInt converted(std::max(floatSize, targetWidth), targetType.isUnsigned()); + bool ignored = false; + toConvert.convertToInteger(converted, APFloat::rmNearestTiesToEven, + &ignored); + // Clip to allowed range. + if (targetWidth < floatSize) { + if (targetType.isUnsigned()) { + return converted.truncUSat(targetWidth); + } + return converted.truncSSat(targetWidth); + } + return converted; + } + + static APInt convertIntToInt(const APInt &toConvert, IntegerType targetType) { + // Make sure to properly translate booleans + if (targetType.getWidth() == 1) { + return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); + } + if (targetType.isUnsigned()) { + return toConvert.zextOrTrunc(targetType.getIntOrFloatBitWidth()); + } + return toConvert.sextOrTrunc(targetType.getIntOrFloatBitWidth()); + } + + static void warnAboutNaNToIntCast(DenseElementsAttr elements, CastOp location, + PatternRewriter &rewriter) { + // This is only relevant if the input values are float + if (!isa(elements.getElementType())) { + return; + } + // Check if it is an float to integer conversion + auto resultType = location.getOutput().getType(); + if (!isa(cast(resultType).getElementType())) { + return; + } + + // Report encountered NaNs + auto checkNan = [](const APFloat &val) { return val.isNaN(); }; + if (any_of(elements.getValues(), checkNan)) { + location->emitWarning( + "Float tensor is casted to integer and it contains NaN values. The " + "cast results in an unspecified value."); + } + } + + LogicalResult matchAndRewrite(CastOp tosaCast, + PatternRewriter &rewriter) const override { + auto inputTensor = tosaCast.getInput(); + + // If the input tensor is not constant, we cannot fold it. + if (failed(notifyIfNoTosaDenseConstantTensor(inputTensor, tosaCast, + rewriter))) { + return failure(); + } + + auto fromType = inputTensor.getType().getElementType(); + auto toType = tosaCast.getOutput().getType().getElementType(); + + DenseElementsAttr elements; + matchPattern(inputTensor, m_Constant(&elements)); + + // Issue a warning if we convert float -> int and NaNs are present; the + // result value is unspecified in that case + warnAboutNaNToIntCast(elements, tosaCast, rewriter); + + // Only fold splat tensors and those used only once to avoid duplicating + // them. + if (!inputTensor.hasOneUse() && !isa(elements)) { + return rewriter.notifyMatchFailure(tosaCast, + "Currently, casts will only be folded " + "if its input only has a single user"); + } + + // Report a match failure for unexpected types + if (!toType.isIntOrFloat() || !fromType.isIntOrFloat()) { + return rewriter.notifyMatchFailure( + tosaCast, "Only casts from/to int/float are supported."); + } + + auto isUnsigned = [](Type toCheck) { + return isa(toCheck) && + cast(toCheck).isUnsigned(); + }; + auto typesToCheck = {toType, fromType}; + if (llvm::any_of(typesToCheck, isUnsigned)) { + // TOSA casts currently don't support unsigned integers. + // To support them by here, one could use APSInt instead of APInts, + // however, this causes trouble with `getValues` which does not support + // APSInts currently. + return rewriter.notifyMatchFailure( + tosaCast, "Cast folding from/to unsigned integers is not supported."); + } + + DenseElementsAttr res; + if (auto intOutTy = dyn_cast(toType)) { + if (isa(fromType)) { + res = applyElementWise( + elements, &convertFloatToInt, intOutTy); + } else { + assert(isa(fromType)); + res = applyElementWise( + elements, &convertIntToInt, intOutTy); + } + } else { + assert(isa(toType)); + auto floatOutTy = cast(toType); + if (isa(fromType)) { + res = applyElementWise( + elements, &convertFloatToFloat, floatOutTy); + } else { + assert(isa(fromType)); + res = applyElementWise( + elements, &convertIntToFloat, floatOutTy); + } + } + + rewriter.replaceOpWithNewOp(tosaCast, res.getType(), res); + return success(); + } +}; + +struct TosaFoldConstantFloatCasts : TosaFoldConstantCast { + + TosaFoldConstantFloatCasts(MLIRContext *ctx) : TosaFoldConstantCast(ctx) {} + + LogicalResult matchAndRewrite(CastOp tosaCast, + PatternRewriter &rewriter) const override { + if (isa(tosaCast.getInput().getType().getElementType())) { + return rewriter.notifyMatchFailure( + tosaCast, "Folding casts from int is currently disabled."); + } + + return TosaFoldConstantCast::matchAndRewrite(tosaCast, rewriter); + } +}; + +} // namespace + +void mlir::tosa::populateTosaFoldConstantCastPatterns( + MLIRContext *ctx, RewritePatternSet &patterns, bool enableIntCastFolding) { + if (enableIntCastFolding) { + patterns.add(ctx); + } else { + patterns.add(ctx); + } +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp index 5913bd2c51006..bc18e25961e7b 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantRSQRT.cpp @@ -32,7 +32,7 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - static APFloat computeRSQRT(const APFloat &apFloatVal, Type floatTy) { + static APFloat computeRSQRT(const APFloat &apFloatVal, FloatType floatTy) { // The result for negative values (apart from zero) is always NaN if (apFloatVal.isNegative() && !apFloatVal.isNegZero()) { return APFloat::getNaN(apFloatVal.getSemantics()); @@ -72,7 +72,9 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern { } // Create a new tensor with the updated values - auto newTensor = applyElementWise(inputValues, &computeRSQRT); + auto newTensor = applyElementWise( + inputValues, &computeRSQRT, + cast(inputValues.getElementType())); // Replace the use of the reciprocal with the transformed tensor rewriter.replaceOpWithNewOp(rsqrt, newTensor.getType(), newTensor); @@ -84,6 +86,7 @@ struct TosaFoldConstantRSQRT : public OpRewritePattern { } // namespace void mlir::tosa::populateTosaFoldConstantRSQRTPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { patterns.add(ctx); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp index 2f7335dc0477f..396c966c3259b 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantReciprocal.cpp @@ -57,7 +57,9 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern { } // Create a new tensor with the updated values - auto newTensor = applyElementWise(inputValues, &computeReciprocal); + auto newTensor = applyElementWise( + inputValues, &computeReciprocal, + cast(inputValues.getElementType())); // Replace the use of the reciprocal with the transformed tensor rewriter.replaceOpWithNewOp(recip, newTensor.getType(), newTensor); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index 1288b1c7ade40..6d0a3c89ed079 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -50,6 +50,8 @@ struct TosaLayerwiseConstantFoldPass RewritePatternSet patterns(ctx); auto func = getOperation(); + mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns, + enableIntCastFolding); mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns); diff --git a/mlir/test/Dialect/Tosa/constant-cast-opt-disable-int-folding.mlir b/mlir/test/Dialect/Tosa/constant-cast-opt-disable-int-folding.mlir new file mode 100644 index 0000000000000..959d5c5bae034 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-cast-opt-disable-int-folding.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt --split-input-file -verify-diagnostics -tosa-layerwise-constant-fold="enable-cast-folding-int-input=false" %s | FileCheck %s + +// CHECK-LABEL: @cast_fold_f32_to_i32 +func.func @cast_fold_f32_to_i32() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 4, 5{{.*}}tensor<3xi32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[12.0, 4.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "tosa.cast"(%0) : (tensor<3xf32>) -> tensor<3xi32> + return %1 : tensor<3xi32> +} + +// CHECK-LABEL: @cast_fold_i16_to_f32 +func.func @cast_fold_i16_to_f32() -> tensor<3xf32> { + // CHECK: tosa.const + // CHECK: [[RES:]] ={{.*}}tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5]> : + tensor<3xi16> + } : () -> tensor<3xi16> + %1 = "tosa.cast"(%0) : (tensor<3xi16>) -> tensor<3xf32> + return %1 : tensor<3xf32> +} diff --git a/mlir/test/Dialect/Tosa/constant-cast-opt.mlir b/mlir/test/Dialect/Tosa/constant-cast-opt.mlir new file mode 100644 index 0000000000000..42c1d5bf6ff6a --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-cast-opt.mlir @@ -0,0 +1,244 @@ +// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s + +// ----- +// Casts from float to int + +// CHECK-LABEL: @cast_fold_f32_to_i1_all_none_zero +func.func @cast_fold_f32_to_i1_all_none_zero() -> tensor<3xi1> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true{{.*}}tensor<3xi1> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[12.0, 4.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "tosa.cast"(%0) : (tensor<3xf32>) -> tensor<3xi1> + return %1 : tensor<3xi1> +} + +// CHECK-LABEL: @cast_fold_f32_to_i1 +func.func @cast_fold_f32_to_i1() -> tensor<3xi1> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true, false, true{{.*}}tensor<3xi1> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[12.0, 0.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "tosa.cast"(%0) : (tensor<3xf32>) -> tensor<3xi1> + return %1 : tensor<3xi1> +} + +// CHECK-LABEL: @cast_fold_f32_to_i32 +func.func @cast_fold_f32_to_i32() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 4, 5{{.*}}tensor<3xi32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[12.0, 4.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "tosa.cast"(%0) : (tensor<3xf32>) -> tensor<3xi32> + return %1 : tensor<3xi32> +} + +// CHECK-LABEL: @cast_fold_f32_to_i16 +func.func @cast_fold_f32_to_i16() -> tensor<5xi16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 0, 5, 32767, -32768{{.*}}tensor<5xi16> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12.0, 0.0, 5.0, 32770.11, -32770.11]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xi16> + return %1 : tensor<5xi16> +} + +// CHECK-LABEL: @cast_fold_f16_to_i32 +func.func @cast_fold_f16_to_i32() -> tensor<6xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 14, 0, 5, 277, -278{{.*}}tensor<6xi32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12.5, 13.5, 0.0, 5.0, 277.11, -277.71]> : + tensor<6xf16> + } : () -> tensor<6xf16> + %1 = "tosa.cast"(%0) : (tensor<6xf16>) -> tensor<6xi32> + return %1 : tensor<6xi32> +} + +// CHECK-LABEL: @cast_fold_f32_to_i8 +func.func @cast_fold_f32_to_i8() -> tensor<5xi8> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 0, 5, 127, -128{{.*}}tensor<5xi8> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12.0, 0.0, 5.0, 32770.11, -32770.11]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xi8> + return %1 : tensor<5xi8> +} + +// CHECK-LABEL: @cast_fold_float_to_int_infinity_zero_nan +func.func @cast_fold_float_to_int_infinity_zero_nan() -> tensor<5xi16> { + // Check if infinity and zero are translated properly. Don't expect any + // specific value for NaN, as the casted int value for NaN is unspecified. + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}32767, -32768, 0, 0, {{.*}}tensor<5xi16> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7F800000, 0xFF800000, 0.0, -0.0, 0x7FC00000]> : + tensor<5xf32> + } : () -> tensor<5xf32> + // expected-warning@below {{Float tensor is casted to integer and it contains NaN values.}} + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xi16> + return %1 : tensor<5xi16> +} + +// ----- +// Casts from int to int + +// CHECK-LABEL: @cast_fold_i16_to_i32 +func.func @cast_fold_i16_to_i32() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 0, -5{{.*}}tensor<3xi32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5]> : + tensor<3xi16> + } : () -> tensor<3xi16> + %1 = "tosa.cast"(%0) : (tensor<3xi16>) -> tensor<3xi32> + return %1 : tensor<3xi32> +} + +// CHECK-LABEL: @cast_fold_i32_to_i8 +func.func @cast_fold_i32_to_i8() -> tensor<5xi8> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}12, 0, -5, -1, 1{{.*}}tensor<5xi8> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5, 511, -511]> : + tensor<5xi32> + } : () -> tensor<5xi32> + %1 = "tosa.cast"(%0) : (tensor<5xi32>) -> tensor<5xi8> + return %1 : tensor<5xi8> +} + + +// CHECK-LABEL: @cast_fold_i16_to_i1 +func.func @cast_fold_i16_to_i1() -> tensor<3xi1> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true, false, true{{.*}}tensor<3xi1> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5]> : + tensor<3xi16> + } : () -> tensor<3xi16> + %1 = "tosa.cast"(%0) : (tensor<3xi16>) -> tensor<3xi1> + return %1 : tensor<3xi1> +} + +// ----- +// Casts from int to float + +// CHECK-LABEL: @cast_fold_i16_to_f32 +func.func @cast_fold_i16_to_f32() -> tensor<3xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.2{{0*}}e+01, 0.{{0*}}e+00, -5.{{0*}}e+00 + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5]> : + tensor<3xi16> + } : () -> tensor<3xi16> + %1 = "tosa.cast"(%0) : (tensor<3xi16>) -> tensor<3xf32> + return %1 : tensor<3xf32> +} + +// CHECK-LABEL: @cast_fold_i16_to_f16 +func.func @cast_fold_i16_to_f16() -> tensor<3xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.2{{0*}}e+01, 0.{{0*}}e+00, -5.{{0*}}e+00 + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5]> : + tensor<3xi16> + } : () -> tensor<3xi16> + %1 = "tosa.cast"(%0) : (tensor<3xi16>) -> tensor<3xf16> + return %1 : tensor<3xf16> +} + +// CHECK-LABEL: @cast_fold_i32_to_f16 +func.func @cast_fold_i32_to_f16() -> tensor<4xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.2{{0*}}e+01, 0.{{0*}}e+00, -5.{{0*}}e+00, 0x7C00 + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12, 0, -5, 2147483647]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %1 = "tosa.cast"(%0) : (tensor<4xi32>) -> tensor<4xf16> + return %1 : tensor<4xf16> +} + +// ----- +// Casts from float to float + +// CHECK-LABEL: @cast_fold_f32_to_f16 +func.func @cast_fold_f32_to_f16() -> tensor<5xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.2{{0*}}e+01, 0.{{0*}}e+00, 5.{{.*}}, 3.2{{.*}}+04, -3.2{{.*}}e+04{{.*}}tensor<5xf16> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12.0, 0.0, 5.2, 32770.11, -32770.11]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xf16> + return %1 : tensor<5xf16> +} + +// CHECK-LABEL: @cast_fold_f32_to_f16_imprecise +func.func @cast_fold_f32_to_f16_imprecise() -> tensor<5xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}9.56{{.*}}e-02, 0x7C00, 0xFC00, 0.{{0*}}e+00, -0.{{0*}}e+00{{.*}}tensor<5xf16> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0.0956256023875237592352073, + 346534769.23495863245, -346534769.23495863245, + 0.000000000003, -0.000000000000001]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xf16> + return %1 : tensor<5xf16> +} + +// CHECK-LABEL: @cast_fold_f32_to_f16_infinity_zero_nan +func.func @cast_fold_f32_to_f16_infinity_zero_nan() -> tensor<5xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7C00, 0xFC00, 0.{{0*}}e+00, -0.{{0*}}e+00, 0x7E00{{.*}}tensor<5xf16> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7F800000, 0xFF800000, 0.0, -0.0, 0x7FC00000]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xf16> + return %1 : tensor<5xf16> +} + +// CHECK-LABEL: @cast_fold_f16_to_f32 +func.func @cast_fold_f16_to_f32() -> tensor<5xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.2{{0*}}e+01, 0.{{0*}}e+00, 5.{{.*}}, 3.2{{.*}}+04, -3.2{{.*}}e+04{{.*}}tensor<5xf32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[12.0, 0.0, 5.2, 32770.11, -32770.11]> : + tensor<5xf16> + } : () -> tensor<5xf16> + %1 = "tosa.cast"(%0) : (tensor<5xf16>) -> tensor<5xf32> + return %1 : tensor<5xf32> +} + +// CHECK-LABEL: @cast_fold_f16_to_f32 +func.func @cast_fold_f16_to_f32_infinity_zero_nan() -> tensor<5xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0xFF800000, 0.{{0*}}e+00, -0.{{0*}}e+00, 0x7FC00000{{.*}}tensor<5xf32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7C00, 0xFC00, 0.0, -0.0, 0x7E00]> : + tensor<5xf16> + } : () -> tensor<5xf16> + %1 = "tosa.cast"(%0) : (tensor<5xf16>) -> tensor<5xf32> + return %1 : tensor<5xf32> +}