From 5d538f9719bc60768dac8922057c5ec295655308 Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Tue, 13 Jun 2023 15:46:56 +0000 Subject: [PATCH] TosaToLinAlg: fix tosa.cast legalization of FP->Int for non FP32 types. --- .../Conversion/TosaToLinalg/TosaToLinalg.cpp | 21 +++++++++++++++++-- .../TosaToLinalg/tosa-to-linalg.mlir | 11 ++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 342634364cc10..2e280dba469c9 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -471,16 +471,33 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { - auto intMin = rewriter.create( + Value intMin = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); - auto intMax = rewriter.create( + Value intMax = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); + // Since F32 constants are created, we may still need to convert them to + // the correct type. + auto convertType = [&](Type ty, Value arg) { + auto argTy = arg.getType(); + bool bitExtend = + argTy.getIntOrFloatBitWidth() < ty.getIntOrFloatBitWidth(); + if (ty != argTy) { + if (!bitExtend) + arg = rewriter.create(loc, ty, arg); + else + arg = rewriter.create(loc, ty, arg); + } + return arg; + }; + intMin = convertType(srcTy, intMin); + intMax = convertType(srcTy, intMax); + auto rounded = rewriter.create(loc, args[0]); auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 6483e29e7a9c2..70d09cde7bc7f 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -270,6 +270,17 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () { // CHECK: arith.extf %0 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xf32> + // CHECK: linalg.generic + // CHECK: %[[C_LOWEST:.+]] = arith.constant -2.14748365E+9 + // CHECK: %[[C_MAX:.+]] = arith.constant 2.14748365E+9 + // CHECK: arith.truncf %[[C_LOWEST]] : f32 to f16 + // CHECK: arith.truncf %[[C_MAX]] : f32 to f16 + // CHECK: math.roundeven + // CHECK: arith.minf + // CHECK: arith.maxf + // CHECK: arith.fptosi + %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi32> + return }