From 5b9793e68f578bd6a4777dbb362f0ff47bd5f8ba Mon Sep 17 00:00:00 2001 From: Dominik Montada Date: Mon, 15 May 2023 09:51:30 +0000 Subject: [PATCH 1/3] feat(TosaCanonicalizations): FXML-1981 fold consecutive concats on same axis --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 46 ++++++++- mlir/test/Dialect/Tosa/fold_concats.mlir | 93 +++++++++++++++++++ 2 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/Tosa/fold_concats.mlir diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 506d159fa2ac3..502fb3938c2f5 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -57,9 +57,53 @@ struct ConcatOptimization : public OpRewritePattern { } }; +struct ConcatFolding : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ConcatOp op, + PatternRewriter &rewriter) const override { + // Fold consecutive concats on the same axis into a single op. + uint64_t axis = op.getAxis(); + + // Keep track of the operands so we are able to construct a new concat + // later. Conservatively assume that we double the number of operands when + // folding + SmallVector concatOperands; + concatOperands.reserve(2 * op->getNumOperands()); + + // Find all operands that are foldable concats + bool canFold = false; + for (Value operand : op->getOperands()) { + concatOperands.emplace_back(operand); + + auto producer = dyn_cast_or_null(operand.getDefiningOp()); + if (!producer) + continue; + + // Foldable if axis is the same + if (axis != producer.getAxis()) + continue; + + // Replace the original operand with all incoming operands + canFold = true; + concatOperands.pop_back(); + llvm::append_range(concatOperands, producer->getOperands()); + } + + if (!canFold) + return rewriter.notifyMatchFailure(op, "No foldable concats found"); + + // Replace the original concat with a new one that contains the original and + // folded operands + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + concatOperands, axis); + return success(); + } +}; + void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } struct ReshapeReshapeOptimization : public OpRewritePattern { diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir new file mode 100644 index 0000000000000..c74e5bf0d2792 --- /dev/null +++ b/mlir/test/Dialect/Tosa/fold_concats.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s + +func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + return %0 : tensor<1x2x7x7xf32> +} + +// CHECK-LABEL: func.func @single_concat( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { +// CHECK: %[[VAL_1:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]]) {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// CHECK: return %[[VAL_1]] : tensor<1x2x7x7xf32> +// CHECK: } + +// ----- + +func.func @concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %1 = "tosa.concat"(%0, %0) {axis = 0} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> + return %1 : tensor<2x2x7x7xf32> +} + +// CHECK-LABEL: func.func @concat_different_axis( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { +// CHECK: %[[VAL_1:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]]) {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_1]]) {axis = 0 : i64} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> +// CHECK: return %[[VAL_2]] : tensor<2x2x7x7xf32> +// CHECK: } + +// ----- + +func.func @fold_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { + %tmp = tensor.empty() : tensor<1x1x7x7xf32> + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %1 = "tosa.concat"(%tmp, %0, %tmp) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> + return %1 : tensor<1x4x7x7xf32> +} + +// CHECK-LABEL: func.func @fold_concats( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { +// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]) {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> +// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32> +// CHECK: } + +// ----- + +func.func @nested_fold(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> { + %tmp = tensor.empty() : tensor<1x1x7x7xf32> + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %1 = "tosa.concat"(%tmp, %0, %tmp) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> + %2 = "tosa.concat"(%1, %1) {axis = 1} : (tensor<1x4x7x7xf32>, tensor<1x4x7x7xf32>) -> tensor<1x8x7x7xf32> + return %2 : tensor<1x8x7x7xf32> +} + +// CHECK-LABEL: func.func @nested_fold( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> { +// CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<1x1x7x7xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]], %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]) {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> +// CHECK: return %[[VAL_2]] : tensor<1x8x7x7xf32> +// CHECK: } + +// ----- + +func.func @wide_fold(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %1 = "tosa.concat"(%arg1, %arg1) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %2 = "tosa.concat"(%0, %1) {axis = 1} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> + return %2 : tensor<1x4x7x7xf32> +} + +// CHECK-LABEL: func.func @wide_fold( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_1]]) {axis = 1 : i64} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> +// CHECK: return %[[VAL_2]] : tensor<1x4x7x7xf32> +// CHECK: } + +// ----- + +func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> { + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>) -> tensor<1x2x8x8xf32> + %1 = "tosa.concat"(%arg1, %arg1) {axis = 2} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32> + %2 = "tosa.concat"(%0, %1) {axis = 1} : (tensor<1x2x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> + return %2 : tensor<1x4x8x8xf32> +} + +// CHECK-LABEL: func.func @partially_foldable( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x8x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> { +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_1]]) {axis = 2 : i64} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_0]], %[[VAL_2]]) {axis = 1 : i64} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> +// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32> +// CHECK: } From 8e461e261e8a326f599edbfcc77703ab1dc13ea3 Mon Sep 17 00:00:00 2001 From: Dominik Montada Date: Tue, 16 May 2023 13:45:26 +0000 Subject: [PATCH 2/3] refactor(TosaOps): FXML-1981 use hasFolder = 1 for concat folding --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 + .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 80 ++++++++----------- mlir/test/Dialect/Tosa/fold_concats.mlir | 2 +- 3 files changed, 37 insertions(+), 46 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index a278c257c7779..a2ebde90299cc 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1383,6 +1383,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [ ); let hasCanonicalizer = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 502fb3938c2f5..78148a9631d63 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -57,53 +57,9 @@ struct ConcatOptimization : public OpRewritePattern { } }; -struct ConcatFolding : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::ConcatOp op, - PatternRewriter &rewriter) const override { - // Fold consecutive concats on the same axis into a single op. - uint64_t axis = op.getAxis(); - - // Keep track of the operands so we are able to construct a new concat - // later. Conservatively assume that we double the number of operands when - // folding - SmallVector concatOperands; - concatOperands.reserve(2 * op->getNumOperands()); - - // Find all operands that are foldable concats - bool canFold = false; - for (Value operand : op->getOperands()) { - concatOperands.emplace_back(operand); - - auto producer = dyn_cast_or_null(operand.getDefiningOp()); - if (!producer) - continue; - - // Foldable if axis is the same - if (axis != producer.getAxis()) - continue; - - // Replace the original operand with all incoming operands - canFold = true; - concatOperands.pop_back(); - llvm::append_range(concatOperands, producer->getOperands()); - } - - if (!canFold) - return rewriter.notifyMatchFailure(op, "No foldable concats found"); - - // Replace the original concat with a new one that contains the original and - // folded operands - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), - concatOperands, axis); - return success(); - } -}; - void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } struct ReshapeReshapeOptimization : public OpRewritePattern { @@ -1039,3 +995,37 @@ OpFoldResult TransposeOp::fold(ArrayRef operands) { return getInput1(); return {}; } + +OpFoldResult ConcatOp::fold(ArrayRef operands) { + // Fold consecutive concats on the same axis into a single op. + // Keep track of the operands so we are able to construct a new concat + // later. Conservatively assume that we double the number of operands when + // folding + SmallVector concatOperands; + concatOperands.reserve(2 * getNumOperands()); + + // Find all operands that are foldable concats + bool canFold = false; + for (Value operand : getOperands()) { + concatOperands.emplace_back(operand); + + auto producer = dyn_cast_or_null(operand.getDefiningOp()); + if (!producer) + continue; + + // Foldable if axis is the same + if (getAxis() != producer.getAxis()) + continue; + + // Replace the original operand with all incoming operands + canFold = true; + concatOperands.pop_back(); + llvm::append_range(concatOperands, producer->getOperands()); + } + + if (!canFold) + return {}; + + getOperation()->setOperands(concatOperands); + return getResult(); +} diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir index c74e5bf0d2792..2b1cd891a33b2 100644 --- a/mlir/test/Dialect/Tosa/fold_concats.mlir +++ b/mlir/test/Dialect/Tosa/fold_concats.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s +// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { %0 = "tosa.concat"(%arg0, %arg0) {axis = 1} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> From 5a529a49ff07ba3da6653998bf65adbfb6101c04 Mon Sep 17 00:00:00 2001 From: Dominik Montada Date: Fri, 19 May 2023 09:46:40 +0000 Subject: [PATCH 3/3] chore(TosaCanonicalizations): FXML-1981 improve wording in comment --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 78148a9631d63..adc79c7c39bbb 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1013,7 +1013,7 @@ OpFoldResult ConcatOp::fold(ArrayRef operands) { if (!producer) continue; - // Foldable if axis is the same + // Foldable if axes are the same if (getAxis() != producer.getAxis()) continue;