-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern #141613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern #141613
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Christopher McGirr (chrsmcgrr) ChangesGiven the following example:
We would generate an invalid transpose operation because the calculated permutation would be The following change modifies how we calculate the permutation array and ensures that the dimension indices given in the permutation array is unique. The above example would then translate to a transpose having a permutation of Full diff: https://github.com/llvm/llvm-project/pull/141613.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8718c57b9e86c..7b6c8243d1040 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1205,16 +1205,23 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
- // Two assumptions are made:
- // 1. All outer dims are 1 - the corresponding transposition doesn't matter.
- // 2. Inner dims position correspond to the trailing `numTiles` dims.
- SmallVector<int64_t> tilesPermNormalized =
- getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
+ // Assumptions made:
+ // 1. Inner dims position correspond to the trailing `numTiles` dims.
SmallVector<int64_t> srcPermForTranspose;
- for (int64_t i = 0; i < (srcRank - numTiles); i++)
+ ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
+ for (int64_t i = 0; i < srcRank; i++) {
+ // As we assume the trailing dimensions of the inner dim position correspond
+ // to the trailing indices of the transpose permutation, we need to
+ // calculate the remaining indicies of the transpose permutation. This is
+ // done by adding the indices not contained in the inner dimension position.
+ // For example if we have a source tensor of dimensions [0, 1, 2, 3]
+ // and inner dim position of [3, 0], the remaining indices are [1, 2].
+ // and the transpose will be [1, 2, 3, 0].
+ if (llvm::is_contained(innerDimPos, i))
+ continue;
srcPermForTranspose.push_back(i);
-
- srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
+ }
+ srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
<< "perm: " << llvm::interleaved(srcPermForTranspose)
diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 911b453f919c3..6d091406a639c 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -229,3 +229,22 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
// CHECK: return %[[INSERT]]
+
+// -----
+
+func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> {
+ %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
+ return %pack : tensor<1x1x1x4x1xf32>
+}
+
+// CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_unit_inner
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x4xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4x1xf32>)
+// CHECK-SAME: permutation = [1, 2, 0]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
+// CHECK: return %[[INSERT]]
\ No newline at end of file
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, you extend the pattern to handle the case that there are non ones in unpacked outer dimensions? I.e., should we relax the check in line 1165 - 1169? Then you are not fixing a corner case. Instead, you extend the support in general?
E.g., I'd expect the below test case working with your support, if I read your intention correctly. func.func @main(%arg0: tensor<2x1x1x4x1xf32>, %arg1: tensor<1x2x4xf32>) -> tensor<2x1x1x4x1xf32> {
%pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x2x4xf32> -> tensor<2x1x1x4x1xf32>
return %pack : tensor<2x1x1x4x1xf32>
} |
SmallVector<int64_t> tilesPermNormalized = | ||
getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos()); | ||
// Assumptions made: | ||
// 1. Inner dims position correspond to the trailing `numTiles` dims. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that you're relaxing this assumption, not the other one.
Looking at your test, this part is key: inner_dims_pos = [2, 0]:
%pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
What’s happening here is that:
- There’s a non-trailing dim (
0
) ininner_dims_pos
. - The dims
[2, 0]
are non-adjacent in the original shape.
The original comment (which I think I wrote, so I can moan about it 😅) doesn’t make this very clear. To clarify the terminology:
[2]
is a trailing dim.[1, 2]
and[2, 1]
are adjacent trailing dims.[0]
and[1]
are not trailing dims.[0, 1, 2]
is a set of adjacent trailing dims.
So in your case, [2, 0]
is a set of non-adjacent dims, and that case isn’t supported at the moment.
Thanks for digging into this - and sorry for the confusing terminology! Really appreciate you improving the situation here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@banach-space thank you for the explanation :)
Would then modifying the check for trailing dimensions make more sense? To then block my use-case and handle it in another pattern?
Or would you be ok in supporting my use case in this pattern? I would also then update the comments to reflect this explanation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would then modifying the check for trailing dimensions make more sense?
Looks like there's off-by-one error there 🤦🏻 :( My bad, sorry!
To then block my use-case and handle it in another pattern?
Is there another pattern that could handle this for you? 🙂 I’m not aware of one, but there are so many that I might be missing something myself.
Or would you be ok in supporting my use case in this pattern?
I’m totally fine with that - your fix is very non-intrusive, and the test case isn’t exotic. It feels like we’re hitting an artificial restriction here: no one needed this case before, so it just wasn’t tested or supported.
Thanks for the quick reply @hanhanW Not quite, at least for my use case I am still only concerned with unit outer dimensions in the unpacked case. AFAIK, the outer dimension in my case would be index My change is more about the adjacent trailing dimensions as @banach-space has now explained. I would be happy to extend 1165-1169 if anyone needs it. |
If it's not required, I would refrain from extending it right now. These "decomposition" patterns are already riddled with assumptions that we neither document nor test (like the case with non-adjacent dims that you discovered). Extending them could lead to even more un-verified assumptions. Btw, @chrsmcgrr , could you also the check Thanks! |
@banach-space @hanhanW I've updated the comments and removed the adjacent trailing dimensions check as it is no longer needed. This change will allow for that use-case. I have also added the corresponding test to the unpack version which works fine out-of-the-box. Looking at the unpack pattern I can't see a clean way of making the patterns symmetrical. So I will leave it for now. Let me know what you think. |
5c41b17
to
4f378a5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates! Some minor comments inline.
// 2. Inner dims position can have non-adjacent trailing dimensions. Where, | ||
// For example, a source tensor with indices [0, 1, 2] can have: | ||
// * adjacent trailing dimensions of [1, 2], [2, 1] | ||
// * non-adjacent trailing dimensions of [0, 2] or [2, 0] | ||
// Trailing dimensions are defined in the case above as index [2]. | ||
// And the indices [0] or [1] are not defined to be trailing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, all restrictions with respect to "inner dims" are lifted, so this comment can be deleted? Could you verify (with a test) that it is OK if there are no "trailing dims" inside "inner dims"? For example (note inner_dims_pos
):
%pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
// We assume the `k` dimensions of the inner dim position correspond | ||
// to the last `k` indices of the transpose permutation. This is | ||
// done by adding the indices not contained in the inner dimension position | ||
// in order from 0 to `n`. Where n is the rank of the source tensor. | ||
// For example if we have a source tensor with indices [0, 1, 2, 3] | ||
// and inner dim position of [3, 0], the remaining indices are [1, 2]. | ||
// and the transpose will be [1, 2, 3, 0]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is k
dimension? I'd expect to see a variable called k
so that I can match the comment with the code.
[nit] The formatting is off.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took that variable from the documentation. I will update the comment to define k
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
// ----- | ||
|
||
func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Note that this is testing unpack
rather than pack
(@pack
-> @unpack
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
// ----- | ||
|
||
func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, what makes this (and the other test added here) unique is the fact that inner_dims_pos
contains non-adjacent positions, right? If yes, could you update the function name to capture that. I would also add a comment, but that's just me :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Given the following example: ``` module { func.func @main(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x1x4x1xf32> { %pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32> return %pack : tensor<1x1x1x4x1xf32> } } ``` We would generate an invalid transpose operation because the calculated permutation would be `[0, 2, 0]` which is semantically incorrect. As the permutation must contain unique integers corresponding to the source tensor dimensions. The following change modifies how we calculate the permutation array and ensures that the dimension indices given in the permutation array is unique. The above example would then translate to a transpose having a permutation of `[1, 2, 0]`. Following the rule, that the `inner_dim_pos` is appended to the permutation array and the preceding indices are filled with the remaining dimensions.
4f378a5
to
79e0ff5
Compare
@@ -172,11 +172,11 @@ func.func @unpack_with_dynamic_dims(%arg0: tensor<?x1x1x1x8x32xf32>, %arg1: tens | |||
|
|||
// ----- | |||
|
|||
func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { | |||
func.func @unpack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be unpack_with_non_adjacent_inner_dims_pos_and_unit_outer
, which follows the same name in the other file?
Given the following example:
We would generate an invalid transpose operation because the calculated permutation would be
[0, 2, 0]
which is semantically incorrect. As the permutation must contain unique integers corresponding to the source tensor dimensions.The following change modifies how we calculate the permutation array and ensures that the dimension indices given in the permutation array is unique.
The above example would then translate to a transpose having a permutation of
[1, 2, 0]
. Following the rule, that theinner_dim_pos
is appended to the permutation array and the preceding indices are filled with the remaining dimensions.