From d59e6b9ba28b12868c519b79bb7ab9b8384ddd0b Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Fri, 14 Feb 2025 15:04:08 +0000 Subject: [PATCH 01/31] Fix problem where the shape of the insert shape was calculated incorrectly --- mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp | 7 ++++++- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 7 ++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 8044405645d44..e1983712b3faf 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -218,10 +218,15 @@ struct LinalgOpTilingInterface })); OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); + SmallVector allShapeSizes = + linalgOp.createFlatListOfOperandDims(b, linalgOp.getLoc()); + SmallVector sizeBounds = + mlir::affine::makeComposedFoldedMultiResultAffineApply( + b, loc, linalgOp.getShapesToLoopsMap(), allShapeSizes); SliceParameters sliceParams = computeSliceParameters( b, loc, outOperand->get(), sizes, linalgOp.getMatchingIndexingMap(outOperand), offsets, - /*ubs*/ {}, subShapeSizes, true); + /*ubs*/ sizeBounds, subShapeSizes, true); resultOffsets = sliceParams.offsets; resultSizes = sliceParams.sizes; return success(); diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 846c2064d87b4..68148eaeef730 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1905,9 +1905,10 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, SmallVector> resultSizes( totalNumResultsOfConsumer); for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) { - if (failed(tiledConsumerOp.getResultTilePosition( - rewriter, idx, iterDomainOffsets, iterDomainSizes, - resultOffsets[idx], resultSizes[idx]))) { + if (failed(cast(clonedConsumerOp) + .getResultTilePosition(rewriter, idx, iterDomainOffsets, + iterDomainSizes, resultOffsets[idx], + resultSizes[idx]))) { return rewriter.notifyMatchFailure( tiledConsumerOp, "can't get result domain position from iter domain position"); From b39a27c31ba0d1a2fd506fca9ea69cfc4a63d1c6 Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Mon, 17 Feb 2025 08:05:15 +0000 Subject: [PATCH 02/31] Revert "Revert "Merge pull request #407 from Xilinx/matthias.fix_non_monotonic_slice_params"" This reverts commit 018230a3783457a6e7ae39900b89e4646e90bda4. --- mlir/include/mlir/IR/AffineExpr.h | 5 +++++ mlir/include/mlir/IR/AffineMap.h | 4 ++++ mlir/lib/IR/AffineExpr.cpp | 36 +++++++++++++++++++++++++++++++ mlir/lib/IR/AffineMap.cpp | 5 +++++ 4 files changed, 50 insertions(+) diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index a93e74b449cee..28d00f1299f2f 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -110,6 +110,11 @@ class AffineExpr { /// floordiv, ceildiv, and mod is only allowed w.r.t constants. bool isPureAffine() const; + /// Returns true if this expression is monotonicically increasing with respect + /// to the AffineDimExprs, i.e. increasing the value of any AffineDimExpr will + /// never decrease the value of the result. + bool isMonotonicallyIncreasing() const; + /// Returns the greatest known integral divisor of this affine expression. The /// result is always positive. int64_t getLargestKnownDivisor() const; diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index e30950bbf292d..b9b57612d912d 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -382,6 +382,10 @@ class AffineMap { /// Returns true if the AffineMap represents a symbol-less permutation map. bool isPermutation() const; + // Returns true if every result is monotonically increasing. + // See AffineExpr::isMonotonicallyIncreasing(). + bool isComponentWiseMonotonicallyIncreasing() const; + /// Returns the map consisting of the `resultPos` subset. AffineMap getSubMap(ArrayRef resultPos) const; diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 2291d64c50a56..d1ec15048758c 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -239,6 +239,42 @@ bool AffineExpr::isPureAffine() const { llvm_unreachable("Unknown AffineExpr"); } +static bool isNonNegativeConstant(AffineExpr expr) { + auto constant = dyn_cast(expr); + return constant && constant.getValue() >= 0; +} + +bool AffineExpr::isMonotonicallyIncreasing() const { + switch (getKind()) { + case AffineExprKind::SymbolId: + case AffineExprKind::DimId: + case AffineExprKind::Constant: + return true; + case AffineExprKind::Add: { + auto op = llvm::cast(*this); + return op.getLHS().isMonotonicallyIncreasing() && + op.getRHS().isMonotonicallyIncreasing(); + } + case AffineExprKind::Mul: { + // One operand must be a non-negative constant. + auto op = llvm::cast(*this); + return op.getLHS().isMonotonicallyIncreasing() && + op.getRHS().isMonotonicallyIncreasing() && + (isNonNegativeConstant(op.getLHS()) || + isNonNegativeConstant(op.getRHS())); + } + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: { + auto op = llvm::cast(*this); + return op.getLHS().isMonotonicallyIncreasing() && + isNonNegativeConstant(op.getRHS()); + } + case AffineExprKind::Mod: + return false; + } + llvm_unreachable("Unknown AffineExpr"); +} + // Returns the greatest known integral divisor of this affine expression. int64_t AffineExpr::getLargestKnownDivisor() const { AffineBinaryOpExpr binExpr(nullptr); diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index ea3c0723b0775..408d75d87adb4 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -651,6 +651,11 @@ bool AffineMap::isPermutation() const { return isProjectedPermutation(); } +bool AffineMap::isComponentWiseMonotonicallyIncreasing() const { + return all_of(getResults(), + [](auto expr) { return expr.isMonotonicallyIncreasing(); }); +} + AffineMap AffineMap::getSubMap(ArrayRef resultPos) const { SmallVector exprs; exprs.reserve(resultPos.size()); From 9406825d8a545c95d5734052a52acfe8c34bfb8d Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Mon, 17 Feb 2025 08:06:03 +0000 Subject: [PATCH 03/31] Revert change on the TileUsingInterface.cpp --- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 68148eaeef730..846c2064d87b4 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1905,10 +1905,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, SmallVector> resultSizes( totalNumResultsOfConsumer); for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) { - if (failed(cast(clonedConsumerOp) - .getResultTilePosition(rewriter, idx, iterDomainOffsets, - iterDomainSizes, resultOffsets[idx], - resultSizes[idx]))) { + if (failed(tiledConsumerOp.getResultTilePosition( + rewriter, idx, iterDomainOffsets, iterDomainSizes, + resultOffsets[idx], resultSizes[idx]))) { return rewriter.notifyMatchFailure( tiledConsumerOp, "can't get result domain position from iter domain position"); From 0c37ff2701e909668f5e1c19d76b0ad725ff7f9e Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Mon, 17 Feb 2025 09:58:30 +0000 Subject: [PATCH 04/31] Check for monotonic functions --- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 26 ++++++++------ mlir/test/Dialect/Linalg/tile-tensors.mlir | 41 ++++++++++++++++++++++ 2 files changed, 57 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 8e898904d87c2..dc2e3971d28bd 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -56,19 +56,24 @@ namespace { // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] // struct TileCheck : public AffineExprVisitor { - TileCheck(ArrayRef tileSizes, ArrayRef sizeBounds) - : tileSizes(tileSizes), sizeBounds(sizeBounds) {} + TileCheck(ArrayRef tileSizes, ArrayRef sizeBounds, + bool isMonotonicallyIncreasing) + : tileSizes(tileSizes), sizeBounds(sizeBounds), + isMonotonicallyIncreasing(isMonotonicallyIncreasing) {} void visitDimExpr(AffineDimExpr expr) { unsigned pos = expr.getPosition(); - // This dimension is tiled if the tile size is larger than zero and not - // equal to its domain size (if statically known). - std::optional tileSize = getConstantIntValue(tileSizes[pos]); - if (tileSize && !sizeBounds.empty()) { - std::optional sizeBound = getConstantIntValue(sizeBounds[pos]); - if (sizeBound && *sizeBound == *tileSize) { - return; + // If the expression is non monotonic, this dimension is tiled if the tile + // size is larger than zero and not equal to its domain size (if statically + // known). + if (!isMonotonicallyIncreasing) { + std::optional tileSize = getConstantIntValue(tileSizes[pos]); + if (tileSize && !sizeBounds.empty()) { + std::optional sizeBound = getConstantIntValue(sizeBounds[pos]); + if (sizeBound && *sizeBound == *tileSize) { + return; + } } } @@ -84,6 +89,7 @@ struct TileCheck : public AffineExprVisitor { bool isTiled = false; ArrayRef tileSizes; ArrayRef sizeBounds; + bool isMonotonicallyIncreasing; }; } // namespace @@ -92,7 +98,7 @@ static bool isTiled(AffineExpr expr, ArrayRef tileSizes, ArrayRef sizeBounds) { if (!expr) return false; - TileCheck t(tileSizes, sizeBounds); + TileCheck t(tileSizes, sizeBounds, expr.isMonotonicallyIncreasing()); t.visit(expr); return t.isTiled; } diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir index 883eb732b2aa6..169f5302d4e37 100644 --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -199,3 +199,44 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +#identity = affine_map<(d0, d1) -> (d0, d1)> +#identity1 = affine_map<(d0, d1) -> (d0 mod 3, d1)> + +// CHECK-LABEL: func @tile_monotonic_outer_dim +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x10xf32> +func.func @tile_monotonic_outer_dim(%in: tensor<4x10xf32>) -> tensor<4x10xf32> { + %empty = tensor.empty() : tensor<4x10xf32> + %1 = linalg.generic {indexing_maps = [#identity, #identity1], iterator_types = ["parallel", "parallel"]} + ins(%in : tensor<4x10xf32>) outs(%empty : tensor<4x10xf32>) { + ^bb1(%a: f32, %b: f32): + linalg.yield %a : f32 + } -> tensor<4x10xf32> + + // CHECK: %[[C4:.+]] = arith.constant 4 : index + // CHECK: %[[C4_1:.+]] = arith.constant 4 : index + // CHECK: %[[C5:.+]] = arith.constant 5 : index + // CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[C4]] step %[[C4_1]] iter_args(%[[ARG1:.+]] = %[[OUT:.+]]) -> (tensor<4x10xf32>) { + // CHECK: scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[ARG2:.+]] = %[[ARG1]]) -> (tensor<4x10xf32>) { + // CHECK: %[[INSLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV1]]] [4, 5] [1, 1] : tensor<4x10xf32> to tensor<4x5xf32> + // CHECK: %[[OUTSLICE:.+]] = tensor.extract_slice %[[ARG2]][0, %[[IV1]]] [4, 5] [1, 1] : tensor<4x10xf32> to tensor<4x5xf32> + // CHECK: %[[RES:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[INSLICE]] : tensor<4x5xf32>) outs(%[[OUTSLICE]] : tensor<4x5xf32>) { + // CHECK: ^bb0(%in: f32, %out: f32): + // CHECK: linalg.yield %in : f32 + // CHECK: } -> tensor<4x5xf32> + // CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[RES]] into %[[ARG2]][0, %[[IV1]]] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<4x10xf32> + // CHECK: scf.yield %[[INSERT_SLICE]] : tensor<4x10xf32> + // CHECK: } + + return %1 : tensor<4x10xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [4, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} From e273ea6f224cdab862b603cc030cb335b7943fcd Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Mon, 17 Feb 2025 10:10:50 +0000 Subject: [PATCH 05/31] Update tests --- mlir/test/Dialect/Linalg/tile-tensors.mlir | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir index 169f5302d4e37..34125727201eb 100644 --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -177,9 +177,14 @@ func.func @non_monotonic_affine_expr(%arg0 : tensor<7xf32>) -> tensor<7xf32> { %0 = tensor.dim %arg0, %c0 : tensor<7xf32> %empty = tensor.empty() : tensor<7xf32> - // CHECK: %[[OUT:.*]] = tensor.empty() : tensor<7xf32> - // CHECK: scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[OUT]]) -> (tensor<7xf32>) { - // CHECK: tensor.extract_slice %[[TC0]][0] [7] [1] : tensor<7xf32> to tensor<7xf32> + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[OUT:.*]] = tensor.empty() : tensor<7xf32> + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index + // CHECK-DAG: %[[C7_1:.*]] = arith.constant 7 : index + // CHECK: scf.for %[[IV0:.+]] = %[[C0]] to %[[C7]] step %[[C7_1]] iter_args(%[[TC0:.*]] = %[[OUT]]) -> (tensor<7xf32>) { + // CHECK: tensor.extract_slice %[[ARG0]][0] [7] [1] : tensor<7xf32> to tensor<7xf32> + // CHECK: tensor.extract_slice %[[TC0]][%[[IV0]]] [7] [1] : tensor<7xf32> to tensor<7xf32> %generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0 mod 4)>, affine_map<(d0) -> (d0)>], @@ -220,7 +225,7 @@ func.func @tile_monotonic_outer_dim(%in: tensor<4x10xf32>) -> tensor<4x10xf32> { // CHECK: %[[C5:.+]] = arith.constant 5 : index // CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[C4]] step %[[C4_1]] iter_args(%[[ARG1:.+]] = %[[OUT:.+]]) -> (tensor<4x10xf32>) { // CHECK: scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[ARG2:.+]] = %[[ARG1]]) -> (tensor<4x10xf32>) { - // CHECK: %[[INSLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV1]]] [4, 5] [1, 1] : tensor<4x10xf32> to tensor<4x5xf32> + // CHECK: %[[INSLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]] [4, 5] [1, 1] : tensor<4x10xf32> to tensor<4x5xf32> // CHECK: %[[OUTSLICE:.+]] = tensor.extract_slice %[[ARG2]][0, %[[IV1]]] [4, 5] [1, 1] : tensor<4x10xf32> to tensor<4x5xf32> // CHECK: %[[RES:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[INSLICE]] : tensor<4x5xf32>) outs(%[[OUTSLICE]] : tensor<4x5xf32>) { // CHECK: ^bb0(%in: f32, %out: f32): From aeda3eeb1dbab73d3b3d578f08930f8a172b6eab Mon Sep 17 00:00:00 2001 From: josel-amd Date: Tue, 18 Feb 2025 15:58:30 +0100 Subject: [PATCH 06/31] [mlir][linalg] Remove `computeStaticLoopSizes` (#124778) (#475) `computeStaticLoopSizes()` is functionally identical to `getStaticLoopRanges()`. Replace all uses of `computeStaticLoopSizes()` by `getStaticLoopRanges()` and remove the former. --- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 5 ----- mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 13 ------------- mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp | 2 +- 3 files changed, 1 insertion(+), 19 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 0a404194569c2..1e21f781bc4e5 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -850,11 +850,6 @@ def LinalgStructuredInterface /// `createFlatListOfOperandDims`. SmallVector createLoopRanges(OpBuilder &b, Location loc); - /// Compute the static loop sizes necessary to vectorize the computation. - /// This is done by applying `getShapesToLoopsMap` to - /// `createFlatListOfOperandStaticDims`. - SmallVector computeStaticLoopSizes(); - /// Returns the value that expresses the shape of the output in terms of /// shape of the input operands where possible LogicalResult reifyResultShapes(OpBuilder &b, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index bd77965194b27..3b705f64bfe40 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -1086,19 +1086,6 @@ SmallVector LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { return res; } -SmallVector LinalgOp::computeStaticLoopSizes() { - AffineMap map = getLoopsToShapesMap(); - unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); - SmallVector allShapeSizes = createFlatListOfOperandStaticDims(); - SmallVector res(numDims, 0); - for (unsigned idx = 0; idx < numRes; ++idx) { - auto result = map.getResult(idx); - if (auto d = dyn_cast(result)) - res[d.getPosition()] = allShapeSizes[idx]; - } - return res; -} - /// Visitor to check if any of the given set of positions from AffineDimExprs /// are used within an AffineExpr. struct HasAffineDimExprVisitor diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp index 2e6079e1402e1..b53180b5cf7c3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -130,7 +130,7 @@ class FoldConstantBase : public OpInterfaceRewritePattern { return failure(); } - SmallVector loopBounds = linalgOp.computeStaticLoopSizes(); + SmallVector loopBounds = linalgOp.getStaticLoopRanges(); int64_t numElements = outputType.getNumElements(); // Use APInt/APFloat instead of Attribute here for constructing the output. From fe9d73c8101c3368d691ac604e48e1c067558318 Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Wed, 19 Feb 2025 10:11:41 +0000 Subject: [PATCH 07/31] Remove ununsed function --- mlir/include/mlir/IR/AffineMap.h | 4 ---- mlir/lib/IR/AffineMap.cpp | 5 ----- 2 files changed, 9 deletions(-) diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index b9b57612d912d..e30950bbf292d 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -382,10 +382,6 @@ class AffineMap { /// Returns true if the AffineMap represents a symbol-less permutation map. bool isPermutation() const; - // Returns true if every result is monotonically increasing. - // See AffineExpr::isMonotonicallyIncreasing(). - bool isComponentWiseMonotonicallyIncreasing() const; - /// Returns the map consisting of the `resultPos` subset. AffineMap getSubMap(ArrayRef resultPos) const; diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 408d75d87adb4..ea3c0723b0775 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -651,11 +651,6 @@ bool AffineMap::isPermutation() const { return isProjectedPermutation(); } -bool AffineMap::isComponentWiseMonotonicallyIncreasing() const { - return all_of(getResults(), - [](auto expr) { return expr.isMonotonicallyIncreasing(); }); -} - AffineMap AffineMap::getSubMap(ArrayRef resultPos) const { SmallVector exprs; exprs.reserve(resultPos.size()); From 3c940451cf20e9bfbbe0a880464e960f77810be8 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 18 Feb 2025 02:15:44 +0000 Subject: [PATCH 08/31] When writing PDLL patterns it is often assumed that some basic checks are executed before constraints are called, but this is not always the case, as operations can be reordered in PDLInterp if there is no dependency between them. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For example: Pdll pattern: ``` let someOp = op(input: Value (resTypes: TypeRange); let someResult = someConstraint(inputAxis); ``` If SomeOp requires axis to have a valid value, it is easy to (wrongly) assume that someConstraint always gets called with a not-null inputAxis. This is not correct. The linearized PDLInterp (pseudo-code) could be the following: ``` %0 = pdl_interp.get_attribute "axis" of %arg0 %1 = pdl_interp.apply_constraint “someConstraint”(%0) pdl_interp.is_not_null(%0) pdl_interp.check_operation_name of %arg0 is "someDialect.SomeOp" ``` Note that here someConstraint can be called with a null attribute. This commit changes the prioritization of predicates, so that constraints are run after other predicates. ``` %0 = pdl_interp.get_attribute "axis" of %arg0 pdl_interp.is_not_null(%0) pdl_interp.check_operation_name of %arg0 is "someDialect.SomeOp" %1 = pdl_interp.apply_constraint “someConstraint”(%0) ``` This ensures that null or operation name checks are run before constraints. This is closer to the mental model when writing PDLL patterns and should make it less likely to run into bugs caused by assuming implicit checks for not null. --- .../PDLToPDLInterp/PredicateTree.cpp | 16 ++++++++++---- .../pdl-to-pdl-interp-matcher.mlir | 21 +++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 99bf73de59e37..5d249a37d50fd 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -768,17 +768,25 @@ struct OrderedPredicate { /// model. bool operator<(const OrderedPredicate &rhs) const { // Sort by: + // * not being a constraint. Rational: When writing constraints, it is + // sometimes assumed that checks for null or operation names are executed + // before the constraint. As there is no dependency between this + // operation, this is not always guaranteed, which can lead to bugs if the + // constraints is not checking inputs for null itself. By ordering + // constraints to the end, it is assured that implicit checks are nun + // before them // * higher first and secondary order sums // * lower depth // * lower position dependency // * lower predicate dependency // * lower tie breaking ID auto *rhsPos = rhs.position; - return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(), + return std::make_tuple(!isa(question), primary, + secondary, rhsPos->getOperationDepth(), rhsPos->getKind(), rhs.question->getKind(), rhs.id) > - std::make_tuple(rhs.primary, rhs.secondary, - position->getOperationDepth(), position->getKind(), - question->getKind(), id); + std::make_tuple(!isa(rhs.question), rhs.primary, + rhs.secondary, position->getOperationDepth(), + position->getKind(), question->getKind(), id); } }; diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir index aeb4c7233d1ff..42fadbc3a6a90 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -488,6 +488,27 @@ module @predicate_ordering { } } +// ----- + +// CHECK-LABEL: module @predicate_ordering_attr +module @predicate_ordering_attr { + // Check that the result is checked for null first, before applying the + // constraint. + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[RESULT:.*]] = pdl_interp.get_attribute "attr" of %[[ROOT]] + // CHECK-NEXT: pdl_interp.is_not_null %[[RESULT]] + // CHECK: pdl_interp.apply_constraint "constraint" + + + pdl.pattern : benefit(1) { + %attr = attribute + pdl.apply_native_constraint "constraint"(%attr: !pdl.attribute) + pdl.apply_native_constraint "constraint"(%attr: !pdl.attribute) + %root = operation "foo.op" {"attr" = %attr} + rewrite %root with "rewriter" + } +} // ----- From b3be25a89107b72f249ce5cfa3ea9baefa8dcdb3 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Sat, 22 Feb 2025 21:24:26 +0000 Subject: [PATCH 09/31] Do not print unnecessary newlines if attributes are elided --- mlir/lib/IR/AsmPrinter.cpp | 3 ++- mlir/test/IR/mlir-newline-after-attr.mlir | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 51fe49a3c0bf8..16844bd12c413 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2791,7 +2791,8 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef attrs, SmallString<16> separator = StringRef(", "); if (printerFlags.getNewlineAfterAttrLimit() && - attrs.size() > *printerFlags.getNewlineAfterAttrLimit()) { + std::distance(filteredAttrs.begin(), filteredAttrs.end()) > + *printerFlags.getNewlineAfterAttrLimit()) { // Increase indent to match the visually match the "{ " below. // currentIndent += 2; diff --git a/mlir/test/IR/mlir-newline-after-attr.mlir b/mlir/test/IR/mlir-newline-after-attr.mlir index d35eac21a5152..047a9257d563c 100644 --- a/mlir/test/IR/mlir-newline-after-attr.mlir +++ b/mlir/test/IR/mlir-newline-after-attr.mlir @@ -29,3 +29,6 @@ // CHECK-NEXT: ], "test.op"() {foo.dense_attr = dense<1> : tensor<3xi32>, foo.second_attr = dense<2> : tensor<3xi32>, Operands = [{foo.vect_attr_1_start = dense<0> : vector<3xindex>, foo.vect_attr_1_end = dense<0> : vector<3xindex>, foo.vect_attr_1_count = dense<1> : vector<3xindex>, foo.vect_attr_2_start = dense<0> : vector<3xindex>, foo.vect_attr_2_end = dense<0> : vector<3xindex>, foo.vect_attr_2_count = dense<1> : vector<3xindex>}, {foo.vect_attr_1_start = dense<0> : vector<3xindex>, foo.vect_attr_1_end = dense<0> : vector<3xindex>, foo.vect_attr_1_count = dense<1> : vector<3xindex>, foo.vect_attr_2_start = dense<0> : vector<3xindex>, foo.vect_attr_2_end = dense<0> : vector<3xindex>, foo.vect_attr_2_count = dense<1> : vector<3xindex>}]} : () -> () +// const_shape skips over shape attr when printing. Check that we do not insert unnecessary newlines +// CHECK{LITERAL}: shape.const_shape {foo.second_attr = dense<2> : tensor<3xi32>, foo.third_attr = dense<2> : tensor<3xi32>}[1, 1, 1] : tensor<3xindex> +"shape.const_shape"() {shape = dense<1> : tensor<3xindex>, foo.second_attr = dense<2> : tensor<3xi32>, foo.third_attr = dense<2> : tensor<3xi32>} : () -> (tensor<3xindex>) \ No newline at end of file From 8f479a46c6d376d140557a717cc654d9fc57786b Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 4 Mar 2025 09:23:43 +0000 Subject: [PATCH 10/31] Update comment and message about folding cast where the input has multiple uses --- mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index ea6f295ff2feb..106a9791b309f 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -981,11 +981,11 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase { warnAboutNaNToIntCast(elements, tosaCast, rewriter); // Only fold splat tensors and those used only once to avoid duplicating - // them. + // them and increasing memory consumption. if (!inputTensor.hasOneUse() && !isa(elements)) { - return rewriter.notifyMatchFailure(tosaCast, - "Currently, casts will only be folded " - "if its input only has a single user"); + return rewriter.notifyMatchFailure( + tosaCast, "Currently, casts will only be folded " + "if its input only has a single user or is a splat value."); } // Report a match failure for unexpected types From b40461d875ded0206f3acc2a8a1bd7e79dc34363 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Tue, 4 Mar 2025 15:36:09 +0000 Subject: [PATCH 11/31] Compute affine expression bounds --- mlir/include/mlir/Analysis/AffineExprBounds.h | 91 ++++++++ mlir/lib/Analysis/AffineExprBounds.cpp | 183 ++++++++++++++++ mlir/lib/Analysis/CMakeLists.txt | 1 + .../Analysis/test-affine-expr-bounds.mlir | 207 ++++++++++++++++++ mlir/test/lib/IR/CMakeLists.txt | 1 + .../lib/IR/TestAffineExpressionBounds.cpp | 190 ++++++++++++++++ mlir/tools/mlir-opt/mlir-opt.cpp | 2 + 7 files changed, 675 insertions(+) create mode 100644 mlir/include/mlir/Analysis/AffineExprBounds.h create mode 100644 mlir/lib/Analysis/AffineExprBounds.cpp create mode 100644 mlir/test/Analysis/test-affine-expr-bounds.mlir create mode 100644 mlir/test/lib/IR/TestAffineExpressionBounds.cpp diff --git a/mlir/include/mlir/Analysis/AffineExprBounds.h b/mlir/include/mlir/Analysis/AffineExprBounds.h new file mode 100644 index 0000000000000..b88ed3c46a9b7 --- /dev/null +++ b/mlir/include/mlir/Analysis/AffineExprBounds.h @@ -0,0 +1,91 @@ +//===- AffineExprBounds.h - Compute bounds of affine expressions *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines an analysis of affine expressions to compute their +// ranges (lower/upper bounds) in a given context. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_ANALYSIS_AFFINEEXPRBOUNDS_H +#define MLIR_ANALYSIS_AFFINEEXPRBOUNDS_H + +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" + +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Support/LogicalResult.h" + +using namespace mlir; + +/// This visitor computes the bounds of affine expressions, using as context the +/// bounds of the dimensions of the expression. +/// +/// Example: +/// Given bounds 0 <= d0 <= 99 and 0 <= d1 <= 199, we can compute the bounds +/// of the following expression: +/// lb(2 * d0 + 3 * d1) = 0 +/// ub(2 * d0 + 3 * d1) = 795 +/// +/// * The bounds given in the context are inclusive, and the bounds returned +/// are also inclusive. +/// * If bounds are not available for a dimension, std::nullopt can be used +/// instead. The bounds of an expression that involves it will be std::nullopt. +/// * Limitations: +/// - Parametric expressions (using symbols) are not supported. +/// - Unsigned FloorDiv is currently not supported. +class AffineExprBoundsVisitor + : public AffineExprVisitor { +public: + /// Initialize the context (bounds) with APInt. All bounds must have the same + /// signedness and bit width. + AffineExprBoundsVisitor(ArrayRef> constLowerBounds, + ArrayRef> constUpperBounds, + bool boundsSigned, uint64_t bitWidth, + MLIRContext *context); + + /// Initialize the context (bounds) with 64-bit signed integers. This allows + /// to directly map index-type values such as Linalg op bounds, which are + /// represented as int64_t. + AffineExprBoundsVisitor(ArrayRef> constLowerBounds, + ArrayRef> constUpperBounds, + MLIRContext *context); + + /// Get the upper bound of \p expr using the context bounds. + std::optional getUpperBound(AffineExpr expr); + std::optional getIndexUpperBound(AffineExpr expr); + + /// Get the lower bound of \p expr using the context bounds. + std::optional getLowerBound(AffineExpr expr); + std::optional getIndexLowerBound(AffineExpr expr); + + // These methods are directly called by the AffineExprVisitor base class. + LogicalResult visitMulExpr(AffineBinaryOpExpr expr); + LogicalResult visitAddExpr(AffineBinaryOpExpr expr); + LogicalResult visitDimExpr(AffineDimExpr expr); + LogicalResult visitSymbolExpr(AffineSymbolExpr expr); + LogicalResult visitConstantExpr(AffineConstantExpr expr); + LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr); + LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr); + LogicalResult visitModExpr(AffineBinaryOpExpr expr); + +private: + bool boundsSigned; + uint64_t bitWidth; + void + inferBinOpRange(AffineBinaryOpExpr expr, + std::function)> + opInference); + + /// Bounds that have been computed for subexpressions are memoized and reused. + llvm::DenseMap lb; + llvm::DenseMap ub; +}; + +#endif // MLIR_ANALYSIS_AFFINEEXPRBOUNDS_H diff --git a/mlir/lib/Analysis/AffineExprBounds.cpp b/mlir/lib/Analysis/AffineExprBounds.cpp new file mode 100644 index 0000000000000..0c0f8b622736d --- /dev/null +++ b/mlir/lib/Analysis/AffineExprBounds.cpp @@ -0,0 +1,183 @@ +//===- AffineExprBounds.h - Compute bounds of affine expressions *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an analysis of affine expressions to compute their +// ranges (lower/upper bounds) in a given context. +// +//===----------------------------------------------------------------------===// +#include "mlir/Analysis/AffineExprBounds.h" + +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" + +#include + +using namespace mlir; + +AffineExprBoundsVisitor::AffineExprBoundsVisitor( + ArrayRef> constLowerBounds, + ArrayRef> constUpperBounds, bool boundsSigned, + uint64_t bitWidth, MLIRContext *context) + : boundsSigned(boundsSigned), bitWidth(bitWidth) { + assert(constLowerBounds.size() == constUpperBounds.size()); + for (unsigned i = 0; i < constLowerBounds.size(); i++) { + if (constLowerBounds[i].has_value()) { + lb[getAffineDimExpr(i, context)] = constLowerBounds[i].value(); + } + if (constUpperBounds[i].has_value()) { + ub[getAffineDimExpr(i, context)] = constUpperBounds[i].value(); + } + } +} + +AffineExprBoundsVisitor::AffineExprBoundsVisitor( + ArrayRef> constLowerBounds, + ArrayRef> constUpperBounds, MLIRContext *context) { + assert(constLowerBounds.size() == constUpperBounds.size()); + // Convert int64_ts to APInts. + for (unsigned i = 0; i < constLowerBounds.size(); i++) { + if (constLowerBounds[i].has_value()) { + lb[getAffineDimExpr(i, context)] = + APInt(64, constLowerBounds[i].value(), /*isSigned=*/true); + } + if (constUpperBounds[i].has_value()) { + ub[getAffineDimExpr(i, context)] = + APInt(64, constUpperBounds[i].value(), /*isSigned=*/true); + } + } +} + +std::optional AffineExprBoundsVisitor::getUpperBound(AffineExpr expr) { + // Use memoized bound if available. + auto i = ub.find(expr); + if (i != ub.end()) { + return i->second; + } + // Compute the bound otherwise. + if (failed(walkPostOrder(expr))) { + return std::nullopt; + } + return ub[expr]; +} + +std::optional AffineExprBoundsVisitor::getLowerBound(AffineExpr expr) { + // Use memoized bound if available. + auto i = lb.find(expr); + if (i != lb.end()) { + return i->second; + } + // Compute the bound otherwise. + if (failed(walkPostOrder(expr))) { + return std::nullopt; + } + return lb[expr]; +} + +std::optional +AffineExprBoundsVisitor::getIndexUpperBound(AffineExpr expr) { + std::optional apIntResult = getUpperBound(expr); + if (!apIntResult) + return std::nullopt; + + return apIntResult->getSExtValue(); +} + +std::optional +AffineExprBoundsVisitor::getIndexLowerBound(AffineExpr expr) { + std::optional apIntResult = getLowerBound(expr); + if (!apIntResult) + return std::nullopt; + + return apIntResult->getSExtValue(); +} + +ConstantIntRanges getRange(APInt lb, APInt ub, bool boundsSigned) { + return ConstantIntRanges::range(lb, ub, boundsSigned); +} + +/// Wrapper around the intrange::infer* functions that infers the range of +/// binary operations on two ranges. +void AffineExprBoundsVisitor::inferBinOpRange( + AffineBinaryOpExpr expr, + std::function)> opInference) { + ConstantIntRanges lhsRange = + getRange(lb[expr.getLHS()], ub[expr.getLHS()], boundsSigned); + ConstantIntRanges rhsRange = + getRange(lb[expr.getRHS()], ub[expr.getRHS()], boundsSigned); + ConstantIntRanges result = opInference({lhsRange, rhsRange}); + + lb[expr] = (boundsSigned) ? result.smin() : result.umin(); + ub[expr] = (boundsSigned) ? result.smax() : result.umax(); +} + +// Visitor method overrides. +LogicalResult AffineExprBoundsVisitor::visitMulExpr(AffineBinaryOpExpr expr) { + inferBinOpRange(expr, [](ArrayRef ranges) { + return intrange::inferMul(ranges); + }); + return success(); +} +LogicalResult AffineExprBoundsVisitor::visitAddExpr(AffineBinaryOpExpr expr) { + inferBinOpRange(expr, [](ArrayRef ranges) { + return intrange::inferAdd(ranges); + }); + return success(); +} +LogicalResult +AffineExprBoundsVisitor::visitCeilDivExpr(AffineBinaryOpExpr expr) { + inferBinOpRange( + expr, [boundsSigned = boundsSigned](ArrayRef ranges) { + if (boundsSigned) { + return intrange::inferCeilDivS(ranges); + } + return intrange::inferCeilDivU(ranges); + }); + return success(); +} +LogicalResult +AffineExprBoundsVisitor::visitFloorDivExpr(AffineBinaryOpExpr expr) { + // There is no inferFloorDivU in the intrange library. We only offer + // computation of bounds for signed floordiv operations. + if (boundsSigned) { + inferBinOpRange(expr, [](ArrayRef ranges) { + return intrange::inferFloorDivS(ranges); + }); + return success(); + } + return failure(); +} +LogicalResult AffineExprBoundsVisitor::visitModExpr(AffineBinaryOpExpr expr) { + inferBinOpRange( + expr, [boundsSigned = boundsSigned](ArrayRef ranges) { + if (boundsSigned) { + return intrange::inferRemS(ranges); + } + return intrange::inferRemU(ranges); + }); + return success(); +} +LogicalResult AffineExprBoundsVisitor::visitDimExpr(AffineDimExpr expr) { + if (lb.find(expr) == lb.end() || ub.find(expr) == ub.end()) { + return failure(); + } + return success(); +} +LogicalResult AffineExprBoundsVisitor::visitSymbolExpr(AffineSymbolExpr expr) { + return failure(); +} +LogicalResult +AffineExprBoundsVisitor::visitConstantExpr(AffineConstantExpr expr) { + APInt apIntVal = + APInt(bitWidth, static_cast(expr.getValue()), boundsSigned); + lb[expr] = apIntVal; + ub[expr] = apIntVal; + return success(); +} diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt index 609cb34309829..9462471a367a0 100644 --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -21,6 +21,7 @@ set(LLVM_OPTIONAL_SOURCES add_subdirectory(Presburger) add_mlir_library(MLIRAnalysis + AffineExprBounds.cpp AliasAnalysis.cpp CallGraph.cpp DataFlowFramework.cpp diff --git a/mlir/test/Analysis/test-affine-expr-bounds.mlir b/mlir/test/Analysis/test-affine-expr-bounds.mlir new file mode 100644 index 0000000000000..e4af66f1b8d13 --- /dev/null +++ b/mlir/test/Analysis/test-affine-expr-bounds.mlir @@ -0,0 +1,207 @@ +// RUN: mlir-opt -test-affine-expr-bounds --mlir-print-local-scope --allow-unregistered-dialect --verify-diagnostics %s | FileCheck %s + +func.func @test_compute_affine_expr_bounds() { + // Add + + // CHECK: "test.add"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 3 + "test.add"() {affine_map = affine_map<(d0) -> (d0 + 1)>, lbs = [0], ubs = [2]} : () -> () + + // CHECK: "test.sub_const"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = 1 + "test.sub_const"() {affine_map = affine_map<(d0) -> (d0 - 1)>, lbs = [0], ubs = [2]} : () -> () + + // CHECK: "test.sub_dim"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = 1 + "test.sub_dim"() {affine_map = affine_map<(d0) -> (1 - d0)>, lbs = [0], ubs = [2]} : () -> () + + // Mul + + // CHECK: "test.mul"() + // CHECK-SAME: expr_lb = 10 + // CHECK-SAME: expr_ub = 15 + "test.mul"() {affine_map = affine_map<(d0) -> (5 * d0)>, lbs = [2], ubs = [3]} : () -> () + + // CHECK: "test.mul_neg"() + // CHECK-SAME: expr_lb = -15 + // CHECK-SAME: expr_ub = -10 + "test.mul_neg"() {affine_map = affine_map<(d0) -> (-5 * d0)>, lbs = [2], ubs = [3]} : () -> () + + // Mod + + // CHECK: "test.mod_basic"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 2 + "test.mod_basic"() {affine_map = affine_map<(d0) -> (d0 mod 5)>, lbs = [0], ubs = [2]} : () -> () + + // CHECK: "test.mod_wrap_around_by_range"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 4 + "test.mod_wrap_around_by_range"() {affine_map = affine_map<(d0) -> (d0 mod 5)>, lbs = [0], ubs = [7]} : () -> () + + // CHECK: "test.mod_wrap_around_by_sum"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 4 + "test.mod_wrap_around_by_sum"() {affine_map = affine_map<(d0) -> ((d0 + 3) mod 5)>, lbs = [0], ubs = [3]} : () -> () + + // CHECK: "test.mod_not_wrapping_around"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 3 + "test.mod_not_wrapping_around"() {affine_map = affine_map<(d0) -> (((d0 + 12) mod 11) mod 5)>, lbs = [0], ubs = [2]} : () -> () + + // FloorDiv + + // CHECK: "test.floordiv_basic"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 1 + "test.floordiv_basic"() {affine_map = affine_map<(d0) -> (d0 floordiv 16)>, lbs = [0], ubs = [31]} : () -> () + + // CHECK: "test.floordiv_not_stepping"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 1 + "test.floordiv_not_stepping"() {affine_map = affine_map<(d0) -> (d0 floordiv 16)>, lbs = [16], ubs = [31]} : () -> () + + // CHECK: "test.floordiv_stepping_by_sum"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 2 + "test.floordiv_stepping_by_sum"() {affine_map = affine_map<(d0) -> ((d0 + 1) floordiv 16)>, lbs = [16], ubs = [31]} : () -> () + + // CHECK: "test.floordiv_neg_factor"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = 0 + "test.floordiv_neg_factor"() {affine_map = affine_map<(d0) -> (d0 floordiv -8)>, lbs = [0], ubs = [8]} : () -> () + + // CHECK: "test.floordiv_neg_factor_not_stepping"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = -1 + "test.floordiv_neg_factor_not_stepping"() {affine_map = affine_map<(d0) -> (d0 floordiv -8)>, lbs = [1], ubs = [8]} : () -> () + + // CHECK: "test.floordiv_neg_range"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = -1 + "test.floordiv_neg_range"() {affine_map = affine_map<(d0) -> (d0 floordiv 8)>, lbs = [-8], ubs = [-1]} : () -> () + + // CeilDiv + + // CHECK: "test.ceildiv_basic"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 1 + "test.ceildiv_basic"() {affine_map = affine_map<(d0) -> (d0 ceildiv 16)>, lbs = [0], ubs = [16]} : () -> () + + // CHECK: "test.ceildiv_not_stepping"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 1 + "test.ceildiv_not_stepping"() {affine_map = affine_map<(d0) -> (d0 ceildiv 16)>, lbs = [1], ubs = [16]} : () -> () + + // CHECK: "test.ceildiv_stepping_by_sum"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 2 + "test.ceildiv_stepping_by_sum"() {affine_map = affine_map<(d0) -> ((d0 + 1) ceildiv 16)>, lbs = [1], ubs = [16]} : () -> () + + // CHECK: "test.ceildiv_neg_factor"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = 0 + "test.ceildiv_neg_factor"() {affine_map = affine_map<(d0) -> (d0 ceildiv -16)>, lbs = [1], ubs = [16]} : () -> () + + // CHECK: "test.ceildiv_neg_factor_not_stepping"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 0 + "test.ceildiv_neg_factor_not_stepping"() {affine_map = affine_map<(d0) -> (d0 ceildiv -16)>, lbs = [0], ubs = [15]} : () -> () + + // CHECK: "test.ceildiv_neg_range"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = 0 + "test.ceildiv_neg_range"() {affine_map = affine_map<(d0) -> (d0 ceildiv 16)>, lbs = [-16], ubs = [-1]} : () -> () + + return +} + +// ----- + +func.func @test_bounds_unsigned() { + // CHECK: "test.unsigned"() + // CHECK-SAME: expr_lb = 0 : ui8 + // CHECK-SAME: expr_ub = 255 : ui8 + "test.unsigned"() {affine_map = affine_map<(d0) -> (d0)>, lbs = [0 : ui8], ubs = [255 : ui8]} : () -> () + + // CHECK: "test.unsigned_wrapping"() + // CHECK-SAME: expr_lb = 0 : ui8 + // CHECK-SAME: expr_ub = 255 : ui8 + "test.unsigned_wrapping"() {affine_map = affine_map<(d0) -> (d0 + 2)>, lbs = [253 : ui8], ubs = [255 : ui8]} : () -> () + + // CHECK: "test.unsigned_wrap_full"() + // CHECK-SAME: expr_lb = 0 : ui8 + // CHECK-SAME: expr_ub = 4 : ui8 + "test.unsigned_wrap_full"() {affine_map = affine_map<(d0) -> (d0 + 5)>, lbs = [251 : ui8], ubs = [255 : ui8]} : () -> () + + return +} + +// ----- + +func.func @test_unsigned_floordiv() { + // Result should be lb = 1, ub = 1, but we're missing an unsigned floordiv computation. + // expected-error @+1 {{Failed to compute bounds}} + "test.unsigned_floordiv"() {affine_map = affine_map<(d0) -> (d0 floordiv 128)>, lbs = [129 : ui8], ubs = [129 : ui8]} : () -> () + +} + +// ----- + +func.func @test_bounds_signed() { + // CHECK: "test.signed"() + // CHECK-SAME: expr_lb = -1 : i8 + // CHECK-SAME: expr_ub = 0 : i8 + "test.signed"() {affine_map = affine_map<(d0) -> (d0 floordiv 16)>, lbs = [-1 : i8], ubs = [0 : i8]} : () -> () + + // CHECK: "test.signed_wrapping"() + // CHECK-SAME: expr_lb = -128 : i8 + // CHECK-SAME: expr_ub = 127 : i8 + "test.signed_wrapping"() {affine_map = affine_map<(d0) -> (d0 + 3)>, lbs = [124 : i8], ubs = [127 : i8]} : () -> () + + // CHECK: "test.signed_wrap_full"() + // CHECK-SAME: expr_lb = -128 : i8 + // CHECK-SAME: expr_ub = -125 : i8 + "test.signed_wrap_full"() {affine_map = affine_map<(d0) -> (d0 + 4)>, lbs = [124 : i8], ubs = [127 : i8]} : () -> () + + return +} + +// ----- + +func.func @test_dynamic_lb_basic() { + // expected-error @+1 {{Failed to compute bounds}} + "test.dynamic_lb_basic"() {affine_map = affine_map<(d0) -> (d0)>, lbs = ["?"], ubs = [1]} : () -> () + return +} + +// ----- + +func.func @test_dynamic_ub_basic() { + // expected-error @+1 {{Failed to compute bounds}} + "test.dynamic_ub_basic"() {affine_map = affine_map<(d0) -> (d0)>, lbs = [0], ubs = ["?"]} : () -> () + return +} + +// ----- + +func.func @test_dynamic_lb_unused() { + // CHECK: "test.dynamic_lb_unused"() + // CHECK-SAME: expr_lb = 14 + // CHECK-SAME: expr_ub = 16 + "test.dynamic_lb_unused"() {affine_map = affine_map<(d0, d1) -> (d1 + 2)>, lbs = ["?", 12], ubs = [1, 14]} : () -> () + return +} + +// ----- + +func.func @test_dynamic_ub_unused() { + // CHECK: "test.dynamic_ub_unused"() + // CHECK-SAME: expr_lb = 14 + // CHECK-SAME: expr_ub = 16 + "test.dynamic_ub_unused"() {affine_map = affine_map<(d0, d1) -> (d1 + 2)>, lbs = [0, 12], ubs = ["?", 14]} : () -> () + return +} diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt index 01297ad0a1148..9fe2ba0c610ef 100644 --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestIR + TestAffineExpressionBounds.cpp TestAffineWalk.cpp TestBytecodeRoundtrip.cpp TestBuiltinAttributeInterfaces.cpp diff --git a/mlir/test/lib/IR/TestAffineExpressionBounds.cpp b/mlir/test/lib/IR/TestAffineExpressionBounds.cpp new file mode 100644 index 0000000000000..dc0b49d32130e --- /dev/null +++ b/mlir/test/lib/IR/TestAffineExpressionBounds.cpp @@ -0,0 +1,190 @@ +//===- TestAffineExpressionBounds.cpp - Test affine expression bounds --=====// +//----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineExprBounds.h" + +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/FormatVariadic.h" + +#include "TestDialect.h" + +using namespace mlir; + +namespace { + +struct TestAffineExpressionBounds + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineExpressionBounds) + + StringRef getArgument() const final { return "test-affine-expr-bounds"; } + StringRef getDescription() const final { + return "Test simplify affine expression simplication"; + } + + FailureOr>> + getBound(Operation *op, StringRef boundType, bool *resultSigned, + uint64_t *resultWidth, bool optional = false) { + SmallVector> result; + + bool isSigned = false; + uint64_t width = 0; + + auto dict = op->getAttrDictionary(); + if (!dict) { + return op->emitError("No dictionary found"); + } + + auto bounds = dict.getNamed(boundType); + if (!bounds) { + if (!optional) { + return op->emitError(llvm::formatv("No {} attribute found", boundType)); + } + return failure(); + } + + auto boundsValue = cast(bounds->getValue()); + + for (auto v : boundsValue) { + if (auto value = dyn_cast(v)) { + if (width == 0) { + isSigned = (value.getType().isSignedInteger() || + value.getType().isSignlessInteger()); + width = value.getType().getIntOrFloatBitWidth(); + } else if (isSigned != (value.getType().isSignedInteger() || + value.getType().isSignlessInteger())) { + return op->emitError("Mixed signedness in bounds"); + } else if (width != value.getType().getIntOrFloatBitWidth()) { + return op->emitError("Mixed width in bounds"); + } + result.push_back(value.getValue()); + } else if (auto value = dyn_cast(v)) { + if (value.getValue() == "?") { + result.push_back(std::nullopt); + } else { + return op->emitError("Unknown string value found"); + } + } else { + return op->emitError("Non-integer or string value found in bounds"); + } + } + + *resultSigned = isSigned; + *resultWidth = width; + + return result; + } + + FailureOr getAffineExpr(Operation *op) { + auto dict = op->getAttrDictionary(); + if (!dict) { + return op->emitError("No dictionary found"); + } + auto affineMap = dict.getNamed("affine_map"); + if (!affineMap) { + return op->emitError("No affine_map attribute found"); + } + auto mapAttr = dyn_cast(affineMap->getValue()); + if (!mapAttr) { + return op->emitError("Invalid affine_map attribute found"); + } + + auto map = mapAttr.getAffineMap(); + if (map.getNumResults() != 1) { + return op->emitError("Invalid number of affine_map results"); + } + + return map.getResult(0); + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + IRRewriter rewriter(func.getContext()); + + func.walk([&](Operation *op) { + if (op->getDialect() != + op->getContext()->getLoadedDialect()) { + return; + } + + auto expr = getAffineExpr(op); + bool ubSigned, lbSigned; + uint64_t ubWidth, lbWidth; + auto ubs = getBound(op, "ubs", &ubSigned, &ubWidth); + auto lbs = getBound(op, "lbs", &lbSigned, &lbWidth); + + if (failed(expr) || failed(ubs) || failed(lbs)) { + return; + } + + if (ubs->size() != lbs->size()) { + op->emitError("Mismatched number of bounds"); + return; + } + if (ubWidth != lbWidth && + !((ubWidth == 0 && lbWidth > 0) || (ubWidth > 0 && lbWidth == 0))) { + op->emitError("Mismatched width in bounds"); + return; + } + bool signCheck = + !(ubWidth == 0 && lbWidth > 0) && !(ubWidth > 0 && lbWidth == 0); + if (signCheck && (ubSigned != lbSigned)) { + op->emitError("Mixed signedness in bounds"); + return; + } + + uint64_t width = (ubWidth == 0) ? lbWidth : ubWidth; + + AffineExprBoundsVisitor visitor(*lbs, *ubs, lbSigned, width, + &getContext()); + auto exprLB = visitor.getLowerBound(*expr); + auto exprUB = visitor.getUpperBound(*expr); + + if (!exprLB || !exprUB) { + op->emitError("Failed to compute bounds"); + return; + } + + auto namedAttrList = mlir::NamedAttrList{rewriter.getDictionaryAttr( + {rewriter.getNamedAttr( + "expr_lb", + IntegerAttr::get( + IntegerType::get( + &getContext(), width, + (lbSigned) ? IntegerType::SignednessSemantics::Signless + : IntegerType::SignednessSemantics::Unsigned), + *exprLB)), + rewriter.getNamedAttr( + "expr_ub", + IntegerAttr::get( + IntegerType::get( + &getContext(), width, + (ubSigned) ? IntegerType::SignednessSemantics::Signless + : IntegerType::SignednessSemantics::Unsigned), + *exprUB))})}; + op->setAttrs(namedAttrList); + }); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestAffineExpressionBounds() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 002c3900056de..0e061aa8d1798 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -75,6 +75,7 @@ void registerInliner(); void registerMemRefBoundCheck(); void registerPatternsTestPass(); void registerSimpleParametricTilingPass(); +void registerTestAffineExpressionBounds(); void registerTestAffineLoopParametricTilingPass(); void registerTestAliasAnalysisPass(); void registerTestArithEmulateWideIntPass(); @@ -212,6 +213,7 @@ void registerTestPasses() { mlir::test::registerMemRefBoundCheck(); mlir::test::registerPatternsTestPass(); mlir::test::registerSimpleParametricTilingPass(); + mlir::test::registerTestAffineExpressionBounds(); mlir::test::registerTestAffineLoopParametricTilingPass(); mlir::test::registerTestAliasAnalysisPass(); mlir::test::registerTestArithEmulateWideIntPass(); From a8c0978f373998611fe86a6c7bd1b1fbf77e1ad2 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Tue, 4 Mar 2025 16:30:04 +0000 Subject: [PATCH 12/31] Add width for int64_t interface, remove one warning --- mlir/include/mlir/Analysis/AffineExprBounds.h | 8 ++++---- mlir/lib/Analysis/AffineExprBounds.cpp | 7 +++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Analysis/AffineExprBounds.h b/mlir/include/mlir/Analysis/AffineExprBounds.h index b88ed3c46a9b7..9ffd7227d973f 100644 --- a/mlir/include/mlir/Analysis/AffineExprBounds.h +++ b/mlir/include/mlir/Analysis/AffineExprBounds.h @@ -78,10 +78,10 @@ class AffineExprBoundsVisitor private: bool boundsSigned; uint64_t bitWidth; - void - inferBinOpRange(AffineBinaryOpExpr expr, - std::function)> - opInference); + void inferBinOpRange( + AffineBinaryOpExpr expr, + const std::function)> + &opInference); /// Bounds that have been computed for subexpressions are memoized and reused. llvm::DenseMap lb; diff --git a/mlir/lib/Analysis/AffineExprBounds.cpp b/mlir/lib/Analysis/AffineExprBounds.cpp index 0c0f8b622736d..92a63d0004687 100644 --- a/mlir/lib/Analysis/AffineExprBounds.cpp +++ b/mlir/lib/Analysis/AffineExprBounds.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" +#include "llvm/ADT/APInt.h" #include @@ -40,7 +41,8 @@ AffineExprBoundsVisitor::AffineExprBoundsVisitor( AffineExprBoundsVisitor::AffineExprBoundsVisitor( ArrayRef> constLowerBounds, - ArrayRef> constUpperBounds, MLIRContext *context) { + ArrayRef> constUpperBounds, MLIRContext *context) + : boundsSigned(true), bitWidth(64) { assert(constLowerBounds.size() == constUpperBounds.size()); // Convert int64_ts to APInts. for (unsigned i = 0; i < constLowerBounds.size(); i++) { @@ -107,7 +109,8 @@ ConstantIntRanges getRange(APInt lb, APInt ub, bool boundsSigned) { /// binary operations on two ranges. void AffineExprBoundsVisitor::inferBinOpRange( AffineBinaryOpExpr expr, - std::function)> opInference) { + const std::function)> + &opInference) { ConstantIntRanges lhsRange = getRange(lb[expr.getLHS()], ub[expr.getLHS()], boundsSigned); ConstantIntRanges rhsRange = From 3e3f092bc2bc5e6bc316b9a1d667005f5c2710b7 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Thu, 6 Mar 2025 07:17:24 -0700 Subject: [PATCH 13/31] Ensure 0 <= x mod N < N semantics --- mlir/lib/Analysis/AffineExprBounds.cpp | 26 ++++++++++++++----- .../Analysis/test-affine-expr-bounds.mlir | 10 +++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Analysis/AffineExprBounds.cpp b/mlir/lib/Analysis/AffineExprBounds.cpp index 92a63d0004687..b71cfe4721323 100644 --- a/mlir/lib/Analysis/AffineExprBounds.cpp +++ b/mlir/lib/Analysis/AffineExprBounds.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineExprBounds.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" @@ -158,13 +159,24 @@ AffineExprBoundsVisitor::visitFloorDivExpr(AffineBinaryOpExpr expr) { return failure(); } LogicalResult AffineExprBoundsVisitor::visitModExpr(AffineBinaryOpExpr expr) { - inferBinOpRange( - expr, [boundsSigned = boundsSigned](ArrayRef ranges) { - if (boundsSigned) { - return intrange::inferRemS(ranges); - } - return intrange::inferRemU(ranges); - }); + // Only support integers >= 1 as RHS. + auto rhsConst = dyn_cast(expr.getRHS()); + if (!rhsConst || rhsConst.getValue() < 1) + return failure(); + + inferBinOpRange(expr, [boundsSigned = + boundsSigned](ArrayRef ranges) { + // Mod must return a value between 0 and N-1. + // Computing (N + (expr mod N)) mod N is guaranteed to yield a result in + // this range. + if (boundsSigned) { + auto rhs = ranges[1]; + auto lhs = ranges[0]; + return intrange::inferRemS( + {intrange::inferAdd({intrange::inferRemS({lhs, rhs}), rhs}), rhs}); + } + return intrange::inferRemU(ranges); + }); return success(); } LogicalResult AffineExprBoundsVisitor::visitDimExpr(AffineDimExpr expr) { diff --git a/mlir/test/Analysis/test-affine-expr-bounds.mlir b/mlir/test/Analysis/test-affine-expr-bounds.mlir index e4af66f1b8d13..03115760a29d0 100644 --- a/mlir/test/Analysis/test-affine-expr-bounds.mlir +++ b/mlir/test/Analysis/test-affine-expr-bounds.mlir @@ -52,6 +52,16 @@ func.func @test_compute_affine_expr_bounds() { // CHECK-SAME: expr_ub = 3 "test.mod_not_wrapping_around"() {affine_map = affine_map<(d0) -> (((d0 + 12) mod 11) mod 5)>, lbs = [0], ubs = [2]} : () -> () + // CHECK: "test.mod_neg"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 3 + "test.mod_neg"() {affine_map = affine_map<(d0) -> (d0 mod 5)>, lbs = [-4], ubs = [-2]} : () -> () + + // CHECK: "test.mod_wrapping_by_zero"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 4 + "test.mod_wrapping_by_zero"() {affine_map = affine_map<(d0) -> (d0 mod 5)>, lbs = [-2], ubs = [1]} : () -> () + // FloorDiv // CHECK: "test.floordiv_basic"() From 212603ad682232f1ee3d9008d99843f746ce447c Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Wed, 5 Mar 2025 14:54:51 +0000 Subject: [PATCH 14/31] Add tosa.cast folding for unsigned integers --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 9 ++ .../Dialect/Tosa/Transforms/TosaFolders.cpp | 64 ++++++++----- mlir/test/Dialect/Tosa/constant-cast-opt.mlir | 92 +++++++++++++++++++ 3 files changed, 140 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 0b467d04350e5..780291939ba20 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1877,6 +1877,15 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, | signed 16 to float | int16 | float | | float 32 to float 64 | float32 | float64 | | float 64 to float 32 | float64 | float32 | + + AMD extensions: + | signed to unsigned | signed | unsigned| + | unsigned to signed | unsigned| signed | + | unsigned to float | unsigned| float | + - unsigned to signed integer and signed to unsigned integer: + wrap on overflow + - unsigned to float: + uses llvm's float to int conversion with TOSA rounding mode }]; let arguments = (ins diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index 106a9791b309f..4932ce87d57b7 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -110,16 +110,34 @@ DenseElementsAttr applyElementWise( // We already know the amount of values we will insert, reserve space for // all of them to avoid dynamic resizing transformedValues.reserve(toTransform.getNumElements()); - for (auto val : toTransform.getValues()) { - auto transformedVal = toApply(val, targetType); - transformedValues.push_back(transformedVal); + if constexpr (std::is_same_v) { + for (auto val : toTransform.getValues()) { + auto transformedVal = + toApply(APSInt(val, toTransform.getElementType().isUnsignedInteger()), + targetType); + transformedValues.push_back(transformedVal); + } + } else { + for (auto val : toTransform.getValues()) { + auto transformedVal = toApply(val, targetType); + transformedValues.push_back(transformedVal); + } } // Make sure that the output tensor has the expected output type auto inShape = toTransform.getType(); auto outTy = inShape.cloneWith({}, targetType); - return DenseElementsAttr::get(outTy, transformedValues); + if constexpr (std::is_same_v) { + SmallVector transformedValuesAPInt; + transformedValuesAPInt.reserve(transformedValues.size()); + for (APSInt val : transformedValues) { + transformedValuesAPInt.emplace_back(val); + } + return DenseElementsAttr::get(outTy, transformedValuesAPInt); + } else { + return DenseElementsAttr::get(outTy, transformedValues); + } } template DenseElementsAttr applyElementWise( @@ -881,10 +899,10 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase { using TosaFoldConstantBase::TosaFoldConstantBase; - static APFloat convertIntToFloat(const APInt &toConvert, + static APFloat convertIntToFloat(const APSInt &toConvert, FloatType targetType) { APFloat res(targetType.getFloatSemantics()); - res.convertFromAPInt(toConvert, true /* isSigned */, tosaRoundingMode); + res.convertFromAPInt(toConvert, toConvert.isSigned(), tosaRoundingMode); return res; } @@ -928,15 +946,14 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase { return converted; } - static APInt convertIntToInt(const APInt &toConvert, IntegerType targetType) { + static APSInt convertIntToInt(const APSInt &toConvert, + IntegerType targetType) { // Make sure to properly translate booleans if (targetType.getWidth() == 1) { - return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); - } - if (targetType.isUnsigned()) { - return toConvert.zextOrTrunc(targetType.getIntOrFloatBitWidth()); + return APSInt(toConvert.isZero() ? APInt::getZero(1) + : APInt::getAllOnes(1)); } - return toConvert.sextOrTrunc(targetType.getIntOrFloatBitWidth()); + return toConvert.extOrTrunc(targetType.getIntOrFloatBitWidth()); } static void warnAboutNaNToIntCast(DenseElementsAttr elements, CastOp location, @@ -994,20 +1011,17 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase { tosaCast, "Only casts from/to int/float are supported."); } - auto isUnsigned = [](Type toCheck) { - return isa(toCheck) && - cast(toCheck).isUnsigned(); - }; - auto typesToCheck = {toType, fromType}; - if (llvm::any_of(typesToCheck, isUnsigned)) { + // TOSA spec does not allow casts from/to unsigned, but we partially do, to + // enable the folding of lowered qdq nodes + if (isa(fromType) && isa(toType) && + cast(toType).isUnsigned()) { // TOSA casts currently don't support unsigned integers. - // To support them by here, one could use APSInt instead of APInts, - // however, this causes trouble with `getValues` which does not support - // APSInts currently. + // Casting float to unsigned int would need a decision about how to handle + // negative floats return rewriter.notifyMatchFailure( - tosaCast, "Cast folding from/to unsigned integers is not supported."); + tosaCast, + "Cast folding from float to unsigned integers is not supported."); } - DenseElementsAttr res; if (auto intOutTy = dyn_cast(toType)) { if (isa(fromType)) { @@ -1015,7 +1029,7 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase { elements, &convertFloatToInt, intOutTy); } else { assert(isa(fromType)); - res = applyElementWise( + res = applyElementWise( elements, &convertIntToInt, intOutTy); } } else { @@ -1026,7 +1040,7 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase { elements, &convertFloatToFloat, floatOutTy); } else { assert(isa(fromType)); - res = applyElementWise( + res = applyElementWise( elements, &convertIntToFloat, floatOutTy); } } diff --git a/mlir/test/Dialect/Tosa/constant-cast-opt.mlir b/mlir/test/Dialect/Tosa/constant-cast-opt.mlir index 75339eacb67d5..74421a6ab8ba9 100644 --- a/mlir/test/Dialect/Tosa/constant-cast-opt.mlir +++ b/mlir/test/Dialect/Tosa/constant-cast-opt.mlir @@ -71,6 +71,20 @@ func.func @cast_fold_f32_to_i8() -> tensor<5xi8> { return %1 : tensor<5xi8> } +// CHECK-LABEL: @cast_fold_f32_to_ui8 +// COM: Do not fold casts from floats to uint +func.func @cast_fold_f32_to_ui8() -> tensor<5xui8> { + // CHECK: tosa.const + // CHECK-NOT: tensor<5xui8> + // CHECK: tosa.cast + %0 = "tosa.const"() {value = + dense<[12.0, 0.0, 5.0, 32770.11, -32770.11]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xui8> + return %1 : tensor<5xui8> +} + // CHECK-LABEL: @cast_fold_float_to_int_infinity_zero_nan func.func @cast_fold_float_to_int_infinity_zero_nan() -> tensor<5xi16> { // Check if infinity and zero are translated properly. Don't expect any @@ -116,6 +130,71 @@ func.func @cast_fold_i32_to_i8() -> tensor<5xi8> { return %1 : tensor<5xi8> } +// CHECK-LABEL: @cast_fold_i8_to_ui8 +func.func @cast_fold_i8_to_ui8() -> tensor<3xui8> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}4, 0, 251{{.*}}tensor<3xui8> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[4, 0, -5]> : + tensor<3xi8> + } : () -> tensor<3xi8> + %1 = "tosa.cast"(%0) : (tensor<3xi8>) -> tensor<3xui8> + return %1 : tensor<3xui8> +} + +// CHECK-LABEL: @cast_fold_ui8_to_i8 +func.func @cast_fold_ui8_to_i8() -> tensor<3xi8> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}4, 0, -6{{.*}}tensor<3xi8> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[4, 0, 250]> : + tensor<3xui8> + } : () -> tensor<3xui8> + %1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xi8> + return %1 : tensor<3xi8> +} + +// CHECK-LABEL: @cast_fold_ui8_to_i16 +func.func @cast_fold_ui8_to_i16() -> tensor<3xi16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}4, 0, 250{{.*}}tensor<3xi16> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[4, 0, 250]> : + tensor<3xui8> + } : () -> tensor<3xui8> + %1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xi16> + return %1 : tensor<3xi16> +} + +// CHECK-LABEL: @cast_fold_ui8_to_i1 +func.func @cast_fold_ui8_to_i1() -> tensor<3xi1> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true, false, true{{.*}}tensor<3xi1> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[4, 0, 250]> : + tensor<3xui8> + } : () -> tensor<3xui8> + %1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xi1> + return %1 : tensor<3xi1> +} + +// CHECK-LABEL: @cast_fold_ui8_to_ui1 +func.func @cast_fold_ui8_to_ui1() -> tensor<3xui1> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true, false, true{{.*}}tensor<3xui1> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[4, 0, 250]> : + tensor<3xui8> + } : () -> tensor<3xui8> + %1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xui1> + return %1 : tensor<3xui1> +} + // CHECK-LABEL: @cast_fold_i16_to_i1 func.func @cast_fold_i16_to_i1() -> tensor<3xi1> { @@ -172,6 +251,19 @@ func.func @cast_fold_i32_to_f16() -> tensor<4xf16> { return %1 : tensor<4xf16> } +// CHECK-LABEL: @cast_fold_ui8_to_f32 +func.func @cast_fold_ui8_to_f32() -> tensor<4xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0.000000e+00, 1.000000e+00, 4.000000e+00, 2.550000e+02{{.*}}tensor<4xf32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0, 1, 4, 255]> : + tensor<4xui8> + } : () -> tensor<4xui8> + %1 = "tosa.cast"(%0) : (tensor<4xui8>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + // ----- // Casts from float to float From 55654e920ddb05b6a60a44aa079bfe72cd4a7bfb Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 11 Mar 2025 11:27:33 +0000 Subject: [PATCH 15/31] Canonicalize 'self-concats' to tile --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 42 +++++++++++++++++++ mlir/test/Dialect/Tosa/fold_concats.mlir | 29 +++++++------ 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 732f794206cd8..518254eaf6919 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -60,9 +60,51 @@ struct ConcatOptimization : public OpRewritePattern { } }; +struct SelfConcatToTile : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ConcatOp concatOp, + PatternRewriter &rewriter) const override { + if (llvm::all_equal(concatOp->getUsers())) { + const auto concatUser = llvm::dyn_cast( + concatOp->getUses().begin()->getOwner()); + if (concatUser) { + // Try folding the concat into its consumer before rewriting it to a + // tile. + SmallVector replacementValues; + auto foldResult = rewriter.tryFold(concatUser, replacementValues); + if (foldResult.succeeded()) { + if (!replacementValues.empty()) { + rewriter.replaceOp(concatUser, replacementValues); + } + return success(); + } + } + } + + if (!llvm::all_equal(concatOp->getOperands())) { + return rewriter.notifyMatchFailure( + concatOp, "Requires all operands to be the same"); + } + const auto concatType = dyn_cast(concatOp.getType()); + if (!concatType || !concatType.hasRank()) { + return rewriter.notifyMatchFailure(concatOp, + "Requires concat to be ranked"); + } + SmallVector multiplies(concatType.getRank(), 1); + multiplies[concatOp.getAxis()] = concatOp->getNumOperands(); + auto tileOp = rewriter.createOrFold( + concatOp->getLoc(), concatOp.getType(), concatOp->getOperand(0), + multiplies); + rewriter.replaceOp(concatOp, {tileOp}); + return success(); + } +}; + void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + results.add(context); } struct SqrtReciprocalOptimization : public OpRewritePattern { diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir index ec54f27346c8b..e77aefcbe6353 100644 --- a/mlir/test/Dialect/Tosa/fold_concats.mlir +++ b/mlir/test/Dialect/Tosa/fold_concats.mlir @@ -5,10 +5,10 @@ func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { return %0 : tensor<1x2x7x7xf32> } -// CHECK-LABEL: func.func @single_concat( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { -// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> -// CHECK: return %[[VAL_1]] : tensor<1x2x7x7xf32> +// CHECK-LABEL: func.func @single_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// CHECK: return [[VAR_0_]] : tensor<1x2x7x7xf32> // CHECK: } // ----- @@ -19,11 +19,11 @@ func.func @concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf return %1 : tensor<2x2x7x7xf32> } -// CHECK-LABEL: func.func @concat_different_axis( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { -// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> -// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> -// CHECK: return %[[VAL_2]] : tensor<2x2x7x7xf32> +// CHECK-LABEL: func.func @concat_different_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.tile [[VAR_0_]] {multiples = array} : (tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> +// CHECK: return [[VAR_1_]] : tensor<2x2x7x7xf32> // CHECK: } // ----- @@ -84,10 +84,9 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x return %2 : tensor<1x4x8x8xf32> } -// CHECK-LABEL: func.func @partially_foldable( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x8x8xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> { -// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32> -// CHECK: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> -// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32> +// CHECK-LABEL: func.func @partially_foldable +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_1_]] {multiples = array} : (tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_0_]], [[VAR_0_]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> +// CHECK: return [[VAR_1_]] : tensor<1x4x8x8xf32> // CHECK: } From 358df158ca539a42a1525e17325e114f8977792f Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 11 Mar 2025 12:42:55 +0000 Subject: [PATCH 16/31] Fold consecutive tosa.tiles --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 15 +++++++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 25 +++++++++++++++++++ mlir/test/Dialect/Tosa/fold_concats.mlir | 5 ++-- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 518254eaf6919..c754cd2de6a56 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1362,6 +1362,21 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) { bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; }); if (allOnes && getInput1().getType() == getType()) return getInput1(); + + if (auto inputTile = getInput1().getDefiningOp()) { + if (!inputTile->hasOneUse()) { + return {}; + } + llvm::SmallVector newMultiplies{getMultiples()}; + for (auto [idx, multiplier] : llvm::enumerate(inputTile.getMultiples())) { + newMultiplies[idx] *= multiplier; + } + setMultiples(newMultiplies); + setOperand(inputTile->getOperand(0)); + getOperation()->setLoc( + FusedLoc::get(getContext(), {inputTile->getLoc(), getLoc()})); + return getResult(); + } return {}; } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index b341a774442ba..fe20a58fe809a 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -691,6 +691,31 @@ func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // ----- +// CHECK-LABEL: func.func @tile_fold_consecutive +func.func @tile_fold_consecutive(%arg0: tensor<3x4xf32>) -> tensor<6x16xf32> { + // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<6x16xf32> { + // CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<3x4xf32>) -> tensor<6x16xf32> + // CHECK: return [[VAR_0_]] : tensor<6x16xf32> + %0 = tosa.tile %arg0 { multiples = array }: (tensor<3x4xf32>) -> tensor<3x8xf32> + %1 = tosa.tile %0 { multiples = array }: (tensor<3x8xf32>) -> tensor<6x16xf32> + return %1 : tensor<6x16xf32> +} + +// ----- + +// CHECK-LABEL: func.func @tile_no_fold_consecutive_multi_use +func.func @tile_no_fold_consecutive_multi_use(%arg0: tensor<3x4xf32>) -> (tensor<3x8xf32>, tensor<6x16xf32>) { + // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> (tensor<3x8xf32>, tensor<6x16xf32>) { + // CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<3x4xf32>) -> tensor<3x8xf32> + // CHECK: [[VAR_1_:%.+]] = tosa.tile [[VAR_0_]] {multiples = array} : (tensor<3x8xf32>) -> tensor<6x16xf32> + // CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<3x8xf32>, tensor<6x16xf32> + %0 = tosa.tile %arg0 { multiples = array }: (tensor<3x4xf32>) -> tensor<3x8xf32> + %1 = tosa.tile %0 { multiples = array }: (tensor<3x8xf32>) -> tensor<6x16xf32> + return %0, %1 : tensor<3x8xf32>, tensor<6x16xf32> +} + +// ----- + // CHECK-LABEL: @tile_nofold func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> { // CHECK: tosa.tile diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir index e77aefcbe6353..409088fd0f3ec 100644 --- a/mlir/test/Dialect/Tosa/fold_concats.mlir +++ b/mlir/test/Dialect/Tosa/fold_concats.mlir @@ -21,9 +21,8 @@ func.func @concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf // CHECK-LABEL: func.func @concat_different_axis // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { -// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> -// CHECK: [[VAR_1_:%.+]] = tosa.tile [[VAR_0_]] {multiples = array} : (tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> -// CHECK: return [[VAR_1_]] : tensor<2x2x7x7xf32> +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> +// CHECK: return [[VAR_0_]] : tensor<2x2x7x7xf32> // CHECK: } // ----- From eb3aed527f089f8d278aaa5e3608e3d7f509e63f Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 11 Mar 2025 14:03:22 +0000 Subject: [PATCH 17/31] Extend concat -> slice canonicalization to remove concat inputs if possible --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 56 ++++++----- mlir/test/Dialect/Tosa/canonicalize.mlir | 92 +++++++++++++++++++ 2 files changed, 126 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index c754cd2de6a56..c35ffef61468a 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -653,35 +653,47 @@ struct ConcatSliceOptimization : public OpRewritePattern { llvm::SmallVector sliceStart(sliceOp.getStart()); llvm::ArrayRef sliceSize = sliceOp.getSize(); - - // Validate slice on the concatenated axis. Slicing along this - // axis should span only one of the inputs to the concatenate - // operation. - std::optional replaceWithSlice; + llvm::SmallVector requiredConcatInputs; + int64_t processedOriginalConcatInputSize = 0; + int64_t droppedConcatInputSize = 0; for (auto input : inputs) { - auto inputType = dyn_cast(input.getType()); + const auto inputType = dyn_cast(input.getType()); if (!inputType || !inputType.hasStaticShape()) return rewriter.notifyMatchFailure( sliceOp, "concat input must be a static ranked tensor"); - - if (sliceStart[axis] >= 0 && - (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) { - replaceWithSlice = rewriter - .create( - sliceOp.getLoc(), sliceOp.getType(), input, - rewriter.getDenseI64ArrayAttr(sliceStart), - rewriter.getDenseI64ArrayAttr(sliceSize)) - .getResult(); - break; + if (processedOriginalConcatInputSize < + (sliceStart[axis] + sliceSize[axis]) && + (processedOriginalConcatInputSize + inputType.getDimSize(axis)) > + sliceStart[axis]) { + if (requiredConcatInputs.empty()) { + droppedConcatInputSize = processedOriginalConcatInputSize; + } + requiredConcatInputs.push_back(input); } - sliceStart[axis] -= inputType.getDimSize(axis); + processedOriginalConcatInputSize += inputType.getDimSize(axis); } - - if (!replaceWithSlice) + if (requiredConcatInputs.size() == concatOp->getNumOperands()) { return rewriter.notifyMatchFailure( - sliceOp, "corresponding concat input not found for slice"); - - rewriter.replaceOp(sliceOp, replaceWithSlice.value()); + sliceOp, "Could not reduce number of inputs to preceding concat"); + } + if (requiredConcatInputs.size() != 1 && !concatOp->hasOneUse()) { + return rewriter.notifyMatchFailure( + sliceOp, + "Preceding concat must have a single use"); // Do not introduce new + // concats + } + if (requiredConcatInputs.empty()) { + return rewriter.notifyMatchFailure( + sliceOp, "degenerate slice with zero sized dim in output"); + } + sliceStart[axis] -= droppedConcatInputSize; + auto newConcat = rewriter.create(concatOp->getLoc(), + requiredConcatInputs, axis); + auto newSlice = rewriter.create( + sliceOp->getLoc(), sliceOp.getType(), newConcat, + rewriter.getDenseI64ArrayAttr(sliceStart), + rewriter.getDenseI64ArrayAttr(sliceSize)); + rewriter.replaceOp(sliceOp, newSlice); return success(); } }; diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index fe20a58fe809a..5460a89a543c4 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -829,6 +829,98 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : // ----- +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_start_overlap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x4xf32>) -> tensor<1x12x12x2xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x2xf32> +func.func @canonicalize_concat_slice_partial_concat_start_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> + return %1 : tensor<1x12x12x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_end_overlap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x4xf32>) -> tensor<1x12x12x2xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x2xf32> +func.func @canonicalize_concat_slice_partial_concat_end_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> + return %1 : tensor<1x12x12x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_all_overlap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x4xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x4xf32> +func.func @canonicalize_concat_slice_partial_concat_all_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x4xf32> + return %1 : tensor<1x12x12x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_multi_use +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> (tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>) { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> +// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32> +func.func @canonicalize_concat_slice_partial_concat_multi_use(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> (tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>) { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> + return %0, %1 : tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_zero_dim +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x0xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x0xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x0xf32> +// CHECK: } +func.func @canonicalize_concat_slice_zero_dim(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x0xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x0xf32> + return %1 : tensor<1x12x12x0xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x120x12x10x16xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<1x120x24x20x30xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x120x24x20x30xf32>) -> tensor<1x120x12x10x16xf32> +// CHECK: return [[VAR_1_]] : tensor<1x120x12x10x16xf32> +func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10xf32>) -> tensor<1x120x12x10x16xf32> { + %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x120x12x10x16xf32> + return %1 : tensor<1x120x12x10x16xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice_multi_output +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> (tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>) { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x12x12x10x16xf32> +// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32> +func.func @canonicalize_tile_slice_multi_output(%arg0 : tensor<1x12x12x10x10xf32>) -> (tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>) { + %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x12x12x10x16xf32> + return %0, %1 : tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32> +} + +// ----- + // CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal func.func @canonicalize_optimize_sqrt_reciprocal(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> { // CHECK: %[[RSQRT:.*]] = tosa.rsqrt %arg{{.*}} : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> From 0aad4d94cb11c268ed76d8ba6a32da09be05b12a Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 11 Mar 2025 17:07:23 +0000 Subject: [PATCH 18/31] Add canonicalization pattern for tile -> slice to minimize the tile multipliers --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 70 ++++++++++++++++++- mlir/test/Dialect/Tosa/canonicalize.mlir | 27 +++++-- 2 files changed, 88 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index c35ffef61468a..fb52c0502db34 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -687,8 +687,8 @@ struct ConcatSliceOptimization : public OpRewritePattern { sliceOp, "degenerate slice with zero sized dim in output"); } sliceStart[axis] -= droppedConcatInputSize; - auto newConcat = rewriter.create(concatOp->getLoc(), - requiredConcatInputs, axis); + auto newConcat = rewriter.create( + concatOp->getLoc(), requiredConcatInputs, axis); auto newSlice = rewriter.create( sliceOp->getLoc(), sliceOp.getType(), newConcat, rewriter.getDenseI64ArrayAttr(sliceStart), @@ -698,9 +698,75 @@ struct ConcatSliceOptimization : public OpRewritePattern { } }; +/// This patterns adjust the multipliers of a tile followed by a slice to only +/// tile as much data as it is required by the slice +struct TileSliceOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, + PatternRewriter &rewriter) const override { + Value sliceInput = sliceOp.getInput1(); + auto tileOp = sliceInput.getDefiningOp(); + if (!tileOp) + return rewriter.notifyMatchFailure(sliceOp, + "slice input must be tile operation"); + if (!tileOp->hasOneUse()) + return rewriter.notifyMatchFailure( + sliceOp, "preceding tile must have a single use"); // Do not insert + // additional tiles + + const auto tileOpInputType = + dyn_cast(tileOp->getOperand(0).getType()); + if (!tileOpInputType || !tileOpInputType.hasStaticShape()) + return rewriter.notifyMatchFailure( + sliceOp, "input to preceding tile op must be a static ranked tensor"); + llvm::SmallVector requiredMultipliers; + llvm::SmallVector newTileStarts; + requiredMultipliers.reserve(tileOpInputType.getRank()); + newTileStarts.reserve(tileOpInputType.getRank()); + for (auto [axis, sliceStart, sliceSize] : + llvm::enumerate(sliceOp.getStart(), sliceOp.getSize())) { + if (sliceSize <= 0) { + return rewriter.notifyMatchFailure( + sliceOp, "degenerate slice with zero sized dim"); + } + const int64_t tileInputDimSize = tileOpInputType.getDimSize(axis); + const int64_t sliceOffsetInNewFirstTile = sliceStart % tileInputDimSize; + const int64_t sliceSizeInFirstTile = + std::min(tileInputDimSize - sliceOffsetInNewFirstTile, sliceSize); + assert(sliceSizeInFirstTile > 0); + const int64_t requiredMultiplierWithoutFirstTile = + llvm::divideCeil(sliceSize - sliceSizeInFirstTile, tileInputDimSize); + const int64_t requiredMultiplier = + requiredMultiplierWithoutFirstTile + (sliceSizeInFirstTile != 0); + assert(requiredMultiplier <= tileOp.getMultiples()[axis]); + requiredMultipliers.push_back(requiredMultiplier); + newTileStarts.push_back(sliceOffsetInNewFirstTile); + } + if (requiredMultipliers == tileOp.getMultiples()) + return rewriter.notifyMatchFailure( + sliceOp, "could not reduce multipliers in preceding tile"); + + llvm::SmallVector newTileShape(tileOpInputType.getShape()); + for (auto [newShape, multiplier] : + llvm::zip_equal(newTileShape, requiredMultipliers)) { + newShape *= multiplier; + } + auto newTile = rewriter.create( + tileOp->getLoc(), tileOpInputType.clone(newTileShape), + tileOp->getOperand(0), requiredMultipliers); + auto newSlice = rewriter.create( + sliceOp->getLoc(), sliceOp.getType(), newTile, + rewriter.getDenseI64ArrayAttr(newTileStarts), sliceOp.getSizeAttr()); + rewriter.replaceOp(sliceOp, newSlice); + return success(); + } +}; + void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + results.add(context); } struct MinToClampOptimization : public OpRewritePattern { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 5460a89a543c4..2bdbd44cb31be 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -896,14 +896,27 @@ func.func @canonicalize_concat_slice_zero_dim(%arg0 : tensor<1x12x12x2xf32>, %ar // ----- // CHECK-LABEL: func.func @canonicalize_tile_slice -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x120x12x10x16xf32> { -// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<1x120x24x20x30xf32> -// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x120x24x20x30xf32>) -> tensor<1x120x12x10x16xf32> -// CHECK: return [[VAR_1_]] : tensor<1x120x12x10x16xf32> -func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10xf32>) -> tensor<1x120x12x10x16xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x12x10x16x5xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x24x20x30x10xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x120x24x20x30x10xf32>) -> tensor<1x120x12x10x16x5xf32> +// CHECK: return [[VAR_1_]] : tensor<1x120x12x10x16x5xf32> +func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x12x10x16x5xf32> { + %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10x10xf32>) -> tensor<10x120x120x100x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100x100xf32>) -> tensor<1x120x12x10x16x5xf32> + return %1 : tensor<1x120x12x10x16x5xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice_zero_dim +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x0x12x10x16xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x0x12x10x16xf32> +// CHECK: return [[VAR_1_]] : tensor<1x0x12x10x16xf32> +func.func @canonicalize_tile_slice_zero_dim(%arg0 : tensor<1x12x12x10x10xf32>) -> tensor<1x0x12x10x16xf32> { %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> - %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x120x12x10x16xf32> - return %1 : tensor<1x120x12x10x16xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x0x12x10x16xf32> + return %1 : tensor<1x0x12x10x16xf32> } // ----- From b87be4ffe402e38b4c5ee533907d6722d445f16f Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Wed, 12 Mar 2025 13:11:37 +0000 Subject: [PATCH 19/31] Add check for canonicalization of slice(concat %a, %a)) = %a --- mlir/test/Dialect/Tosa/canonicalize.mlir | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 2bdbd44cb31be..3aae58de87518 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -908,6 +908,17 @@ func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10x10xf32>) -> tens // ----- +// CHECK-LABEL: func.func @canonicalize_self_concat_slice +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { +// CHECK: return [[PARAM_0_]] : tensor<1x2x3x4xf32> +func.func @canonicalize_self_concat_slice(%arg0 : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { + %0 = tosa.concat %arg0, %arg0 {axis = 3 : i32} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x8xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x2x3x8xf32>) -> tensor<1x2x3x4xf32> + return %1 : tensor<1x2x3x4xf32> +} + +// ----- + // CHECK-LABEL: func.func @canonicalize_tile_slice_zero_dim // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x0x12x10x16xf32> { // CHECK: [[VAR_0_:%.+]] = tosa.tile [[PARAM_0_]] {multiples = array} : (tensor<1x12x12x10x10xf32>) -> tensor<10x120x120x100x100xf32> From 2afdeee3e971a1e91714013f1a326c99737c0459 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Thu, 13 Mar 2025 08:14:28 +0000 Subject: [PATCH 20/31] Add test for complete folding of tile + slice --- mlir/test/Dialect/Tosa/canonicalize.mlir | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 3aae58de87518..64143f477c85d 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -908,6 +908,17 @@ func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10x10xf32>) -> tens // ----- +// CHECK-LABEL: func.func @canonicalize_tile_slice_fold +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10x10xf32>) -> tensor<1x12x12x10x10x10xf32> { +// CHECK: return [[PARAM_0_]] : tensor<1x12x12x10x10x10xf32> +func.func @canonicalize_tile_slice_fold(%arg0 : tensor<1x12x12x10x10x10xf32>) -> tensor<1x12x12x10x10x10xf32> { + %0 = tosa.tile %arg0 {multiples = array} : (tensor<1x12x12x10x10x10xf32>) -> tensor<10x120x120x100x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100x100xf32>) -> tensor<1x12x12x10x10x10xf32> + return %1 : tensor<1x12x12x10x10x10xf32> +} + +// ----- + // CHECK-LABEL: func.func @canonicalize_self_concat_slice // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { // CHECK: return [[PARAM_0_]] : tensor<1x2x3x4xf32> From bb99503626b6efd2bd87a216ff279181cc6ec48f Mon Sep 17 00:00:00 2001 From: Jonas Rickert Date: Fri, 21 Mar 2025 02:37:54 -0600 Subject: [PATCH 21/31] Pick getTosaConstShape helper from 571a987 --- .../mlir/Dialect/Tosa/Utils/ConversionUtils.h | 9 ++++- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 33 ++++++------------- .../Dialect/Tosa/Utils/ConversionUtils.cpp | 22 +++++++++++++ 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h index 90fea1f68beb5..0d9c76f31d78f 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h @@ -84,6 +84,12 @@ LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1, Value &input2); +Value getTosaConstShape(ImplicitLocOpBuilder &builder, + llvm::ArrayRef shape); + +Value getTosaConstShape(PatternRewriter &rewriter, Location loc, + llvm::ArrayRef shape); + namespace { // Creates a TOSA operation and performs shape inference on the individual @@ -217,7 +223,8 @@ TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc, } // Apply an int32_t permutation to some input, that should be of the same -// size as perms. Perms should contain some permutation of 0 - perms.size() - 1. +// size as perms. Perms should contain some permutation of 0 - perms.size() +// - 1. template SmallVector applyTOSAPermutation(ArrayRef input, ArrayRef perms) { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 5c83393cb2d8e..fb9dca815124f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -93,15 +93,11 @@ struct SelfConcatToTile : public OpRewritePattern { } SmallVector multiplies(concatType.getRank(), 1); multiplies[concatOp.getAxis()] = concatOp->getNumOperands(); - const int64_t rank = multiplies.size(); - auto constantShapeOp = rewriter.create( - concatOp->getLoc(), shapeType::get(concatOp->getContext(), rank), - DenseIntElementsAttr::get( - RankedTensorType::get({rank}, rewriter.getIndexType()), - multiplies)); + auto constantShapeValue = + getTosaConstShape(rewriter, concatOp->getLoc(), multiplies); auto tileOp = rewriter.createOrFold( concatOp->getLoc(), concatOp.getType(), concatOp->getOperand(0), - constantShapeOp); + constantShapeValue); rewriter.replaceOp(concatOp, {tileOp}); return success(); } @@ -140,19 +136,15 @@ struct FuseChainedTile : public OpRewritePattern { for (auto [idx, multiplier] : llvm::enumerate(inputTileMultiples)) { multiplies[idx] *= multiplier; } - auto constantShapeOp = rewriter.create( + auto constantShapeValue = getTosaConstShape( + rewriter, rewriter.getFusedLoc( {op.getMultiples().getLoc(), inputTile.getMultiples().getLoc()}), - op.getMultiples().getType(), - DenseIntElementsAttr::get( - RankedTensorType::get( - {cast(op.getMultiples().getType()).getRank()}, - rewriter.getIndexType()), - multiplies)); + multiplies); rewriter.modifyOpInPlace(op, [&]() { op.setOperand(0, inputTile->getOperand(0)); - op.setOperand(1, constantShapeOp); + op.setOperand(1, constantShapeValue); op.getOperation()->setLoc( FusedLoc::get(getContext(), {inputTile->getLoc(), op.getLoc()})); }); @@ -828,16 +820,11 @@ struct TileSliceOptimization : public OpRewritePattern { llvm::zip_equal(newTileShape, requiredMultipliers)) { newShape *= multiplier; } - auto constantShapeOp = rewriter.create( - tileOp.getMultiples().getLoc(), tileOp.getMultiples().getType(), - DenseIntElementsAttr::get( - RankedTensorType::get( - {cast(tileOp.getMultiples().getType()).getRank()}, - rewriter.getIndexType()), - requiredMultipliers)); + auto constantShapeValue = getTosaConstShape( + rewriter, tileOp.getMultiples().getLoc(), requiredMultipliers); auto newTile = rewriter.create( tileOp->getLoc(), tileOpInputType.clone(newTileShape), - tileOp->getOperand(0), constantShapeOp); + tileOp->getOperand(0), constantShapeValue); auto newSlice = rewriter.create( sliceOp->getLoc(), sliceOp.getType(), newTile, rewriter.getDenseI64ArrayAttr(newTileStarts), sliceOp.getSizeAttr()); diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index 1f6e3b2ab8391..db68895efa21a 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -160,3 +160,25 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder, return success(); } + +namespace { +SmallVector convertFromMlirShape(ArrayRef shape) { + return to_vector(llvm::map_range(shape, [](int64_t dim) { + return ShapedType::isDynamic(dim) ? -1 : dim; + })); +} +} // namespace + +Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder, + llvm::ArrayRef shape) { + auto attr = builder.getIndexTensorAttr(convertFromMlirShape(shape)); + auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size()); + mlir::Operation *mlir_op = builder.create(type, attr); + return mlir_op->getResult(0); +} + +Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc, + llvm::ArrayRef shape) { + ImplicitLocOpBuilder builder(loc, rewriter); + return getTosaConstShape(builder, shape); +} From 51065a3ad806b376bdb9e979fec5f3528c85eddb Mon Sep 17 00:00:00 2001 From: Jonas Rickert Date: Mon, 24 Mar 2025 02:01:45 -0600 Subject: [PATCH 22/31] Do not link internal mlir-libs shared, even if MLIR_LINK_MLIR_DYLIB is set, as it causes options be registered more than once --- mlir/cmake/modules/AddMLIR.cmake | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake index 9c7b00b660ba7..4933cafa41ed6 100644 --- a/mlir/cmake/modules/AddMLIR.cmake +++ b/mlir/cmake/modules/AddMLIR.cmake @@ -732,7 +732,8 @@ function(mlir_target_link_libraries target type) endif() if (MLIR_LINK_MLIR_DYLIB) - target_link_libraries(${target} ${type} MLIR) + # AMD: Do not link shared, as this casues linking errors + target_link_libraries(${target} ${type} ${ARGN}) else() target_link_libraries(${target} ${type} ${ARGN}) endif() From de23f583f0d1ceea1a2bdf756f80b09512a37ebb Mon Sep 17 00:00:00 2001 From: Jonas Rickert Date: Mon, 24 Mar 2025 02:13:45 -0600 Subject: [PATCH 23/31] Move getConvOpsAccType from torch-mlir 12250739bfe85b702f9503cad45c2e535ea8eb18 to LLVM --- .../mlir/Dialect/Tosa/Utils/ConversionUtils.h | 6 +++ .../Dialect/Tosa/Utils/ConversionUtils.cpp | 49 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h index 0d9c76f31d78f..f49332eb54290 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h @@ -90,6 +90,12 @@ Value getTosaConstShape(ImplicitLocOpBuilder &builder, Value getTosaConstShape(PatternRewriter &rewriter, Location loc, llvm::ArrayRef shape); +// Get accumulator type for TOSA convolution ops +LogicalResult getConvOpsAccType(PatternRewriter &rewriter, + RankedTensorType inputTy, + RankedTensorType weightTy, + RankedTensorType outputTy, TypeAttr &accType); + namespace { // Creates a TOSA operation and performs shape inference on the individual diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index db68895efa21a..ad2363d5c4140 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" using namespace mlir; @@ -182,3 +183,51 @@ Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc, ImplicitLocOpBuilder builder(loc, rewriter); return getTosaConstShape(builder, shape); } + +// AMD: Picked from torch-mlir 12250739bfe85b702f9503cad45c2e535ea8eb18 +// Get accumulator type for TOSA convolution ops +LogicalResult mlir::tosa ::getConvOpsAccType(PatternRewriter &rewriter, + RankedTensorType inputTy, + RankedTensorType weightTy, + RankedTensorType outputTy, + TypeAttr &accType) { + auto inputElemTy = inputTy.getElementType(); + auto weightElemTy = weightTy.getElementType(); + auto outputElemTy = outputTy.getElementType(); + + auto quantTy = dyn_cast(inputElemTy); + if (quantTy) + inputElemTy = quantTy.getStorageType(); + + // Get TOSA conv ops acc type based on input, weight, and output types + // according to the spec: + // https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d + // https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d + // https://www.mlplatform.org/tosa/tosa_spec.html#_conv3d + // + // For undefined dtypes in TOSA like I64 and F64, acc_type will be set to the + // output type but does not offer any guarantee on the numerical precision + // since such cases will fail TOSA validation. + if ((inputElemTy.isF32() && weightElemTy.isF32() && outputElemTy.isF32()) || + (inputElemTy.isF16() && weightElemTy.isF16() && outputElemTy.isF16()) || + (inputElemTy.isBF16() && weightElemTy.isBF16() && + outputElemTy.isBF16())) { + accType = mlir::TypeAttr::get(rewriter.getF32Type()); + } else if (inputElemTy.isInteger(8) && + (weightElemTy.isInteger(8) || weightElemTy.isInteger(4)) && + outputElemTy.isInteger(32)) { + accType = mlir::TypeAttr::get(rewriter.getIntegerType(32)); + } else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) && + outputElemTy.isInteger(48)) { + accType = mlir::TypeAttr::get(rewriter.getIntegerType(48)); + } else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() && + outputElemTy.isF16()) || + (inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() && + outputElemTy.isF16())) { + accType = mlir::TypeAttr::get(rewriter.getF16Type()); + } else { + accType = mlir::TypeAttr::get(outputElemTy); + } + + return success(); +} From dd65947e2cfbf3975211f0f6670d40e06fe7fce1 Mon Sep 17 00:00:00 2001 From: mathmer-amd Date: Tue, 18 Mar 2025 09:05:33 -0600 Subject: [PATCH 24/31] feat: allow arrayAttr parsing in constraint --- mlir/include/mlir/Dialect/PDL/IR/Builtins.h | 5 ++- mlir/lib/Dialect/PDL/IR/Builtins.cpp | 50 ++++++++++----------- mlir/lib/Tools/PDLL/Parser/Parser.cpp | 32 +++++++++---- mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll | 32 +++++++++++-- mlir/test/mlir-pdll/Parser/expr.pdll | 37 ++++++++++++++- mlir/unittests/Dialect/PDL/BuiltinTest.cpp | 11 +++-- 6 files changed, 124 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/PDL/IR/Builtins.h b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h index 0c6cceb54b68a..5e60391e278eb 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/Builtins.h +++ b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h @@ -43,8 +43,9 @@ enum class UnaryOpKind { LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter, PDLResultList &results, ArrayRef args); -Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr, - Attribute element); +LogicalResult addElemToArrayAttr(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args); LogicalResult mul(PatternRewriter &rewriter, PDLResultList &results, llvm::ArrayRef args); LogicalResult div(PatternRewriter &rewriter, PDLResultList &results, diff --git a/mlir/lib/Dialect/PDL/IR/Builtins.cpp b/mlir/lib/Dialect/PDL/IR/Builtins.cpp index 9e4efbf7e71c0..3e561461e389f 100644 --- a/mlir/lib/Dialect/PDL/IR/Builtins.cpp +++ b/mlir/lib/Dialect/PDL/IR/Builtins.cpp @@ -38,13 +38,17 @@ LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter, return success(); } -mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter, - mlir::Attribute attr, - mlir::Attribute element) { - assert(isa(attr)); - auto values = cast(attr).getValue().vec(); - values.push_back(element); - return rewriter.getArrayAttr(values); +LogicalResult addElemToArrayAttr(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args) { + + auto arrayAttr = cast(args[0].cast()); + auto attrElement = args[1].cast(); + std::vector values = arrayAttr.getValue().vec(); + values.push_back(attrElement); + + results.push_back(rewriter.getArrayAttr(values)); + return success(); } template @@ -344,11 +348,15 @@ void registerBuiltins(PDLPatternModule &pdlPattern) { // See Parser::defineBuiltins() pdlPattern.registerRewriteFunction( "__builtin_addEntryToDictionaryAttr_rewrite", addEntryToDictionaryAttr); - pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr", - addElemToArrayAttr); pdlPattern.registerConstraintFunction( "__builtin_addEntryToDictionaryAttr_constraint", addEntryToDictionaryAttr); + + pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttrRewriter", + addElemToArrayAttr); + pdlPattern.registerConstraintFunction( + "__builtin_addElemToArrayAttrConstraint", addElemToArrayAttr); + pdlPattern.registerRewriteFunction("__builtin_mulRewrite", mul); pdlPattern.registerRewriteFunction("__builtin_divRewrite", div); pdlPattern.registerRewriteFunction("__builtin_modRewrite", mod); @@ -357,22 +365,14 @@ void registerBuiltins(PDLPatternModule &pdlPattern) { pdlPattern.registerRewriteFunction("__builtin_log2Rewrite", log2); pdlPattern.registerRewriteFunction("__builtin_exp2Rewrite", exp2); pdlPattern.registerRewriteFunction("__builtin_absRewrite", abs); - pdlPattern.registerConstraintFunction("__builtin_mulConstraint", - mul); - pdlPattern.registerConstraintFunction("__builtin_divConstraint", - div); - pdlPattern.registerConstraintFunction("__builtin_modConstraint", - mod); - pdlPattern.registerConstraintFunction("__builtin_addConstraint", - add); - pdlPattern.registerConstraintFunction("__builtin_subConstraint", - sub); - pdlPattern.registerConstraintFunction("__builtin_log2Constraint", - log2); - pdlPattern.registerConstraintFunction("__builtin_exp2Constraint", - exp2); - pdlPattern.registerConstraintFunction("__builtin_absConstraint", - abs); + pdlPattern.registerConstraintFunction("__builtin_mulConstraint", mul); + pdlPattern.registerConstraintFunction("__builtin_divConstraint", div); + pdlPattern.registerConstraintFunction("__builtin_modConstraint", mod); + pdlPattern.registerConstraintFunction("__builtin_addConstraint", add); + pdlPattern.registerConstraintFunction("__builtin_subConstraint", sub); + pdlPattern.registerConstraintFunction("__builtin_log2Constraint", log2); + pdlPattern.registerConstraintFunction("__builtin_exp2Constraint", exp2); + pdlPattern.registerConstraintFunction("__builtin_absConstraint", abs); pdlPattern.registerConstraintFunction("__builtin_equals", equals); } } // namespace mlir::pdl diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 0250ecb0f7f28..aacb049f32b09 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -625,7 +625,8 @@ class Parser { struct { ast::UserRewriteDecl *addEntryToDictionaryAttr_Rewrite; ast::UserConstraintDecl *addEntryToDictionaryAttr_Constraint; - ast::UserRewriteDecl *addElemToArrayAttr; + ast::UserRewriteDecl *addElemToArrayAttrRewrite; + ast::UserConstraintDecl *addElemToArrayAttrConstraint; ast::UserRewriteDecl *mulRewrite; ast::UserRewriteDecl *divRewrite; ast::UserRewriteDecl *modRewrite; @@ -691,9 +692,13 @@ void Parser::declareBuiltins() { "__builtin_addEntryToDictionaryAttr_constraint", {"attr", "attrName", "attrEntry"}, /*returnsAttr=*/true); - builtins.addElemToArrayAttr = declareBuiltin( - "__builtin_addElemToArrayAttr", {"attr", "element"}, + builtins.addElemToArrayAttrRewrite = declareBuiltin( + "__builtin_addElemToArrayAttrRewriter", {"attr", "element"}, /*returnsAttr=*/true); + builtins.addElemToArrayAttrConstraint = + declareBuiltin( + "__builtin_addElemToArrayAttrConstraint", {"attr", "element"}, + /*returnsAttr=*/true); builtins.mulRewrite = declareBuiltin( "__builtin_mulRewrite", {"lhs", "rhs"}, true); builtins.divRewrite = declareBuiltin( @@ -2323,27 +2328,35 @@ FailureOr Parser::parseArrayAttrExpr() { consumeToken(Token::l_square); + ast::Decl *builtinFunction = builtins.addElemToArrayAttrRewrite; if (parserContext != ParserContext::Rewrite) - return emitError( - "Parsing of array attributes as constraint not supported!"); + builtinFunction = builtins.addElemToArrayAttrConstraint; - FailureOr arrayAttr = ast::AttributeExpr::create(ctx, curToken.getLoc(), "[]"); + FailureOr arrayAttr = + ast::AttributeExpr::create(ctx, curToken.getLoc(), "[]"); if (failed(arrayAttr)) return failure(); + // No values inside the array + if (consumeIf(Token::r_square)) { + return arrayAttr; + } + do { FailureOr attr = parseExpr(); if (failed(attr)) return failure(); SmallVector arrayAttrArgs{*arrayAttr, *attr}; - auto elemToArrayCall = createBuiltinCall( - curToken.getLoc(), builtins.addElemToArrayAttr, arrayAttrArgs); + + auto elemToArrayCall = + createBuiltinCall(curToken.getLoc(), builtinFunction, arrayAttrArgs); if (failed(elemToArrayCall)) return failure(); // Uses the new array for the next element. arrayAttr = elemToArrayCall; + } while (consumeIf(Token::comma)); if (failed( @@ -2415,7 +2428,8 @@ FailureOr Parser::parseDictAttrExpr() { consumeToken(Token::l_brace); SMRange loc = curToken.getLoc(); - FailureOr dictAttrCall = ast::AttributeExpr::create(ctx, loc, "{}"); + FailureOr dictAttrCall = + ast::AttributeExpr::create(ctx, loc, "{}"); if (failed(dictAttrCall)) return failure(); diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll index 84cba9035123f..c2509db6b42ce 100644 --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -218,7 +218,7 @@ Pattern RewriteMultipleEntriesDictionary { // CHECK: %[[VAL_4:.*]] = attribute = "firstAttr" // CHECK: %[[VAL_5:.*]] = attribute = "test1" // CHECK: %[[VAL_6:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]] -// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_2]], %[[VAL_6]] +// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_2]], %[[VAL_6]] // CHECK: %[[VAL_8:.*]] = operation "test.success" {"some_array" = %[[VAL_7]]} // CHECK: replace %[[VAL_1]] with %[[VAL_8]] Pattern RewriteOneDictionaryArrayAttr { @@ -229,6 +229,32 @@ Pattern RewriteOneDictionaryArrayAttr { }; } +// ----- + +// CHECK-LABEL: pdl.pattern @ConstraintWithArrayAttr +// CHECK: %[[VAL_0:.*]] = attribute = "test1" +// CHECK: %[[VAL_1:.*]] = attribute = "test2" +// CHECK: %[[VAL_2:.*]] = attribute = [] +// CHECK: %[[VAL_3:.*]] = apply_native_constraint "__builtin_addElemToArrayAttrConstraint"(%[[VAL_2]], %[[VAL_0]] +// CHECK: %[[VAL_4:.*]] = apply_native_constraint "__builtin_addElemToArrayAttrConstraint"(%[[VAL_3]], %[[VAL_1]] +// CHECK: %[[VAL_5:.*]] = operation "test.op" +// CHECK: rewrite %[[VAL_5]] { +// CHECK: %[[VAL_6:.*]] = operation "test.success" {"some_array" = %[[VAL_4]]} +// CHECK: replace %[[VAL_5]] with %[[VAL_6]] + +Pattern ConstraintWithArrayAttr { + let attr1 = attr<"\"test1\"">; + let attr2 = attr<"\"test2\"">; + let array = [attr1, attr2]; + let root = op -> (); + rewrite root with { + let newRoot = op() { some_array = array} -> (); + replace root with newRoot; + }; +} + + + // ----- // CHECK-LABEL: pdl.pattern @RewriteMultiplyElementsArrayAttr @@ -240,8 +266,8 @@ Pattern RewriteOneDictionaryArrayAttr { // CHECK: %[[VAL_5:.*]] = attribute = "firstAttr" // CHECK: %[[VAL_6:.*]] = attribute = "test1" // CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]] -// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_3]], %[[VAL_7]] -// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_8]], %[[VAL_2]] +// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_3]], %[[VAL_7]] +// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_8]], %[[VAL_2]] // CHECK: %[[VAL_10:.*]] = operation "test.success" {"some_array" = %[[VAL_9]]} // CHECK: replace %[[VAL_1]] with %[[VAL_10]] Pattern RewriteMultiplyElementsArrayAttr { diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index fe3d8956b3dd7..0eeeaddd2afd5 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -34,7 +34,7 @@ Pattern { // CHECK-LABEL: Module // CHECK: |-NamedAttributeDecl {{.*}} Name -// CHECK: `-UserRewriteDecl {{.*}} Name<__builtin_addElemToArrayAttr> ResultType +// CHECK: `-UserRewriteDecl {{.*}} Name<__builtin_addElemToArrayAttrRewriter> ResultType // CHECK: `Arguments` // CHECK: CallExpr {{.*}} Type // CHECK: AttributeExpr {{.*}} Value<"[]"> @@ -87,6 +87,41 @@ Constraint getPopulatedDict() -> Attr { return dictionary; } + + +// ----- + +// CHECK-LABEL: Module +// CHECK:LetStmt {{.*}} +//CHECK-NEXT:`-VariableDecl {{.*}} Name Type +//CHECK-NEXT: `-AttributeExpr {{.*}} Value<"[]"> +//CHECK-NEXT:ReturnStmt {{.*}} + +Constraint getEmtpyArray() -> Attr { + let array = []; + return array; +} + +// ----- + +// CHECK-LABEL: Module +// CHECK:LetStmt {{.*}} +//CHECK-NEXT:`-VariableDecl {{.*}} Name Type +//CHECK-NEXT: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType +// CHECK: `Arguments` +//CHECK-NEXT: |-AttributeExpr {{.*}} Value<"[]"> +//CHECK-NEXT: `-AttributeExpr {{.*}} Value<""attr1""> +//CHECK-NEXT:ReturnStmt {{.*}} + +Constraint getPopulateArray() -> Attr { + let array = ["attr1"]; + return array; +} + + + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp index 21a620e3b6675..113fc1ff8640f 100644 --- a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp +++ b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp @@ -66,13 +66,17 @@ TEST_F(BuiltinTest, addEntryToDictionaryAttr) { } TEST_F(BuiltinTest, addElemToArrayAttr) { + TestPDLResultList results(1); + auto dict = rewriter.getDictionaryAttr( rewriter.getNamedAttr("key", rewriter.getStringAttr("value"))); rewriter.getArrayAttr({}); auto arrAttr = rewriter.getArrayAttr({}); + EXPECT_TRUE(succeeded( + builtin::addElemToArrayAttr(rewriter, results, {arrAttr, dict}))); mlir::Attribute updatedArrAttr = - builtin::addElemToArrayAttr(rewriter, arrAttr, dict); + results.getResults().front().cast(); auto dictInsideArrAttr = cast(*cast(updatedArrAttr).begin()); @@ -617,7 +621,7 @@ TEST_F(BuiltinTest, log2) { cast(result.cast()).getValue().convertToFloat(), 2.0); } - + auto threeF16 = rewriter.getF16FloatAttr(3.0); // check correctness @@ -626,7 +630,8 @@ TEST_F(BuiltinTest, log2) { EXPECT_TRUE(builtin::log2(rewriter, results, {threeF16}).succeeded()); PDLValue result = results.getResults()[0]; - float resultVal = cast(result.cast()).getValue().convertToFloat(); + float resultVal = + cast(result.cast()).getValue().convertToFloat(); EXPECT_TRUE(resultVal > 1.58 && resultVal < 1.59); } } From 4c14f6677200f3ccc09c9fcb877bb5e0b00ee742 Mon Sep 17 00:00:00 2001 From: mathmer-amd Date: Tue, 18 Mar 2025 09:52:01 -0600 Subject: [PATCH 25/31] test: check some edge cases --- mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll | 11 +++++++++++ mlir/test/mlir-pdll/Parser/expr-failure.pdll | 16 ++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll index c2509db6b42ce..b876eecadfbce 100644 --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -253,6 +253,17 @@ Pattern ConstraintWithArrayAttr { }; } +// ----- + +// CHECK-LABEL: pdl.pattern @ConstraintNotMatchingArrayAttrInAttrType +// CHECK-NOT: apply_native_constraint "__builtin_addElemToArrayAttrConstraint" + + +Constraint I64Value(value: Value); +Pattern ConstraintNotMatchingArrayAttrInAttrType { + let root = op(arg: Value, arg2: Value, arg3: [Value, I64Value], arg); + replace root with arg; +} // ----- diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll index 34cf54fb7c23d..9d1218f124009 100644 --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -134,6 +134,22 @@ Pattern { // ----- +Pattern ConstraintArrayAttrWithAttrAndValue { + let root = op(arg: Value) -> (); + let attr1 = attr<"\"test1\"">; + let array = [attr1, arg]; + // CHECK: unable to convert expression of type `Value` to the expected type of `Attr` + let root = op -> (); + rewrite root with { + let newRoot = op() { some_array = array} -> (); + replace root with newRoot; + }; +} + +// ----- + + + //===----------------------------------------------------------------------===// // Range Expr //===----------------------------------------------------------------------===// From cba66f41545bae9244f8e2bb43b1adc187ba77fc Mon Sep 17 00:00:00 2001 From: mathmer-amd Date: Tue, 18 Mar 2025 10:05:20 -0600 Subject: [PATCH 26/31] chore: add assertion --- mlir/lib/Dialect/PDL/IR/Builtins.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Dialect/PDL/IR/Builtins.cpp b/mlir/lib/Dialect/PDL/IR/Builtins.cpp index 3e561461e389f..c928b5135a448 100644 --- a/mlir/lib/Dialect/PDL/IR/Builtins.cpp +++ b/mlir/lib/Dialect/PDL/IR/Builtins.cpp @@ -42,6 +42,8 @@ LogicalResult addElemToArrayAttr(PatternRewriter &rewriter, PDLResultList &results, ArrayRef args) { + assert(args.size() == 2 && + "Expected two arguments, one ArrayAttr and one Attr"); auto arrayAttr = cast(args[0].cast()); auto attrElement = args[1].cast(); std::vector values = arrayAttr.getValue().vec(); From a474457ff47dd54bf3d698616f5dd5c26bd27078 Mon Sep 17 00:00:00 2001 From: mathmer-amd Date: Wed, 19 Mar 2025 03:19:07 -0600 Subject: [PATCH 27/31] chore: use llvm::SmallVector --- mlir/lib/Dialect/PDL/IR/Builtins.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/PDL/IR/Builtins.cpp b/mlir/lib/Dialect/PDL/IR/Builtins.cpp index c928b5135a448..bc238092c141d 100644 --- a/mlir/lib/Dialect/PDL/IR/Builtins.cpp +++ b/mlir/lib/Dialect/PDL/IR/Builtins.cpp @@ -46,7 +46,7 @@ LogicalResult addElemToArrayAttr(PatternRewriter &rewriter, "Expected two arguments, one ArrayAttr and one Attr"); auto arrayAttr = cast(args[0].cast()); auto attrElement = args[1].cast(); - std::vector values = arrayAttr.getValue().vec(); + llvm::SmallVector values(arrayAttr.getValue()); values.push_back(attrElement); results.push_back(rewriter.getArrayAttr(values)); From a008a59694bd3b26c22746aa454c93414e522bae Mon Sep 17 00:00:00 2001 From: mathmer-amd Date: Thu, 20 Mar 2025 07:45:22 -0600 Subject: [PATCH 28/31] test: array with contraints results --- mlir/test/mlir-pdll/Parser/expr.pdll | 36 ++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index 0eeeaddd2afd5..8acdb7e9bba77 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -121,6 +121,42 @@ Constraint getPopulateArray() -> Attr { } +// ----- + + +// CHECK-LABEL: Module +// CHECK:LetStmt {{.*}} +//CHECK-NEXT:`-VariableDecl {{.*}} Name Type +//CHECK-NEXT: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType +// CHECK-DAG: `Arguments` +//CHECK-NEXT: |-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: | `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType +// CHECK-DAG: `Arguments` +//CHECK-NEXT: |-AttributeExpr {{.*}} Value<"[]"> +//CHECK-NEXT: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name ResultType +// CHECK: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name ResultType +// CHECK-DAG: -ReturnStmt {{.*}} + +Constraint getA() -> Attr { + return "A"; +} + +Constraint getB() -> Attr { + return "B"; +} + +Constraint getPopulateArrayFromOtherConstraints() -> Attr { + let array = [getA(), getB()]; + return array; +} + // ----- From 4cfc438872636b88a0a650967329159a0d30a786 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Thu, 27 Mar 2025 16:59:23 +0100 Subject: [PATCH 29/31] Check valid emitc float/opaque types, not float (#525) Use the emitc-provided function to check the types instead of checking for float types, as the other arith lowering do. --- mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index d347ae916784b..3f39de4200359 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -84,7 +84,7 @@ class CmpFOpConversion : public OpConversionPattern { matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!isa(adaptor.getRhs().getType())) { + if (!emitc::isFloatOrOpaqueType(adaptor.getRhs().getType())) { return rewriter.notifyMatchFailure(op.getLoc(), "cmpf currently only supported on " "floats, not tensors/vectors thereof"); From 61a20b8b98ccdb84582b9947e1c4a2a10e338205 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 1 Apr 2025 09:44:26 +0200 Subject: [PATCH 30/31] cmake: Use old CMP0175 policy Otherwise we get hundreds of warnings during cmake. --- mlir/cmake/modules/AddMLIRPython.cmake | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake index 815f65b106d94..404002e03b51b 100644 --- a/mlir/cmake/modules/AddMLIRPython.cmake +++ b/mlir/cmake/modules/AddMLIRPython.cmake @@ -23,6 +23,11 @@ # grouping. Source groupings form a DAG. # SOURCES: List of specific source files relative to ROOT_DIR to include. # SOURCES_GLOB: List of glob patterns relative to ROOT_DIR to include. + +if (POLICY CMP0175) + cmake_policy(SET CMP0175 OLD) +endif() + function(declare_mlir_python_sources name) cmake_parse_arguments(ARG "" From 11193c5cf2e84bccf7f0a95e1b66afd748091858 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Mon, 7 Apr 2025 16:40:31 +0100 Subject: [PATCH 31/31] Add support for folding tosa.slice with tosa.slice --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 24 ++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 75 +++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index fb9dca815124f..532237f083e89 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1463,6 +1463,30 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { } OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { + const auto tryFoldWithPrecedingSlice = [this](FoldAdaptor adaptor) { + auto precedingSliceOp = getInput1().getDefiningOp(); + if (!precedingSliceOp) + return failure(); + const auto precedingSliceStart = precedingSliceOp.getStart(); + const auto thisSliceStart = getStart(); + SmallVector newSliceStart; + newSliceStart.reserve(precedingSliceStart.size()); + for (auto [startPreceding, startThis] : + llvm::zip_equal(precedingSliceStart, thisSliceStart)) { + newSliceStart.push_back(startPreceding + startThis); + } + setOperand(precedingSliceOp->getOperand(0)); + setStart(newSliceStart); + getOperation()->setLoc( + FusedLoc::get(getContext(), {precedingSliceOp->getLoc(), getLoc()})); + return success(); + }; + + // First try folding the preceding slice, this also works if the shapes are + // dynamic + if (succeeded(tryFoldWithPrecedingSlice(adaptor))) + return getResult(); + auto inputTy = llvm::dyn_cast(getInput1().getType()); auto outputTy = llvm::dyn_cast(getType()); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 447b7a300ad3f..826811b0f2344 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -690,6 +690,81 @@ func.func @slice_nofold(%arg0: tensor) -> tensor { // ----- +// CHECK-LABEL: @slice_fuse +func.func @slice_fuse(%arg0: tensor<3x4xf32>) -> tensor<1x2xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x2xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<3x4xf32>) -> tensor<1x2xf32> +// CHECK: return [[VAR_0_]] : tensor<1x2xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<3x4xf32>) -> tensor<2x3xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<2x3xf32>) -> tensor<1x2xf32> + return %1 : tensor<1x2xf32> +} + +// ----- + +// CHECK-LABEL: @slice_fuse_different_step +func.func @slice_fuse_different_step(%arg0: tensor<3x4xf32>) -> tensor<1x1xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x1xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<3x4xf32>) -> tensor<1x1xf32> +// CHECK: return [[VAR_0_]] : tensor<1x1xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<3x4xf32>) -> tensor<1x3xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<1x3xf32>) -> tensor<1x1xf32> + return %1 : tensor<1x1xf32> +} + +// ----- + +// CHECK-LABEL: @slice_fuse_different_start +func.func @slice_fuse_different_start(%arg0: tensor<3x4xf32>) -> tensor<1x1xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x1xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<3x4xf32>) -> tensor<1x1xf32> +// CHECK: return [[VAR_0_]] : tensor<1x1xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<3x4xf32>) -> tensor<1x3xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<1x3xf32>) -> tensor<1x1xf32> + return %1 : tensor<1x1xf32> +} + +// ----- + +// CHECK-LABEL: @slice_fuse_different_start_2 +func.func @slice_fuse_different_start_2(%arg0: tensor<10x10xf32>) -> tensor<1x1xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<1x1xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<10x10xf32>) -> tensor<1x1xf32> +// CHECK: return [[VAR_0_]] : tensor<1x1xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<10x10xf32>) -> tensor<5x5xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<5x5xf32>) -> tensor<3x3xf32> + %2 = tosa.slice %1 { size = array, start = array}: (tensor<3x3xf32>) -> tensor<1x1xf32> + return %2 : tensor<1x1xf32> +} + +// ----- + +// CHECK-LABEL: @slice_fuse_different_start_3 +func.func @slice_fuse_different_start_3(%arg0: tensor<10x10xf32>) -> tensor<1x1xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<1x1xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<10x10xf32>) -> tensor<1x1xf32> +// CHECK: return [[VAR_0_]] : tensor<1x1xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<10x10xf32>) -> tensor<5x5xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<5x5xf32>) -> tensor<3x3xf32> + %2 = tosa.slice %1 { size = array, start = array}: (tensor<3x3xf32>) -> tensor<1x1xf32> + return %2 : tensor<1x1xf32> +} + +// ----- + +// CHECK-LABEL: func.func @slice_fuse_different_start_dynamic +func.func @slice_fuse_different_start_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<*xf32>) -> tensor<*xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<*xf32>) -> tensor<*xf32> + %2 = tosa.slice %1 { size = array, start = array}: (tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: @tile_fold func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: return %arg0