-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[mlir][hoisting] Support memref.assume_alignment in linalg hoisting #144843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
All ViewLike operations are excluded by hoisting optimization. But assume_alignment just mark memref's alignment, we should check its memref instead of itself.
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: XiangZhang (xiangzh1) ChangesThe recent updates of AssumeAlignmentOp will affect linalg hoisting optimization. related issue : 144825 This patch tend to fix this problem due to the assume_alignment just mark memref's alignment, Full diff: https://github.com/llvm/llvm-project/pull/144843.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..b949b06631484 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -199,6 +199,24 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
return true;
}
+static bool skipViewLike(Operation *source0, Operation *source1) {
+ bool viewLikeCheck = true;
+ auto assumeAlignOp = dyn_cast_or_null<memref::AssumeAlignmentOp>(source0);
+ if (assumeAlignOp && source0 == source1) {
+ Value sourceMemRef = assumeAlignOp.getMemref();
+ Operation *sourceOp = sourceMemRef.getDefiningOp();
+ return isa_and_nonnull<ViewLikeOpInterface>(sourceOp);
+ }
+
+ if (source0 && isa_and_nonnull<ViewLikeOpInterface>(source0))
+ return true;
+
+ if (source1 && isa_and_nonnull<ViewLikeOpInterface>(source1))
+ return true;
+
+ return false;
+}
+
void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
bool verifyNonZeroTrip) {
bool changed = true;
@@ -312,12 +330,10 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
return WalkResult::advance();
- auto *source = transferRead.getBase().getDefiningOp();
- if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
- return WalkResult::advance();
+ auto *source0 = transferRead.getBase().getDefiningOp();
+ auto *source1 = transferWrite.getBase().getDefiningOp();
- source = transferWrite.getBase().getDefiningOp();
- if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
+ if (skipViewLike(source0, source1))
return WalkResult::advance();
// TODO: may want to memoize this information for performance but it
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..c58074e40c5f4 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -802,3 +802,55 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// Test hoisting of vector.transfer_read/transfer_write pairs with same location
+// and this location is marked with assume_align.
+
+// CHECK-LABEL: func.func @hoist_vector_transfer_read_write() {
+// CHECK: %c0 = arith.constant 0 : index
+// CHECK-NEXT: %c256 = arith.constant 256 : index
+// CHECK-NEXT: %c4096 = arith.constant 4096 : index
+// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f16
+// CHECK-NEXT: %alloc = memref.alloc() : memref<4096x4096xf16>
+// CHECK-NEXT: %alloc_0 = memref.alloc() : memref<4096x4096xf16>
+// CHECK-NEXT: %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
+// CHECK-NEXT: %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT: %1 = scf.for %arg0 = %c256 to %c4096 step %c256 iter_args(%arg1 = %0) -> (vector<16x16xf16>) {
+// CHECK-NEXT: %2 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT: %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %2, %arg1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+// CHECK-NEXT: scf.yield %3 : vector<16x16xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: vector.transfer_write %1, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+func.func @hoist_vector_transfer_read_write() {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c256 = arith.constant 256 : index
+ %c4096 = arith.constant 4096 : index
+ %cst_0 = arith.constant 0.000000e+00 : f16
+ %m0 = memref.alloc() : memref<4096x4096xf16>
+ %m1 = memref.alloc() : memref<4096x4096xf16>
+ %assume_align_0 = memref.assume_alignment %m0, 64 : memref<4096x4096xf16>
+ %assume_align_1 = memref.assume_alignment %m1, 64 : memref<4096x4096xf16>
+ scf.for %arg0 = %c256 to %c4096 step %c256 {
+ %1 = vector.transfer_read %assume_align_0[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+ %2 = vector.transfer_read %m1[%arg0, %arg0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+ %3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %2, %1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+ vector.transfer_write %3, %assume_align_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
|
The recent updates of AssumeAlignmentOp will affect linalg hoisting optimization.
We find it has regression on "hoist load/store out of loop".
The flowing issue list more detail:
related issue : 144825
This patch tend to fix this problem due to the assume_alignment just mark memref's alignment,
the linalg hoisting should check its memref operand not it self.