Skip to content

[mlir][hoisting] Support memref.assume_alignment in linalg hoisting #144843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

xiangzh1
Copy link
Contributor

The recent updates of AssumeAlignmentOp will affect linalg hoisting optimization.
We find it has regression on "hoist load/store out of loop".
The flowing issue list more detail:

related issue : 144825

This patch tend to fix this problem due to the assume_alignment just mark memref's alignment,
the linalg hoisting should check its memref operand not it self.

xiangzh1 added 2 commits June 18, 2025 17:10
All ViewLike operations are excluded by hoisting optimization. But
assume_alignment just mark memref's alignment, we should check its
memref instead of itself.
@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: XiangZhang (xiangzh1)

Changes

The recent updates of AssumeAlignmentOp will affect linalg hoisting optimization.
We find it has regression on "hoist load/store out of loop".
The flowing issue list more detail:

related issue : 144825

This patch tend to fix this problem due to the assume_alignment just mark memref's alignment,
the linalg hoisting should check its memref operand not it self.


Full diff: https://github.com/llvm/llvm-project/pull/144843.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp (+21-5)
  • (modified) mlir/test/Dialect/Linalg/hoisting.mlir (+52)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..b949b06631484 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -199,6 +199,24 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
   return true;
 }
 
+static bool skipViewLike(Operation *source0, Operation *source1) {
+  bool viewLikeCheck = true;
+  auto assumeAlignOp = dyn_cast_or_null<memref::AssumeAlignmentOp>(source0);
+  if (assumeAlignOp && source0 == source1) {
+    Value sourceMemRef = assumeAlignOp.getMemref();
+    Operation *sourceOp = sourceMemRef.getDefiningOp();
+    return isa_and_nonnull<ViewLikeOpInterface>(sourceOp);
+  }
+
+  if (source0 && isa_and_nonnull<ViewLikeOpInterface>(source0))
+    return true;
+
+  if (source1 && isa_and_nonnull<ViewLikeOpInterface>(source1))
+    return true;
+
+  return false;
+}
+
 void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
                                                  bool verifyNonZeroTrip) {
   bool changed = true;
@@ -312,12 +330,10 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
           transferRead.getPermutationMap() != transferWrite.getPermutationMap())
         return WalkResult::advance();
 
-      auto *source = transferRead.getBase().getDefiningOp();
-      if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
-        return WalkResult::advance();
+      auto *source0 = transferRead.getBase().getDefiningOp();
+      auto *source1 = transferWrite.getBase().getDefiningOp();
 
-      source = transferWrite.getBase().getDefiningOp();
-      if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
+      if (skipViewLike(source0, source1))
         return WalkResult::advance();
 
       // TODO: may want to memoize this information for performance but it
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..c58074e40c5f4 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -802,3 +802,55 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// Test hoisting of vector.transfer_read/transfer_write pairs with same location
+// and this location is marked with assume_align.
+
+// CHECK-LABEL:  func.func @hoist_vector_transfer_read_write() {
+// CHECK:          %c0 = arith.constant 0 : index
+// CHECK-NEXT:     %c256 = arith.constant 256 : index
+// CHECK-NEXT:     %c4096 = arith.constant 4096 : index
+// CHECK-NEXT:     %cst = arith.constant 0.000000e+00 : f16
+// CHECK-NEXT:     %alloc = memref.alloc() : memref<4096x4096xf16>
+// CHECK-NEXT:     %alloc_0 = memref.alloc() : memref<4096x4096xf16>
+// CHECK-NEXT:     %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
+// CHECK-NEXT:     %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT:     %1 = scf.for %arg0 = %c256 to %c4096 step %c256 iter_args(%arg1 = %0) -> (vector<16x16xf16>) {
+// CHECK-NEXT:       %2 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT:       %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %2, %arg1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+// CHECK-NEXT:       scf.yield %3 : vector<16x16xf16>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     vector.transfer_write %1, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
+// CHECK-NEXT:     return
+// CHECK-NEXT:   }
+
+func.func @hoist_vector_transfer_read_write() {
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c256 = arith.constant 256 : index
+  %c4096 = arith.constant 4096 : index
+  %cst_0 = arith.constant 0.000000e+00 : f16
+  %m0 = memref.alloc() : memref<4096x4096xf16>
+  %m1 = memref.alloc() : memref<4096x4096xf16>
+  %assume_align_0 = memref.assume_alignment %m0, 64 : memref<4096x4096xf16>
+  %assume_align_1 = memref.assume_alignment %m1, 64 : memref<4096x4096xf16>
+  scf.for %arg0 = %c256 to %c4096 step %c256 {
+    %1 = vector.transfer_read %assume_align_0[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+    %2 = vector.transfer_read %m1[%arg0, %arg0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+    %3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %2, %1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+    vector.transfer_write %3, %assume_align_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+        transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

@xiangzh1 xiangzh1 requested a review from banach-space June 19, 2025 06:08
@xiangzh1 xiangzh1 added the good first issue https://github.com/llvm/llvm-project/contribute label Jun 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue https://github.com/llvm/llvm-project/contribute mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants