diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 732f794206cd8..fb52c0502db34 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -60,9 +60,51 @@ struct ConcatOptimization : public OpRewritePattern { } }; +struct SelfConcatToTile : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ConcatOp concatOp, + PatternRewriter &rewriter) const override { + if (llvm::all_equal(concatOp->getUsers())) { + const auto concatUser = llvm::dyn_cast( + concatOp->getUses().begin()->getOwner()); + if (concatUser) { + // Try folding the concat into its consumer before rewriting it to a + // tile. + SmallVector replacementValues; + auto foldResult = rewriter.tryFold(concatUser, replacementValues); + if (foldResult.succeeded()) { + if (!replacementValues.empty()) { + rewriter.replaceOp(concatUser, replacementValues); + } + return success(); + } + } + } + + if (!llvm::all_equal(concatOp->getOperands())) { + return rewriter.notifyMatchFailure( + concatOp, "Requires all operands to be the same"); + } + const auto concatType = dyn_cast(concatOp.getType()); + if (!concatType || !concatType.hasRank()) { + return rewriter.notifyMatchFailure(concatOp, + "Requires concat to be ranked"); + } + SmallVector multiplies(concatType.getRank(), 1); + multiplies[concatOp.getAxis()] = concatOp->getNumOperands(); + auto tileOp = rewriter.createOrFold( + concatOp->getLoc(), concatOp.getType(), concatOp->getOperand(0), + multiplies); + rewriter.replaceOp(concatOp, {tileOp}); + return success(); + } +}; + void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + results.add(context); } struct SqrtReciprocalOptimization : public OpRewritePattern { @@ -611,35 +653,112 @@ struct ConcatSliceOptimization : public OpRewritePattern { llvm::SmallVector sliceStart(sliceOp.getStart()); llvm::ArrayRef sliceSize = sliceOp.getSize(); - - // Validate slice on the concatenated axis. Slicing along this - // axis should span only one of the inputs to the concatenate - // operation. - std::optional replaceWithSlice; + llvm::SmallVector requiredConcatInputs; + int64_t processedOriginalConcatInputSize = 0; + int64_t droppedConcatInputSize = 0; for (auto input : inputs) { - auto inputType = dyn_cast(input.getType()); + const auto inputType = dyn_cast(input.getType()); if (!inputType || !inputType.hasStaticShape()) return rewriter.notifyMatchFailure( sliceOp, "concat input must be a static ranked tensor"); - - if (sliceStart[axis] >= 0 && - (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) { - replaceWithSlice = rewriter - .create( - sliceOp.getLoc(), sliceOp.getType(), input, - rewriter.getDenseI64ArrayAttr(sliceStart), - rewriter.getDenseI64ArrayAttr(sliceSize)) - .getResult(); - break; + if (processedOriginalConcatInputSize < + (sliceStart[axis] + sliceSize[axis]) && + (processedOriginalConcatInputSize + inputType.getDimSize(axis)) > + sliceStart[axis]) { + if (requiredConcatInputs.empty()) { + droppedConcatInputSize = processedOriginalConcatInputSize; + } + requiredConcatInputs.push_back(input); } - sliceStart[axis] -= inputType.getDimSize(axis); + processedOriginalConcatInputSize += inputType.getDimSize(axis); + } + if (requiredConcatInputs.size() == concatOp->getNumOperands()) { + return rewriter.notifyMatchFailure( + sliceOp, "Could not reduce number of inputs to preceding concat"); + } + if (requiredConcatInputs.size() != 1 && !concatOp->hasOneUse()) { + return rewriter.notifyMatchFailure( + sliceOp, + "Preceding concat must have a single use"); // Do not introduce new + // concats + } + if (requiredConcatInputs.empty()) { + return rewriter.notifyMatchFailure( + sliceOp, "degenerate slice with zero sized dim in output"); } + sliceStart[axis] -= droppedConcatInputSize; + auto newConcat = rewriter.create( + concatOp->getLoc(), requiredConcatInputs, axis); + auto newSlice = rewriter.create( + sliceOp->getLoc(), sliceOp.getType(), newConcat, + rewriter.getDenseI64ArrayAttr(sliceStart), + rewriter.getDenseI64ArrayAttr(sliceSize)); + rewriter.replaceOp(sliceOp, newSlice); + return success(); + } +}; + +/// This patterns adjust the multipliers of a tile followed by a slice to only +/// tile as much data as it is required by the slice +struct TileSliceOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, + PatternRewriter &rewriter) const override { + Value sliceInput = sliceOp.getInput1(); + auto tileOp = sliceInput.getDefiningOp(); + if (!tileOp) + return rewriter.notifyMatchFailure(sliceOp, + "slice input must be tile operation"); + if (!tileOp->hasOneUse()) + return rewriter.notifyMatchFailure( + sliceOp, "preceding tile must have a single use"); // Do not insert + // additional tiles - if (!replaceWithSlice) + const auto tileOpInputType = + dyn_cast(tileOp->getOperand(0).getType()); + if (!tileOpInputType || !tileOpInputType.hasStaticShape()) return rewriter.notifyMatchFailure( - sliceOp, "corresponding concat input not found for slice"); + sliceOp, "input to preceding tile op must be a static ranked tensor"); + llvm::SmallVector requiredMultipliers; + llvm::SmallVector newTileStarts; + requiredMultipliers.reserve(tileOpInputType.getRank()); + newTileStarts.reserve(tileOpInputType.getRank()); + for (auto [axis, sliceStart, sliceSize] : + llvm::enumerate(sliceOp.getStart(), sliceOp.getSize())) { + if (sliceSize <= 0) { + return rewriter.notifyMatchFailure( + sliceOp, "degenerate slice with zero sized dim"); + } + const int64_t tileInputDimSize = tileOpInputType.getDimSize(axis); + const int64_t sliceOffsetInNewFirstTile = sliceStart % tileInputDimSize; + const int64_t sliceSizeInFirstTile = + std::min(tileInputDimSize - sliceOffsetInNewFirstTile, sliceSize); + assert(sliceSizeInFirstTile > 0); + const int64_t requiredMultiplierWithoutFirstTile = + llvm::divideCeil(sliceSize - sliceSizeInFirstTile, tileInputDimSize); + const int64_t requiredMultiplier = + requiredMultiplierWithoutFirstTile + (sliceSizeInFirstTile != 0); + assert(requiredMultiplier <= tileOp.getMultiples()[axis]); + requiredMultipliers.push_back(requiredMultiplier); + newTileStarts.push_back(sliceOffsetInNewFirstTile); + } + if (requiredMultipliers == tileOp.getMultiples()) + return rewriter.notifyMatchFailure( + sliceOp, "could not reduce multipliers in preceding tile"); - rewriter.replaceOp(sliceOp, replaceWithSlice.value()); + llvm::SmallVector newTileShape(tileOpInputType.getShape()); + for (auto [newShape, multiplier] : + llvm::zip_equal(newTileShape, requiredMultipliers)) { + newShape *= multiplier; + } + auto newTile = rewriter.create( + tileOp->getLoc(), tileOpInputType.clone(newTileShape), + tileOp->getOperand(0), requiredMultipliers); + auto newSlice = rewriter.create( + sliceOp->getLoc(), sliceOp.getType(), newTile, + rewriter.getDenseI64ArrayAttr(newTileStarts), sliceOp.getSizeAttr()); + rewriter.replaceOp(sliceOp, newSlice); return success(); } }; @@ -647,6 +766,7 @@ struct ConcatSliceOptimization : public OpRewritePattern { void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + results.add(context); } struct MinToClampOptimization : public OpRewritePattern { @@ -1320,6 +1440,21 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) { bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; }); if (allOnes && getInput1().getType() == getType()) return getInput1(); + + if (auto inputTile = getInput1().getDefiningOp()) { + if (!inputTile->hasOneUse()) { + return {}; + } + llvm::SmallVector newMultiplies{getMultiples()}; + for (auto [idx, multiplier] : llvm::enumerate(inputTile.getMultiples())) { + newMultiplies[idx] *= multiplier; + } + setMultiples(newMultiplies); + setOperand(inputTile->getOperand(0)); + getOperation()->setLoc( + FusedLoc::get(getContext(), {inputTile->getLoc(), getLoc()})); + return getResult(); + } return {}; } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index b341a774442ba..64143f477c85d 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -691,6 +691,31 @@ func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // ----- +// CHECK-LABEL: func.func @tile_fold_consecutive +func.func @tile_fold_consecutive(%arg0: tensor<3x4xf32>) -> tensor<6x16xf32> { + // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<6x16xf32> { + // CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<3x4xf32>) -> tensor<6x16xf32> + // CHECK: return [[VAR_0_]] : tensor<6x16xf32> + %0 = tosa.tile %arg0 { multiples = array }: (tensor<3x4xf32>) -> tensor<3x8xf32> + %1 = tosa.tile %0 { multiples = array }: (tensor<3x8xf32>) -> tensor<6x16xf32> + return %1 : tensor<6x16xf32> +} + +// ----- + +// CHECK-LABEL: func.func @tile_no_fold_consecutive_multi_use +func.func @tile_no_fold_consecutive_multi_use(%arg0: tensor<3x4xf32>) -> (tensor<3x8xf32>, tensor<6x16xf32>) { + // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> (tensor<3x8xf32>, tensor<6x16xf32>) { + // CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<3x4xf32>) -> tensor<3x8xf32> + // CHECK: [[VAR_1_:%.+]] = tosa.tile [[VAR_0_]] {multiples = array} : (tensor<3x8xf32>) -> tensor<6x16xf32> + // CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<3x8xf32>, tensor<6x16xf32> + %0 = tosa.tile %arg0 { multiples = array }: (tensor<3x4xf32>) -> tensor<3x8xf32> + %1 = tosa.tile %0 { multiples = array }: (tensor<3x8xf32>) -> tensor<6x16xf32> + return %0, %1 : tensor<3x8xf32>, tensor<6x16xf32> +} + +// ----- + // CHECK-LABEL: @tile_nofold func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> { // CHECK: tosa.tile @@ -804,6 +829,133 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : // ----- +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_start_overlap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x4xf32>) -> tensor<1x12x12x2xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x2xf32> +func.func @canonicalize_concat_slice_partial_concat_start_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> + return %1 : tensor<1x12x12x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_end_overlap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x4xf32>) -> tensor<1x12x12x2xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x2xf32> +func.func @canonicalize_concat_slice_partial_concat_end_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> + return %1 : tensor<1x12x12x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_all_overlap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x4xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x4xf32> +func.func @canonicalize_concat_slice_partial_concat_all_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x4xf32> + return %1 : tensor<1x12x12x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_multi_use +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> (tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>) { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> +// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32> +func.func @canonicalize_concat_slice_partial_concat_multi_use(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> (tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>) { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> + return %0, %1 : tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_zero_dim +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x0xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x0xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x0xf32> +// CHECK: } +func.func @canonicalize_concat_slice_zero_dim(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x0xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x0xf32> + return %1 : tensor<1x12x12x0xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x12x10x16x5xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x24x20x30x10xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x120x24x20x30x10xf32>) -> tensor<1x120x12x10x16x5xf32> +// CHECK: return [[VAR_1_]] : tensor<1x120x12x10x16x5xf32> +func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x12x10x16x5xf32> { + %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10x10xf32>) -> tensor<10x120x120x100x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100x100xf32>) -> tensor<1x120x12x10x16x5xf32> + return %1 : tensor<1x120x12x10x16x5xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice_fold +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10x10xf32>) -> tensor<1x12x12x10x10x10xf32> { +// CHECK: return [[PARAM_0_]] : tensor<1x12x12x10x10x10xf32> +func.func @canonicalize_tile_slice_fold(%arg0 : tensor<1x12x12x10x10x10xf32>) -> tensor<1x12x12x10x10x10xf32> { + %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10x10xf32>) -> tensor<10x120x120x100x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100x100xf32>) -> tensor<1x12x12x10x10x10xf32> + return %1 : tensor<1x12x12x10x10x10xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_self_concat_slice +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { +// CHECK: return [[PARAM_0_]] : tensor<1x2x3x4xf32> +func.func @canonicalize_self_concat_slice(%arg0 : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { + %0 = tosa.concat %arg0, %arg0 {axis = 3 : i32} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x8xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x2x3x8xf32>) -> tensor<1x2x3x4xf32> + return %1 : tensor<1x2x3x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice_zero_dim +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x0x12x10x16xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x0x12x10x16xf32> +// CHECK: return [[VAR_1_]] : tensor<1x0x12x10x16xf32> +func.func @canonicalize_tile_slice_zero_dim(%arg0 : tensor<1x12x12x10x10xf32>) -> tensor<1x0x12x10x16xf32> { + %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x0x12x10x16xf32> + return %1 : tensor<1x0x12x10x16xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice_multi_output +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> (tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>) { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x12x12x10x16xf32> +// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32> +func.func @canonicalize_tile_slice_multi_output(%arg0 : tensor<1x12x12x10x10xf32>) -> (tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>) { + %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x12x12x10x16xf32> + return %0, %1 : tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32> +} + +// ----- + // CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal func.func @canonicalize_optimize_sqrt_reciprocal(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> { // CHECK: %[[RSQRT:.*]] = tosa.rsqrt %arg{{.*}} : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir index ec54f27346c8b..409088fd0f3ec 100644 --- a/mlir/test/Dialect/Tosa/fold_concats.mlir +++ b/mlir/test/Dialect/Tosa/fold_concats.mlir @@ -5,10 +5,10 @@ func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { return %0 : tensor<1x2x7x7xf32> } -// CHECK-LABEL: func.func @single_concat( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { -// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> -// CHECK: return %[[VAL_1]] : tensor<1x2x7x7xf32> +// CHECK-LABEL: func.func @single_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// CHECK: return [[VAR_0_]] : tensor<1x2x7x7xf32> // CHECK: } // ----- @@ -19,11 +19,10 @@ func.func @concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf return %1 : tensor<2x2x7x7xf32> } -// CHECK-LABEL: func.func @concat_different_axis( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { -// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> -// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> -// CHECK: return %[[VAL_2]] : tensor<2x2x7x7xf32> +// CHECK-LABEL: func.func @concat_different_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> +// CHECK: return [[VAR_0_]] : tensor<2x2x7x7xf32> // CHECK: } // ----- @@ -84,10 +83,9 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x return %2 : tensor<1x4x8x8xf32> } -// CHECK-LABEL: func.func @partially_foldable( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x8x8xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> { -// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32> -// CHECK: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> -// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32> +// CHECK-LABEL: func.func @partially_foldable +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_1_]] {multiples = array} : (tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_0_]], [[VAR_0_]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<1x4x8x8xf32> // CHECK: }