From 55654e920ddb05b6a60a44aa079bfe72cd4a7bfb Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 11 Mar 2025 11:27:33 +0000 Subject: [PATCH 1/6] Canonicalize 'self-concats' to tile --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 42 +++++++++++++++++++ mlir/test/Dialect/Tosa/fold_concats.mlir | 29 +++++++------ 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 732f794206cd8..518254eaf6919 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -60,9 +60,51 @@ struct ConcatOptimization : public OpRewritePattern { } }; +struct SelfConcatToTile : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ConcatOp concatOp, + PatternRewriter &rewriter) const override { + if (llvm::all_equal(concatOp->getUsers())) { + const auto concatUser = llvm::dyn_cast( + concatOp->getUses().begin()->getOwner()); + if (concatUser) { + // Try folding the concat into its consumer before rewriting it to a + // tile. + SmallVector replacementValues; + auto foldResult = rewriter.tryFold(concatUser, replacementValues); + if (foldResult.succeeded()) { + if (!replacementValues.empty()) { + rewriter.replaceOp(concatUser, replacementValues); + } + return success(); + } + } + } + + if (!llvm::all_equal(concatOp->getOperands())) { + return rewriter.notifyMatchFailure( + concatOp, "Requires all operands to be the same"); + } + const auto concatType = dyn_cast(concatOp.getType()); + if (!concatType || !concatType.hasRank()) { + return rewriter.notifyMatchFailure(concatOp, + "Requires concat to be ranked"); + } + SmallVector multiplies(concatType.getRank(), 1); + multiplies[concatOp.getAxis()] = concatOp->getNumOperands(); + auto tileOp = rewriter.createOrFold( + concatOp->getLoc(), concatOp.getType(), concatOp->getOperand(0), + multiplies); + rewriter.replaceOp(concatOp, {tileOp}); + return success(); + } +}; + void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + results.add(context); } struct SqrtReciprocalOptimization : public OpRewritePattern { diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir index ec54f27346c8b..e77aefcbe6353 100644 --- a/mlir/test/Dialect/Tosa/fold_concats.mlir +++ b/mlir/test/Dialect/Tosa/fold_concats.mlir @@ -5,10 +5,10 @@ func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { return %0 : tensor<1x2x7x7xf32> } -// CHECK-LABEL: func.func @single_concat( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { -// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> -// CHECK: return %[[VAL_1]] : tensor<1x2x7x7xf32> +// CHECK-LABEL: func.func @single_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// CHECK: return [[VAR_0_]] : tensor<1x2x7x7xf32> // CHECK: } // ----- @@ -19,11 +19,11 @@ func.func @concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf return %1 : tensor<2x2x7x7xf32> } -// CHECK-LABEL: func.func @concat_different_axis( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { -// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> -// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> -// CHECK: return %[[VAL_2]] : tensor<2x2x7x7xf32> +// CHECK-LABEL: func.func @concat_different_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.tile [[VAR_0_]] {multiples = array} : (tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> +// CHECK: return [[VAR_1_]] : tensor<2x2x7x7xf32> // CHECK: } // ----- @@ -84,10 +84,9 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x return %2 : tensor<1x4x8x8xf32> } -// CHECK-LABEL: func.func @partially_foldable( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x8x8xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> { -// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32> -// CHECK: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> -// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32> +// CHECK-LABEL: func.func @partially_foldable +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_1_]] {multiples = array} : (tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_0_]], [[VAR_0_]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<1x4x8x8xf32> // CHECK: } From 358df158ca539a42a1525e17325e114f8977792f Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 11 Mar 2025 12:42:55 +0000 Subject: [PATCH 2/6] Fold consecutive tosa.tiles --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 15 +++++++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 25 +++++++++++++++++++ mlir/test/Dialect/Tosa/fold_concats.mlir | 5 ++-- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 518254eaf6919..c754cd2de6a56 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1362,6 +1362,21 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) { bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; }); if (allOnes && getInput1().getType() == getType()) return getInput1(); + + if (auto inputTile = getInput1().getDefiningOp()) { + if (!inputTile->hasOneUse()) { + return {}; + } + llvm::SmallVector newMultiplies{getMultiples()}; + for (auto [idx, multiplier] : llvm::enumerate(inputTile.getMultiples())) { + newMultiplies[idx] *= multiplier; + } + setMultiples(newMultiplies); + setOperand(inputTile->getOperand(0)); + getOperation()->setLoc( + FusedLoc::get(getContext(), {inputTile->getLoc(), getLoc()})); + return getResult(); + } return {}; } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index b341a774442ba..fe20a58fe809a 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -691,6 +691,31 @@ func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // ----- +// CHECK-LABEL: func.func @tile_fold_consecutive +func.func @tile_fold_consecutive(%arg0: tensor<3x4xf32>) -> tensor<6x16xf32> { + // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<6x16xf32> { + // CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<3x4xf32>) -> tensor<6x16xf32> + // CHECK: return [[VAR_0_]] : tensor<6x16xf32> + %0 = tosa.tile %arg0 { multiples = array }: (tensor<3x4xf32>) -> tensor<3x8xf32> + %1 = tosa.tile %0 { multiples = array }: (tensor<3x8xf32>) -> tensor<6x16xf32> + return %1 : tensor<6x16xf32> +} + +// ----- + +// CHECK-LABEL: func.func @tile_no_fold_consecutive_multi_use +func.func @tile_no_fold_consecutive_multi_use(%arg0: tensor<3x4xf32>) -> (tensor<3x8xf32>, tensor<6x16xf32>) { + // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> (tensor<3x8xf32>, tensor<6x16xf32>) { + // CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<3x4xf32>) -> tensor<3x8xf32> + // CHECK: [[VAR_1_:%.+]] = tosa.tile [[VAR_0_]] {multiples = array} : (tensor<3x8xf32>) -> tensor<6x16xf32> + // CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<3x8xf32>, tensor<6x16xf32> + %0 = tosa.tile %arg0 { multiples = array }: (tensor<3x4xf32>) -> tensor<3x8xf32> + %1 = tosa.tile %0 { multiples = array }: (tensor<3x8xf32>) -> tensor<6x16xf32> + return %0, %1 : tensor<3x8xf32>, tensor<6x16xf32> +} + +// ----- + // CHECK-LABEL: @tile_nofold func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> { // CHECK: tosa.tile diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir index e77aefcbe6353..409088fd0f3ec 100644 --- a/mlir/test/Dialect/Tosa/fold_concats.mlir +++ b/mlir/test/Dialect/Tosa/fold_concats.mlir @@ -21,9 +21,8 @@ func.func @concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf // CHECK-LABEL: func.func @concat_different_axis // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { -// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> -// CHECK: [[VAR_1_:%.+]] = tosa.tile [[VAR_0_]] {multiples = array} : (tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> -// CHECK: return [[VAR_1_]] : tensor<2x2x7x7xf32> +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> +// CHECK: return [[VAR_0_]] : tensor<2x2x7x7xf32> // CHECK: } // ----- From eb3aed527f089f8d278aaa5e3608e3d7f509e63f Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 11 Mar 2025 14:03:22 +0000 Subject: [PATCH 3/6] Extend concat -> slice canonicalization to remove concat inputs if possible --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 56 ++++++----- mlir/test/Dialect/Tosa/canonicalize.mlir | 92 +++++++++++++++++++ 2 files changed, 126 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index c754cd2de6a56..c35ffef61468a 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -653,35 +653,47 @@ struct ConcatSliceOptimization : public OpRewritePattern { llvm::SmallVector sliceStart(sliceOp.getStart()); llvm::ArrayRef sliceSize = sliceOp.getSize(); - - // Validate slice on the concatenated axis. Slicing along this - // axis should span only one of the inputs to the concatenate - // operation. - std::optional replaceWithSlice; + llvm::SmallVector requiredConcatInputs; + int64_t processedOriginalConcatInputSize = 0; + int64_t droppedConcatInputSize = 0; for (auto input : inputs) { - auto inputType = dyn_cast(input.getType()); + const auto inputType = dyn_cast(input.getType()); if (!inputType || !inputType.hasStaticShape()) return rewriter.notifyMatchFailure( sliceOp, "concat input must be a static ranked tensor"); - - if (sliceStart[axis] >= 0 && - (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) { - replaceWithSlice = rewriter - .create( - sliceOp.getLoc(), sliceOp.getType(), input, - rewriter.getDenseI64ArrayAttr(sliceStart), - rewriter.getDenseI64ArrayAttr(sliceSize)) - .getResult(); - break; + if (processedOriginalConcatInputSize < + (sliceStart[axis] + sliceSize[axis]) && + (processedOriginalConcatInputSize + inputType.getDimSize(axis)) > + sliceStart[axis]) { + if (requiredConcatInputs.empty()) { + droppedConcatInputSize = processedOriginalConcatInputSize; + } + requiredConcatInputs.push_back(input); } - sliceStart[axis] -= inputType.getDimSize(axis); + processedOriginalConcatInputSize += inputType.getDimSize(axis); } - - if (!replaceWithSlice) + if (requiredConcatInputs.size() == concatOp->getNumOperands()) { return rewriter.notifyMatchFailure( - sliceOp, "corresponding concat input not found for slice"); - - rewriter.replaceOp(sliceOp, replaceWithSlice.value()); + sliceOp, "Could not reduce number of inputs to preceding concat"); + } + if (requiredConcatInputs.size() != 1 && !concatOp->hasOneUse()) { + return rewriter.notifyMatchFailure( + sliceOp, + "Preceding concat must have a single use"); // Do not introduce new + // concats + } + if (requiredConcatInputs.empty()) { + return rewriter.notifyMatchFailure( + sliceOp, "degenerate slice with zero sized dim in output"); + } + sliceStart[axis] -= droppedConcatInputSize; + auto newConcat = rewriter.create(concatOp->getLoc(), + requiredConcatInputs, axis); + auto newSlice = rewriter.create( + sliceOp->getLoc(), sliceOp.getType(), newConcat, + rewriter.getDenseI64ArrayAttr(sliceStart), + rewriter.getDenseI64ArrayAttr(sliceSize)); + rewriter.replaceOp(sliceOp, newSlice); return success(); } }; diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index fe20a58fe809a..5460a89a543c4 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -829,6 +829,98 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : // ----- +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_start_overlap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x4xf32>) -> tensor<1x12x12x2xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x2xf32> +func.func @canonicalize_concat_slice_partial_concat_start_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> + return %1 : tensor<1x12x12x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_end_overlap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x4xf32>) -> tensor<1x12x12x2xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x2xf32> +func.func @canonicalize_concat_slice_partial_concat_end_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> + return %1 : tensor<1x12x12x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_all_overlap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x4xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x4xf32> +func.func @canonicalize_concat_slice_partial_concat_all_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x4xf32> + return %1 : tensor<1x12x12x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_multi_use +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> (tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>) { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> +// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32> +func.func @canonicalize_concat_slice_partial_concat_multi_use(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> (tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>) { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> + return %0, %1 : tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_zero_dim +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x0xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x0xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x0xf32> +// CHECK: } +func.func @canonicalize_concat_slice_zero_dim(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x0xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x0xf32> + return %1 : tensor<1x12x12x0xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x120x12x10x16xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<1x120x24x20x30xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x120x24x20x30xf32>) -> tensor<1x120x12x10x16xf32> +// CHECK: return [[VAR_1_]] : tensor<1x120x12x10x16xf32> +func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10xf32>) -> tensor<1x120x12x10x16xf32> { + %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x120x12x10x16xf32> + return %1 : tensor<1x120x12x10x16xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice_multi_output +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> (tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>) { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x12x12x10x16xf32> +// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32> +func.func @canonicalize_tile_slice_multi_output(%arg0 : tensor<1x12x12x10x10xf32>) -> (tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>) { + %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x12x12x10x16xf32> + return %0, %1 : tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32> +} + +// ----- + // CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal func.func @canonicalize_optimize_sqrt_reciprocal(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> { // CHECK: %[[RSQRT:.*]] = tosa.rsqrt %arg{{.*}} : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> From 0aad4d94cb11c268ed76d8ba6a32da09be05b12a Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 11 Mar 2025 17:07:23 +0000 Subject: [PATCH 4/6] Add canonicalization pattern for tile -> slice to minimize the tile multipliers --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 70 ++++++++++++++++++- mlir/test/Dialect/Tosa/canonicalize.mlir | 27 +++++-- 2 files changed, 88 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index c35ffef61468a..fb52c0502db34 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -687,8 +687,8 @@ struct ConcatSliceOptimization : public OpRewritePattern { sliceOp, "degenerate slice with zero sized dim in output"); } sliceStart[axis] -= droppedConcatInputSize; - auto newConcat = rewriter.create(concatOp->getLoc(), - requiredConcatInputs, axis); + auto newConcat = rewriter.create( + concatOp->getLoc(), requiredConcatInputs, axis); auto newSlice = rewriter.create( sliceOp->getLoc(), sliceOp.getType(), newConcat, rewriter.getDenseI64ArrayAttr(sliceStart), @@ -698,9 +698,75 @@ struct ConcatSliceOptimization : public OpRewritePattern { } }; +/// This patterns adjust the multipliers of a tile followed by a slice to only +/// tile as much data as it is required by the slice +struct TileSliceOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, + PatternRewriter &rewriter) const override { + Value sliceInput = sliceOp.getInput1(); + auto tileOp = sliceInput.getDefiningOp(); + if (!tileOp) + return rewriter.notifyMatchFailure(sliceOp, + "slice input must be tile operation"); + if (!tileOp->hasOneUse()) + return rewriter.notifyMatchFailure( + sliceOp, "preceding tile must have a single use"); // Do not insert + // additional tiles + + const auto tileOpInputType = + dyn_cast(tileOp->getOperand(0).getType()); + if (!tileOpInputType || !tileOpInputType.hasStaticShape()) + return rewriter.notifyMatchFailure( + sliceOp, "input to preceding tile op must be a static ranked tensor"); + llvm::SmallVector requiredMultipliers; + llvm::SmallVector newTileStarts; + requiredMultipliers.reserve(tileOpInputType.getRank()); + newTileStarts.reserve(tileOpInputType.getRank()); + for (auto [axis, sliceStart, sliceSize] : + llvm::enumerate(sliceOp.getStart(), sliceOp.getSize())) { + if (sliceSize <= 0) { + return rewriter.notifyMatchFailure( + sliceOp, "degenerate slice with zero sized dim"); + } + const int64_t tileInputDimSize = tileOpInputType.getDimSize(axis); + const int64_t sliceOffsetInNewFirstTile = sliceStart % tileInputDimSize; + const int64_t sliceSizeInFirstTile = + std::min(tileInputDimSize - sliceOffsetInNewFirstTile, sliceSize); + assert(sliceSizeInFirstTile > 0); + const int64_t requiredMultiplierWithoutFirstTile = + llvm::divideCeil(sliceSize - sliceSizeInFirstTile, tileInputDimSize); + const int64_t requiredMultiplier = + requiredMultiplierWithoutFirstTile + (sliceSizeInFirstTile != 0); + assert(requiredMultiplier <= tileOp.getMultiples()[axis]); + requiredMultipliers.push_back(requiredMultiplier); + newTileStarts.push_back(sliceOffsetInNewFirstTile); + } + if (requiredMultipliers == tileOp.getMultiples()) + return rewriter.notifyMatchFailure( + sliceOp, "could not reduce multipliers in preceding tile"); + + llvm::SmallVector newTileShape(tileOpInputType.getShape()); + for (auto [newShape, multiplier] : + llvm::zip_equal(newTileShape, requiredMultipliers)) { + newShape *= multiplier; + } + auto newTile = rewriter.create( + tileOp->getLoc(), tileOpInputType.clone(newTileShape), + tileOp->getOperand(0), requiredMultipliers); + auto newSlice = rewriter.create( + sliceOp->getLoc(), sliceOp.getType(), newTile, + rewriter.getDenseI64ArrayAttr(newTileStarts), sliceOp.getSizeAttr()); + rewriter.replaceOp(sliceOp, newSlice); + return success(); + } +}; + void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + results.add(context); } struct MinToClampOptimization : public OpRewritePattern { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 5460a89a543c4..2bdbd44cb31be 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -896,14 +896,27 @@ func.func @canonicalize_concat_slice_zero_dim(%arg0 : tensor<1x12x12x2xf32>, %ar // ----- // CHECK-LABEL: func.func @canonicalize_tile_slice -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x120x12x10x16xf32> { -// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<1x120x24x20x30xf32> -// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x120x24x20x30xf32>) -> tensor<1x120x12x10x16xf32> -// CHECK: return [[VAR_1_]] : tensor<1x120x12x10x16xf32> -func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10xf32>) -> tensor<1x120x12x10x16xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x12x10x16x5xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x24x20x30x10xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x120x24x20x30x10xf32>) -> tensor<1x120x12x10x16x5xf32> +// CHECK: return [[VAR_1_]] : tensor<1x120x12x10x16x5xf32> +func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x12x10x16x5xf32> { + %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10x10xf32>) -> tensor<10x120x120x100x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100x100xf32>) -> tensor<1x120x12x10x16x5xf32> + return %1 : tensor<1x120x12x10x16x5xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice_zero_dim +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x0x12x10x16xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x0x12x10x16xf32> +// CHECK: return [[VAR_1_]] : tensor<1x0x12x10x16xf32> +func.func @canonicalize_tile_slice_zero_dim(%arg0 : tensor<1x12x12x10x10xf32>) -> tensor<1x0x12x10x16xf32> { %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> - %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x120x12x10x16xf32> - return %1 : tensor<1x120x12x10x16xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x0x12x10x16xf32> + return %1 : tensor<1x0x12x10x16xf32> } // ----- From b87be4ffe402e38b4c5ee533907d6722d445f16f Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Wed, 12 Mar 2025 13:11:37 +0000 Subject: [PATCH 5/6] Add check for canonicalization of slice(concat %a, %a)) = %a --- mlir/test/Dialect/Tosa/canonicalize.mlir | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 2bdbd44cb31be..3aae58de87518 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -908,6 +908,17 @@ func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10x10xf32>) -> tens // ----- +// CHECK-LABEL: func.func @canonicalize_self_concat_slice +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { +// CHECK: return [[PARAM_0_]] : tensor<1x2x3x4xf32> +func.func @canonicalize_self_concat_slice(%arg0 : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { + %0 = tosa.concat %arg0, %arg0 {axis = 3 : i32} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x8xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x2x3x8xf32>) -> tensor<1x2x3x4xf32> + return %1 : tensor<1x2x3x4xf32> +} + +// ----- + // CHECK-LABEL: func.func @canonicalize_tile_slice_zero_dim // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x0x12x10x16xf32> { // CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> From 2afdeee3e971a1e91714013f1a326c99737c0459 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Thu, 13 Mar 2025 08:14:28 +0000 Subject: [PATCH 6/6] Add test for complete folding of tile + slice --- mlir/test/Dialect/Tosa/canonicalize.mlir | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 3aae58de87518..64143f477c85d 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -908,6 +908,17 @@ func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10x10xf32>) -> tens // ----- +// CHECK-LABEL: func.func @canonicalize_tile_slice_fold +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10x10xf32>) -> tensor<1x12x12x10x10x10xf32> { +// CHECK: return [[PARAM_0_]] : tensor<1x12x12x10x10x10xf32> +func.func @canonicalize_tile_slice_fold(%arg0 : tensor<1x12x12x10x10x10xf32>) -> tensor<1x12x12x10x10x10xf32> { + %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10x10xf32>) -> tensor<10x120x120x100x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100x100xf32>) -> tensor<1x12x12x10x10x10xf32> + return %1 : tensor<1x12x12x10x10x10xf32> +} + +// ----- + // CHECK-LABEL: func.func @canonicalize_self_concat_slice // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { // CHECK: return [[PARAM_0_]] : tensor<1x2x3x4xf32>