diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index dfd4e154c9275..c2d4e20efd364 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -622,6 +622,18 @@ OpFoldResult CastOp::fold(ArrayRef operands) { if (getInput().getType() == getType()) return getInput(); + // cast-to-iN(cast-to-iM(x)) -> cast-to-iN(x) when N <= M + if (auto cast = getInput().getDefiningOp()) { + auto intermediateElTy = cast.getType().getElementType().dyn_cast(); + auto finalElTy = getType().getElementType().dyn_cast(); + if (intermediateElTy && finalElTy && + intermediateElTy.getSignedness() == finalElTy.getSignedness() && + intermediateElTy.getWidth() >= finalElTy.getWidth()) { + getInputMutable().assign(cast.getInput()); + return getResult(); + } + } + auto operand = operands[0].dyn_cast_or_null(); if (!operand) return {}; diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index a52764e2cc038..c1f87d83892dd 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -39,6 +39,32 @@ func.func @cast_nofold(%arg0: tensor) -> tensor { return %0 : tensor } +// CHECK-LABEL: @cast_fold_double +func.func @cast_fold_double(%arg0: tensor) -> tensor { + // CHECK: "tosa.cast"{{.*}} (tensor) -> tensor + %0 = "tosa.cast"(%arg0) : (tensor) -> tensor + %1 = "tosa.cast"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @cast_no_fold_double1 +func.func @cast_no_fold_double1(%arg0: tensor) -> tensor { + // CHECK: "tosa.cast"{{.*}} (tensor) -> tensor + // CHECK: "tosa.cast"{{.*}} (tensor) -> tensor + %0 = "tosa.cast"(%arg0) : (tensor) -> tensor + %1 = "tosa.cast"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @cast_no_fold_double2 +func.func @cast_no_fold_double2(%arg0: tensor) -> tensor { + // CHECK: "tosa.cast"{{.*}} (tensor) -> tensor + // CHECK: "tosa.cast"{{.*}} (tensor) -> tensor + %0 = "tosa.cast"(%arg0) : (tensor) -> tensor + %1 = "tosa.cast"(%0) : (tensor) -> tensor + return %1 : tensor +} + // CHECK-LABEL: @clamp_not_noop func.func @clamp_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { // CHECK: "tosa.clamp"