diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h index 63a138527b32e..953b2a27b7152 100644 --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -109,7 +109,7 @@ class [[nodiscard]] APInt { /// \param implicitTrunc allow implicit truncation of non-zero/sign bits of /// val beyond the range of numBits APInt(unsigned numBits, uint64_t val, bool isSigned = false, - bool implicitTrunc = true) + bool implicitTrunc = false) : BitWidth(numBits) { if (!implicitTrunc) { if (isSigned) { diff --git a/mlir/lib/Dialect/PDL/IR/Builtins.cpp b/mlir/lib/Dialect/PDL/IR/Builtins.cpp index 770a390a3fe5f..9e4efbf7e71c0 100644 --- a/mlir/lib/Dialect/PDL/IR/Builtins.cpp +++ b/mlir/lib/Dialect/PDL/IR/Builtins.cpp @@ -63,7 +63,8 @@ LogicalResult static unaryOp(PatternRewriter &rewriter, PDLResultList &results, ? std::pow(2, operandIntAttr.getValue().getZExtValue()) : std::pow(2, operandIntAttr.getValue().getSExtValue()); - APInt resultInt(bitWidth, resultVal, integerType.isSigned()); + APInt resultInt(bitWidth, resultVal, integerType.isSigned(), + /*implicitTrunc*/ true); bool isOverflow = integerType.isSigned() ? resultInt.slt(operandIntAttr.getValue()) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 5e5c39507057a..732f794206cd8 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -155,18 +155,43 @@ struct SelectToClampOptimization : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "RHS of predicate GreaterEqualOp is not a constant"); } + auto isCompatibleSplat = [](DenseElementsAttr a, DenseElementsAttr b) -> bool { if (!a.isSplat() || !b.isSplat()) { return false; } - if (llvm::isa(a.getElementType())) { - return a.getSplatValue() == b.getSplatValue(); + + auto aAsIntegerType = dyn_cast(a.getElementType()); + auto bAsIntegerType = dyn_cast(b.getElementType()); + if (aAsIntegerType && bAsIntegerType) { + if (aAsIntegerType.getSignedness() != bAsIntegerType.getSignedness()) { + return false; + } + + auto aAsAPInt = a.getSplatValue(); + auto bAsAPInt = b.getSplatValue(); + + const size_t aBitWidth = aAsAPInt.getBitWidth(); + const size_t bBitWidth = bAsAPInt.getBitWidth(); + + if (aBitWidth >= bBitWidth) { + return aAsAPInt == (bAsIntegerType.isUnsigned() + ? bAsAPInt.zext(aBitWidth) + : bAsAPInt.sext(aBitWidth)); + } + return (aAsIntegerType.isUnsigned() + ? aAsAPInt.zext(bBitWidth) + : aAsAPInt.sext(bBitWidth)) == bAsAPInt; } - if (llvm::isa(a.getElementType())) { - return a.getSplatValue() == b.getSplatValue(); + + auto aAsFloatType = dyn_cast(a.getElementType()); + auto bAsFloatType = dyn_cast(b.getElementType()); + if (!aAsFloatType || aAsFloatType != bAsFloatType) { + return false; } - return false; // Only int and float types are supported + + return a.getSplatValue() == b.getSplatValue(); }; auto onFalse = op.getOnFalse(); @@ -237,10 +262,25 @@ struct SelectToClampOptimization : public OpRewritePattern { clampFloatMax = rewriter.getFloatAttr(inputElementType, splatValue); } } + + Value input = geq.getInput1(); + + // In case they do not have same bit width, insert a cast to still be able + // to do this canonicalization + const size_t geqBitWidth = + geq.getInput1().getType().getElementTypeBitWidth(); + const size_t selectBitWidth = op.getType().getElementTypeBitWidth(); + if (geqBitWidth != selectBitWidth) { + input = rewriter.create( + op->getLoc(), + geq.getInput1().getType().clone(op.getType().getElementType()), + input); + } + rewriter.replaceOpWithNewOp( - op, op.getType(), geq.getInput1(), - rewriter.getI64IntegerAttr(clampIntMin), + op, op.getType(), input, rewriter.getI64IntegerAttr(clampIntMin), rewriter.getI64IntegerAttr(clampIntMax), clampFloatMin, clampFloatMax); + return success(); } }; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index 63594d83ccc11..ea6f295ff2feb 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -1540,9 +1540,11 @@ struct TosaFoldConstantMatMul // Convert int64_t to the correct output type. std::vector apintValues; - llvm::transform(values, std::back_inserter(apintValues), - [&](const int64_t &val) { - APInt apIntVal(baseType.getIntOrFloatBitWidth(), val); + llvm::transform( + values, std::back_inserter(apintValues), [&](const int64_t &val) { + APInt apIntVal(baseType.getIntOrFloatBitWidth(), val, + /*isSigned=*/true); // tosa-mlir uses signless + // instead of signed return apIntVal; }); return DenseElementsAttr::get(outputType, apintValues); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 4e10833a775c3..b341a774442ba 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1147,3 +1147,59 @@ func.func @canonicalize_select_lrelu_zero_pattern(%arg0: tensor<13x21x3xf32>) -> return %3 : tensor<13x21x3xf32> } +// ----- + +// CHECK-LABEL: @canonicalize_select_to_clamp_i64_and_i8_pat1 +func.func @canonicalize_select_to_clamp_i64_and_i8_pat1(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { +// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi8> +// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0xFF800000 : f32, min_int = 42 : i64} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8> +// CHECK: return %[[VAL_2]] : tensor<13x21x3xi8> + %0 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64> + %1 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8> + %2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi1> + %3 = tosa.select %2, %arg1, %1: ( tensor<13x21x3xi1>, tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi8> + return %3 : tensor<13x21x3xi8> +} + +// ----- + +// CHECK-LABEL: @canonicalize_select_to_clamp_i64_and_i8_pat2 +func.func @canonicalize_select_to_clamp_i64_and_i8_pat2(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { +// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi8> +// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = -42 : i64, min_fp = 0xFF800000 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8> +// CHECK: return %[[VAL_2]] : tensor<13x21x3xi8> + %0 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64> + %1 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8> + %2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi1> + %3 = tosa.select %2, %1, %arg1 : ( tensor<13x21x3xi1>, tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi8> + return %3 : tensor<13x21x3xi8> +} + +// ----- + +// CHECK-LABEL: @canonicalize_select_to_clamp_i8_and_i64_pat1 +func.func @canonicalize_select_to_clamp_i8_and_i64_pat1(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { +// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi64> +// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0xFF800000 : f32, min_int = 42 : i64} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64> +// CHECK: return %[[VAL_2]] : tensor<13x21x3xi64> + %0 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8> + %1 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64> + %2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi1> + %3 = tosa.select %2, %arg1, %1: ( tensor<13x21x3xi1>, tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64> + return %3 : tensor<13x21x3xi64> +} + +// ----- + +// CHECK-LABEL: @canonicalize_select_to_clamp_i8_and_i64_pat2 +func.func @canonicalize_select_to_clamp_i8_and_i64_pat2(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { +// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi64> +// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = -42 : i64, min_fp = 0xFF800000 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64> +// CHECK: return %[[VAL_2]] : tensor<13x21x3xi64> + %0 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8> + %1 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64> + %2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi1> + %3 = tosa.select %2, %1, %arg1: ( tensor<13x21x3xi1>, tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64> + return %3 : tensor<13x21x3xi64> +} +