Skip to content

[mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern #141613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

chrsmcgrr
Copy link
Contributor

Given the following example:

module {
  func.func @main(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x1x4x1xf32> {
    %pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
    return %pack : tensor<1x1x1x4x1xf32>
  }
}

We would generate an invalid transpose operation because the calculated permutation would be [0, 2, 0] which is semantically incorrect. As the permutation must contain unique integers corresponding to the source tensor dimensions.

The following change modifies how we calculate the permutation array and ensures that the dimension indices given in the permutation array is unique.

The above example would then translate to a transpose having a permutation of [1, 2, 0]. Following the rule, that the inner_dim_pos is appended to the permutation array and the preceding indices are filled with the remaining dimensions.

@llvmbot
Copy link
Member

llvmbot commented May 27, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Christopher McGirr (chrsmcgrr)

Changes

Given the following example:

module {
  func.func @<!-- -->main(%arg0: tensor&lt;1x1x1x4x1xf32&gt;, %arg1: tensor&lt;1x1x4xf32&gt;) -&gt; tensor&lt;1x1x1x4x1xf32&gt; {
    %pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor&lt;1x1x4xf32&gt; -&gt; tensor&lt;1x1x1x4x1xf32&gt;
    return %pack : tensor&lt;1x1x1x4x1xf32&gt;
  }
}

We would generate an invalid transpose operation because the calculated permutation would be [0, 2, 0] which is semantically incorrect. As the permutation must contain unique integers corresponding to the source tensor dimensions.

The following change modifies how we calculate the permutation array and ensures that the dimension indices given in the permutation array is unique.

The above example would then translate to a transpose having a permutation of [1, 2, 0]. Following the rule, that the inner_dim_pos is appended to the permutation array and the preceding indices are filled with the remaining dimensions.


Full diff: https://github.com/llvm/llvm-project/pull/141613.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+15-8)
  • (modified) mlir/test/Dialect/Linalg/decompose-pack.mlir (+19)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8718c57b9e86c..7b6c8243d1040 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1205,16 +1205,23 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
   //    %init = tensor.empty()
   //    %transposed_tile = linalg.transpose ins(%source_or_padded_source),
   //                                        outs(%init)
-  // Two assumptions are made:
-  //  1. All outer dims are 1 - the corresponding transposition doesn't matter.
-  //  2. Inner dims position correspond to the trailing `numTiles` dims.
-  SmallVector<int64_t> tilesPermNormalized =
-      getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
+  // Assumptions made:
+  //  1. Inner dims position correspond to the trailing `numTiles` dims.
   SmallVector<int64_t> srcPermForTranspose;
-  for (int64_t i = 0; i < (srcRank - numTiles); i++)
+  ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
+  for (int64_t i = 0; i < srcRank; i++) {
+    // As we assume the trailing dimensions of the inner dim position correspond
+    // to the trailing indices of the transpose permutation, we need to
+    // calculate the remaining indicies of the transpose permutation. This is
+    // done by adding the indices not contained in the inner dimension position.
+    //   For example if we have a source tensor of dimensions [0, 1, 2, 3]
+    //   and inner dim position of [3, 0], the remaining indices are [1, 2].
+    //   and the transpose will be [1, 2, 3, 0].
+    if (llvm::is_contained(innerDimPos, i))
+      continue;
     srcPermForTranspose.push_back(i);
-
-  srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
+  }
+  srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
 
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
                     << "perm: " << llvm::interleaved(srcPermForTranspose)
diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 911b453f919c3..6d091406a639c 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -229,3 +229,22 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
 // CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
 // CHECK-SAME:      [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
 // CHECK:         return %[[INSERT]]
+
+// -----
+
+func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> {
+  %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
+  return %pack : tensor<1x1x1x4x1xf32>
+}
+
+// CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_unit_inner
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32>
+// CHECK:         %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME:      ins(%[[SRC]] : tensor<1x1x4xf32>)
+// CHECK-SAME:      outs(%[[EMPTY]] : tensor<1x4x1xf32>)
+// CHECK-SAME:      permutation = [1, 2, 0]
+// CHECK:         %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME:      [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
+// CHECK:         return %[[INSERT]]
\ No newline at end of file

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, you extend the pattern to handle the case that there are non ones in unpacked outer dimensions? I.e., should we relax the check in line 1165 - 1169? Then you are not fixing a corner case. Instead, you extend the support in general?

@hanhanW
Copy link
Contributor

hanhanW commented May 27, 2025

module {
func.func @main(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x1x4x1xf32> {
%pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
return %pack : tensor<1x1x1x4x1xf32>
}
}

E.g., I'd expect the below test case working with your support, if I read your intention correctly.

func.func @main(%arg0: tensor<2x1x1x4x1xf32>, %arg1: tensor<1x2x4xf32>) -> tensor<2x1x1x4x1xf32> {
  %pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x2x4xf32> -> tensor<2x1x1x4x1xf32>
  return %pack : tensor<2x1x1x4x1xf32>
}

SmallVector<int64_t> tilesPermNormalized =
getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
// Assumptions made:
// 1. Inner dims position correspond to the trailing `numTiles` dims.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that you're relaxing this assumption, not the other one.

Looking at your test, this part is key: inner_dims_pos = [2, 0]:

%pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>

What’s happening here is that:

  • There’s a non-trailing dim (0) in inner_dims_pos.
  • The dims [2, 0] are non-adjacent in the original shape.

The original comment (which I think I wrote, so I can moan about it 😅) doesn’t make this very clear. To clarify the terminology:

  • [2] is a trailing dim.
  • [1, 2] and [2, 1] are adjacent trailing dims.
  • [0] and [1] are not trailing dims.
  • [0, 1, 2] is a set of adjacent trailing dims.

So in your case, [2, 0] is a set of non-adjacent dims, and that case isn’t supported at the moment.

Thanks for digging into this - and sorry for the confusing terminology! Really appreciate you improving the situation here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space thank you for the explanation :)

Would then modifying the check for trailing dimensions make more sense? To then block my use-case and handle it in another pattern?

Or would you be ok in supporting my use case in this pattern? I would also then update the comments to reflect this explanation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would then modifying the check for trailing dimensions make more sense?

Looks like there's off-by-one error there 🤦🏻 :( My bad, sorry!

To then block my use-case and handle it in another pattern?

Is there another pattern that could handle this for you? 🙂 I’m not aware of one, but there are so many that I might be missing something myself.

Or would you be ok in supporting my use case in this pattern?

I’m totally fine with that - your fix is very non-intrusive, and the test case isn’t exotic. It feels like we’re hitting an artificial restriction here: no one needed this case before, so it just wasn’t tested or supported.

@chrsmcgrr
Copy link
Contributor Author

IIUC, you extend the pattern to handle the case that there are non ones in unpacked outer dimensions? I.e., should we relax the check in line 1165 - 1169? Then you are not fixing a corner case. Instead, you extend the support in general?

Thanks for the quick reply @hanhanW

Not quite, at least for my use case I am still only concerned with unit outer dimensions in the unpacked case. AFAIK, the outer dimension in my case would be index [1] and the inner dimensions would be [2, 0] when looking at the source, unpacked tensor. Correct me if I am wrong.

My change is more about the adjacent trailing dimensions as @banach-space has now explained.

I would be happy to extend 1165-1169 if anyone needs it.

@banach-space
Copy link
Contributor

I would be happy to extend 1165-1169 if anyone needs it.

If it's not required, I would refrain from extending it right now. These "decomposition" patterns are already riddled with assumptions that we neither document nor test (like the case with non-adjacent dims that you discovered). Extending them could lead to even more un-verified assumptions.

Btw, @chrsmcgrr , could you also the check DecomposeOuterUnitDimsUnPackOpPattern? We should keep these "decompose" patterns symmetrical 😅

Thanks!

@chrsmcgrr
Copy link
Contributor Author

chrsmcgrr commented Jun 2, 2025

@banach-space @hanhanW I've updated the comments and removed the adjacent trailing dimensions check as it is no longer needed. This change will allow for that use-case.

I have also added the corresponding test to the unpack version which works fine out-of-the-box. Looking at the unpack pattern I can't see a clean way of making the patterns symmetrical. So I will leave it for now.

Let me know what you think.

@chrsmcgrr chrsmcgrr requested review from banach-space and hanhanW June 2, 2025 14:30
@chrsmcgrr chrsmcgrr force-pushed the fix-decomposition-of-unit-outer-dims branch from 5c41b17 to 4f378a5 Compare June 4, 2025 11:23
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! Some minor comments inline.

Comment on lines 1204 to 1209
// 2. Inner dims position can have non-adjacent trailing dimensions. Where,
// For example, a source tensor with indices [0, 1, 2] can have:
// * adjacent trailing dimensions of [1, 2], [2, 1]
// * non-adjacent trailing dimensions of [0, 2] or [2, 0]
// Trailing dimensions are defined in the case above as index [2].
// And the indices [0] or [1] are not defined to be trailing.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, all restrictions with respect to "inner dims" are lifted, so this comment can be deleted? Could you verify (with a test) that it is OK if there are no "trailing dims" inside "inner dims"? For example (note inner_dims_pos):

%pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>

Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 1213 to 1219
// We assume the `k` dimensions of the inner dim position correspond
// to the last `k` indices of the transpose permutation. This is
// done by adding the indices not contained in the inner dimension position
// in order from 0 to `n`. Where n is the rank of the source tensor.
// For example if we have a source tensor with indices [0, 1, 2, 3]
// and inner dim position of [3, 0], the remaining indices are [1, 2].
// and the transpose will be [1, 2, 3, 0].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is k dimension? I'd expect to see a variable called k so that I can match the comment with the code.

[nit] The formatting is off.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took that variable from the documentation. I will update the comment to define k

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// -----

func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Note that this is testing unpack rather than pack (@pack -> @unpack)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// -----

func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, what makes this (and the other test added here) unique is the fact that inner_dims_pos contains non-adjacent positions, right? If yes, could you update the function name to capture that. I would also add a comment, but that's just me :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Given the following example:
```
module {
  func.func @main(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x1x4x1xf32> {
    %pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
    return %pack : tensor<1x1x1x4x1xf32>
  }
}
```

We would generate an invalid transpose operation because the calculated
permutation would be `[0, 2, 0]` which is semantically incorrect. As the
permutation must contain unique integers corresponding to the source
tensor dimensions.

The following change modifies how we calculate the permutation array and
ensures that the dimension indices given in the permutation array is
unique.

The above example would then translate to a transpose having a
permutation of `[1, 2, 0]`. Following the rule, that the `inner_dim_pos`
is appended to the permutation array and the preceding indices are
filled with the remaining dimensions.
@chrsmcgrr chrsmcgrr force-pushed the fix-decomposition-of-unit-outer-dims branch from 4f378a5 to 79e0ff5 Compare June 12, 2025 15:00
@@ -172,11 +172,11 @@ func.func @unpack_with_dynamic_dims(%arg0: tensor<?x1x1x1x8x32xf32>, %arg1: tens

// -----

func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
func.func @unpack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be unpack_with_non_adjacent_inner_dims_pos_and_unit_outer, which follows the same name in the other file?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants