diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 0b4bd7d27badc..8ad2eef978323 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -38,8 +38,6 @@ class Tosa_QuantizedType params, bit signed> // Used to express accumulator results or compare results. //===----------------------------------------------------------------------===// -def Tosa_UInt8 : UI<8>; - def Tosa_Int8 : I<8>; def Tosa_Int16 : I<16>; def Tosa_Int32 : I<32>; @@ -54,9 +52,11 @@ def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8, def Tosa_Bool : I<1>; -// No unsigned unquantized int types. def Tosa_Int : AnyTypeOf<[Tosa_Bool, - Tosa_UInt8, + AnyUnsignedInteger, + AnySignlessInteger, +// TODO: For backwards compatibility, keep Tosa_SignedInt, which is actually +// a set of signless types. Tosa_SignedInt]>; def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 251be563a0c8f..75f82e737798a 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -469,7 +469,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, args.front(), zero); } - if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { + if (dstTy.isSignlessInteger() && + arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { auto intMin = rewriter.create( loc, rewriter.getF32FloatAttr( APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) @@ -487,6 +488,30 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, return rewriter.create(loc, dstTy, clamped); } + if (dstTy.isUnsignedInteger() && + arith::FPToUIOp::areCastCompatible(srcTy, dstTy)) { + auto intMin = rewriter.create( + loc, rewriter.getF32FloatAttr( + APInt::getMinValue(dstTy.getIntOrFloatBitWidth()) + .getZExtValue())); + + auto intMax = rewriter.create( + loc, rewriter.getF32FloatAttr( + APInt::getMaxValue(dstTy.getIntOrFloatBitWidth()) + .getZExtValue())); + + auto rounded = rewriter.create(loc, args[0]); + + auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter); + + auto cast = rewriter.create( + loc, rewriter.getIntegerType(dstTy.getIntOrFloatBitWidth()), clamped); + // arith is signless, so temporarily cast back to being unsigned. + return rewriter + .create(loc, dstTy, cast->getResult(0)) + .getResult(0); + } + // Casting to boolean, integers need to only be checked as not-equal to // zero. if (srcTy.isa() && dstTy.isInteger(1)) { diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-i2.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-i2.mlir new file mode 100644 index 0000000000000..b9001d807194a --- /dev/null +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-i2.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics -o -| FileCheck %s + +func.func @test_cast(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK: linalg.generic + // CHECK: arith.constant -2.000000e+00 + // CHECK: arith.constant 1.000000e+00 + // CHECK: math.roundeven + // CHECK: arith.minf + // CHECK: arith.maxf + // CHECK: arith.fptosi + %1 = "tosa.cast"(%arg0) : (tensor<1xf32>) -> tensor<1xi2> + + // CHECK: linalg.generic + // CHECK: arith.sitofp + %2 = "tosa.cast"(%1) : (tensor<1xi2>) -> tensor<1xf32> + + return %2 : tensor<1xf32> +} + +// ----- diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named-i2.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named-i2.mlir new file mode 100644 index 0000000000000..0d0686dc45fa2 --- /dev/null +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named-i2.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s + +// CHECK-LABEL: @matmul +func.func @matmul(%arg0: tensor<1x5x3xi2>, %arg1: tensor<1x3x6xi2>) -> (tensor<1x5x6xi2>) { + // CHECK: [[C0:%.+]] = arith.constant 0 : i2 + // CHECK: [[INIT:%.+]] = tensor.empty() + // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : i2) outs([[INIT]] : tensor<1x5x6xi2>) -> tensor<1x5x6xi2> + // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xi2>, tensor<1x3x6xi2>) outs([[FILLED]] : tensor<1x5x6xi2>) -> tensor<1x5x6xi2> + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xi2>, tensor<1x3x6xi2>) -> (tensor<1x5x6xi2>) + return %0 : tensor<1x5x6xi2> +} + +// ----- + +// CHECK-LABEL: @matmul +func.func @matmul(%arg0: tensor<1x5x3xi2>, %arg1: tensor<1x3x6xi2>) -> (tensor<1x5x6xi4>) { + // CHECK: [[C0:%.+]] = arith.constant 0 : i4 + // CHECK: [[INIT:%.+]] = tensor.empty() + // CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : i4) outs([[INIT]] : tensor<1x5x6xi4>) -> tensor<1x5x6xi4> + // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xi2>, tensor<1x3x6xi2>) outs([[FILLED]] : tensor<1x5x6xi4>) -> tensor<1x5x6xi4> + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xi2>, tensor<1x3x6xi2>) -> (tensor<1x5x6xi4>) + return %0 : tensor<1x5x6xi4> +} + +// ----- diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-ui3.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-ui3.mlir new file mode 100644 index 0000000000000..051a4fb5a1092 --- /dev/null +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-ui3.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s | FileCheck %s + +func.func @test_cast(%arg0: tensor<1xf32>) -> tensor<1xui3> { + // CHECK: linalg.generic + // CHECK: arith.constant 0.000000e+00 + // CHECK: arith.constant 7.000000e+00 + // CHECK: math.roundeven + // CHECK: arith.minf + // CHECK: arith.maxf + // CHECK: arith.fptoui {{.*}} : f32 to i3 + // CHECK: builtin.unrealized_conversion_cast + %1 = "tosa.cast"(%arg0) : (tensor<1xf32>) -> tensor<1xui3> + + return %1 : tensor<1xui3> +} + \ No newline at end of file