diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 0ac000e2bd978..85c7f05b83d95 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1460,6 +1460,8 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [ Tosa_Tensor:$output ); + let hasFolder = 1; + let hasCanonicalizer = 1; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 97fb220b53290..8fdd6cf35500f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -494,6 +494,30 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, // Operator Folders. //===----------------------------------------------------------------------===// +static bool hasZeroSize(Type ty) { + auto ranked = dyn_cast(ty); + if (!ranked) + return false; + return any_of(ranked.getShape(), [](auto d) { return d == 0; }); +} + +OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { + /// Remove operands that have zero elements. + bool changed = false; + for (size_t i = 0; i < getInput1().size(); ) { + auto input = getInput1()[i]; + if (hasZeroSize(input.getType())) { + getInput1Mutable().erase(i); + changed = true; + } else { + ++i; + } + } + if (changed) + return getResult(); + return {}; +} + template DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy) { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 4ad6ce65655ed..e049bc164150c 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -86,6 +86,13 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { return %1 : tensor<4xi8> } +// CHECK-LABEL: @concat_fold_zero +func.func @concat_fold_zero(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "tosa.concat"(%arg1, %arg2) <{axis = 1 : i64}> + %0 = "tosa.concat"(%arg0, %arg1, %arg2) {axis = 1 : i64}: (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + // CHECK-LABEL: @concat_fold func.func @concat_fold(%arg0: tensor) -> tensor { // CHECK: return %arg0