From 0f197304a5e55e3cb272bdc82d75e7a7a34e019d Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Mon, 7 Apr 2025 16:40:31 +0100 Subject: [PATCH] Add support for folding tosa.slice with tosa.slice --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 24 ++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 75 +++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 8208caf0983b4..ea61718428477 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1390,6 +1390,30 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { } OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { + const auto tryFoldWithPrecedingSlice = [this](FoldAdaptor adaptor) { + auto precedingSliceOp = getInput1().getDefiningOp(); + if (!precedingSliceOp) + return failure(); + const auto precedingSliceStart = precedingSliceOp.getStart(); + const auto thisSliceStart = getStart(); + SmallVector newSliceStart; + newSliceStart.reserve(precedingSliceStart.size()); + for (auto [startPreceding, startThis] : + llvm::zip_equal(precedingSliceStart, thisSliceStart)) { + newSliceStart.push_back(startPreceding + startThis); + } + setOperand(precedingSliceOp->getOperand(0)); + setStart(newSliceStart); + getOperation()->setLoc( + FusedLoc::get(getContext(), {precedingSliceOp->getLoc(), getLoc()})); + return success(); + }; + + // First try folding the preceding slice, this also works if the shapes are + // dynamic + if (succeeded(tryFoldWithPrecedingSlice(adaptor))) + return getResult(); + auto inputTy = llvm::dyn_cast(getInput1().getType()); auto outputTy = llvm::dyn_cast(getType()); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 02409a3b1827d..91e6f2439f81b 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -693,6 +693,81 @@ func.func @slice_nofold(%arg0: tensor) -> tensor { // ----- +// CHECK-LABEL: @slice_fuse +func.func @slice_fuse(%arg0: tensor<3x4xf32>) -> tensor<1x2xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x2xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<3x4xf32>) -> tensor<1x2xf32> +// CHECK: return [[VAR_0_]] : tensor<1x2xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<3x4xf32>) -> tensor<2x3xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<2x3xf32>) -> tensor<1x2xf32> + return %1 : tensor<1x2xf32> +} + +// ----- + +// CHECK-LABEL: @slice_fuse_different_step +func.func @slice_fuse_different_step(%arg0: tensor<3x4xf32>) -> tensor<1x1xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x1xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<3x4xf32>) -> tensor<1x1xf32> +// CHECK: return [[VAR_0_]] : tensor<1x1xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<3x4xf32>) -> tensor<1x3xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<1x3xf32>) -> tensor<1x1xf32> + return %1 : tensor<1x1xf32> +} + +// ----- + +// CHECK-LABEL: @slice_fuse_different_start +func.func @slice_fuse_different_start(%arg0: tensor<3x4xf32>) -> tensor<1x1xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x1xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<3x4xf32>) -> tensor<1x1xf32> +// CHECK: return [[VAR_0_]] : tensor<1x1xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<3x4xf32>) -> tensor<1x3xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<1x3xf32>) -> tensor<1x1xf32> + return %1 : tensor<1x1xf32> +} + +// ----- + +// CHECK-LABEL: @slice_fuse_different_start_2 +func.func @slice_fuse_different_start_2(%arg0: tensor<10x10xf32>) -> tensor<1x1xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<1x1xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<10x10xf32>) -> tensor<1x1xf32> +// CHECK: return [[VAR_0_]] : tensor<1x1xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<10x10xf32>) -> tensor<5x5xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<5x5xf32>) -> tensor<3x3xf32> + %2 = tosa.slice %1 { size = array, start = array}: (tensor<3x3xf32>) -> tensor<1x1xf32> + return %2 : tensor<1x1xf32> +} + +// ----- + +// CHECK-LABEL: @slice_fuse_different_start_3 +func.func @slice_fuse_different_start_3(%arg0: tensor<10x10xf32>) -> tensor<1x1xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<1x1xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<10x10xf32>) -> tensor<1x1xf32> +// CHECK: return [[VAR_0_]] : tensor<1x1xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<10x10xf32>) -> tensor<5x5xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<5x5xf32>) -> tensor<3x3xf32> + %2 = tosa.slice %1 { size = array, start = array}: (tensor<3x3xf32>) -> tensor<1x1xf32> + return %2 : tensor<1x1xf32> +} + +// ----- + +// CHECK-LABEL: func.func @slice_fuse_different_start_dynamic +func.func @slice_fuse_different_start_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<*xf32>) -> tensor<*xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<*xf32>) -> tensor<*xf32> + %2 = tosa.slice %1 { size = array, start = array}: (tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: @tile_fold func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: return %arg0