Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/include/llvm/ADT/APInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class [[nodiscard]] APInt {
/// \param implicitTrunc allow implicit truncation of non-zero/sign bits of
/// val beyond the range of numBits
APInt(unsigned numBits, uint64_t val, bool isSigned = false,
bool implicitTrunc = true)
bool implicitTrunc = false)
: BitWidth(numBits) {
if (!implicitTrunc) {
if (isSigned) {
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/PDL/IR/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ LogicalResult static unaryOp(PatternRewriter &rewriter, PDLResultList &results,
? std::pow(2, operandIntAttr.getValue().getZExtValue())
: std::pow(2, operandIntAttr.getValue().getSExtValue());

APInt resultInt(bitWidth, resultVal, integerType.isSigned());
APInt resultInt(bitWidth, resultVal, integerType.isSigned(),
/*implicitTrunc*/ true);

bool isOverflow = integerType.isSigned()
? resultInt.slt(operandIntAttr.getValue())
Expand Down
54 changes: 47 additions & 7 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,43 @@ struct SelectToClampOptimization : public OpRewritePattern<tosa::SelectOp> {
return rewriter.notifyMatchFailure(
op, "RHS of predicate GreaterEqualOp is not a constant");
}

auto isCompatibleSplat = [](DenseElementsAttr a,
DenseElementsAttr b) -> bool {
if (!a.isSplat() || !b.isSplat()) {
return false;
}
if (llvm::isa<IntegerType>(a.getElementType())) {
return a.getSplatValue<APInt>() == b.getSplatValue<APInt>();

auto aAsIntegerType = dyn_cast<IntegerType>(a.getElementType());
auto bAsIntegerType = dyn_cast<IntegerType>(b.getElementType());
if (aAsIntegerType && bAsIntegerType) {
if (aAsIntegerType.getSignedness() != bAsIntegerType.getSignedness()) {
return false;
}

auto aAsAPInt = a.getSplatValue<APInt>();
auto bAsAPInt = b.getSplatValue<APInt>();

const size_t aBitWidth = aAsAPInt.getBitWidth();
const size_t bBitWidth = bAsAPInt.getBitWidth();

if (aBitWidth >= bBitWidth) {
return aAsAPInt == (bAsIntegerType.isUnsigned()
? bAsAPInt.zext(aBitWidth)
: bAsAPInt.sext(aBitWidth));
}
return (aAsIntegerType.isUnsigned()
? aAsAPInt.zext(bBitWidth)
: aAsAPInt.sext(bBitWidth)) == bAsAPInt;
}
if (llvm::isa<FloatType>(a.getElementType())) {
return a.getSplatValue<APFloat>() == b.getSplatValue<APFloat>();

auto aAsFloatType = dyn_cast<FloatType>(a.getElementType());
auto bAsFloatType = dyn_cast<FloatType>(b.getElementType());
if (!aAsFloatType || aAsFloatType != bAsFloatType) {
return false;
}
return false; // Only int and float types are supported

return a.getSplatValue<APFloat>() == b.getSplatValue<APFloat>();
};

auto onFalse = op.getOnFalse();
Expand Down Expand Up @@ -237,10 +262,25 @@ struct SelectToClampOptimization : public OpRewritePattern<tosa::SelectOp> {
clampFloatMax = rewriter.getFloatAttr(inputElementType, splatValue);
}
}

Value input = geq.getInput1();

// In case they do not have same bit width, insert a cast to still be able
// to do this canonicalization
const size_t geqBitWidth =
geq.getInput1().getType().getElementTypeBitWidth();
const size_t selectBitWidth = op.getType().getElementTypeBitWidth();
if (geqBitWidth != selectBitWidth) {
input = rewriter.create<tosa::CastOp>(
op->getLoc(),
geq.getInput1().getType().clone(op.getType().getElementType()),
input);
}

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), geq.getInput1(),
rewriter.getI64IntegerAttr(clampIntMin),
op, op.getType(), input, rewriter.getI64IntegerAttr(clampIntMin),
rewriter.getI64IntegerAttr(clampIntMax), clampFloatMin, clampFloatMax);

return success();
}
};
Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1540,9 +1540,11 @@ struct TosaFoldConstantMatMul

// Convert int64_t to the correct output type.
std::vector<APInt> apintValues;
llvm::transform(values, std::back_inserter(apintValues),
[&](const int64_t &val) {
APInt apIntVal(baseType.getIntOrFloatBitWidth(), val);
llvm::transform(
values, std::back_inserter(apintValues), [&](const int64_t &val) {
APInt apIntVal(baseType.getIntOrFloatBitWidth(), val,
/*isSigned=*/true); // tosa-mlir uses signless
// instead of signed
return apIntVal;
});
return DenseElementsAttr::get(outputType, apintValues);
Expand Down
56 changes: 56 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1147,3 +1147,59 @@ func.func @canonicalize_select_lrelu_zero_pattern(%arg0: tensor<13x21x3xf32>) ->
return %3 : tensor<13x21x3xf32>
}

// -----

// CHECK-LABEL: @canonicalize_select_to_clamp_i64_and_i8_pat1
func.func @canonicalize_select_to_clamp_i64_and_i8_pat1(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi8>
// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0xFF800000 : f32, min_int = 42 : i64} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
// CHECK: return %[[VAL_2]] : tensor<13x21x3xi8>
%0 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64>
%1 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8>
%2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi1>
%3 = tosa.select %2, %arg1, %1: ( tensor<13x21x3xi1>, tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
return %3 : tensor<13x21x3xi8>
}

// -----

// CHECK-LABEL: @canonicalize_select_to_clamp_i64_and_i8_pat2
func.func @canonicalize_select_to_clamp_i64_and_i8_pat2(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi8>
// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = -42 : i64, min_fp = 0xFF800000 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
// CHECK: return %[[VAL_2]] : tensor<13x21x3xi8>
%0 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64>
%1 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8>
%2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi1>
%3 = tosa.select %2, %1, %arg1 : ( tensor<13x21x3xi1>, tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
return %3 : tensor<13x21x3xi8>
}

// -----

// CHECK-LABEL: @canonicalize_select_to_clamp_i8_and_i64_pat1
func.func @canonicalize_select_to_clamp_i8_and_i64_pat1(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi64>
// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0xFF800000 : f32, min_int = 42 : i64} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
// CHECK: return %[[VAL_2]] : tensor<13x21x3xi64>
%0 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8>
%1 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64>
%2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi1>
%3 = tosa.select %2, %arg1, %1: ( tensor<13x21x3xi1>, tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
return %3 : tensor<13x21x3xi64>
}

// -----

// CHECK-LABEL: @canonicalize_select_to_clamp_i8_and_i64_pat2
func.func @canonicalize_select_to_clamp_i8_and_i64_pat2(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi64>
// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = -42 : i64, min_fp = 0xFF800000 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
// CHECK: return %[[VAL_2]] : tensor<13x21x3xi64>
%0 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8>
%1 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64>
%2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi1>
%3 = tosa.select %2, %1, %arg1: ( tensor<13x21x3xi1>, tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
return %3 : tensor<13x21x3xi64>
}

Loading