-
Notifications
You must be signed in to change notification settings - Fork 13.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][scf]Fix scf.forall inlining: add shared outputs #132197
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: Prakhar Dixit (Prakhar-Dixit) ChangesFixes #108164 This patch fixes a crash in the scf-forall-to-for conversion pass by ensuring that the replacement vector used during inlining contains both the induction variables and the shared outputs. Previously, only the induction variables were passed, causing a mismatch with the expected number of block arguments in the forall op’s body. The fix concatenates the shared outputs (retrieved via getOutputs()) with the induction variables and then replaces the forall op with its shared outputs, preserving the intended semantics without introducing regressions.
Full diff: https://github.com/llvm/llvm-project/pull/132197.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
index a2f03f1e1056e..a1df366cef132 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -40,12 +40,15 @@ mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
SmallVector<Value> ivs = llvm::map_to_vector(
loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
+ SmallVector<Value> replacementVals = ivs;
+ for (Value shared : forallOp.getOutputs())
+ replacementVals.push_back(shared);
Block *innermostBlock = loopNest.loops.back().getBody();
rewriter.eraseOp(forallOp.getBody()->getTerminator());
rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
innermostBlock->getTerminator()->getIterator(),
- ivs);
- rewriter.eraseOp(forallOp);
+ replacementVals);
+ rewriter.replaceOp(forallOp, forallOp.getOutputs());
if (results) {
llvm::move(loopNest.loops, std::back_inserter(*results));
diff --git a/mlir/test/Dialect/SCF/forall-to-for.mlir b/mlir/test/Dialect/SCF/forall-to-for.mlir
index e7d183fb9d2b5..17598a154fefd 100644
--- a/mlir/test/Dialect/SCF/forall-to-for.mlir
+++ b/mlir/test/Dialect/SCF/forall-to-for.mlir
@@ -55,3 +55,26 @@ func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
}
return
}
+
+// -----
+
+ func.func @parallel_insert_slice(%arg0: tensor<100xf32>) -> tensor<100xf32> {
+ %c100 = arith.constant 100 : index
+ %res = scf.forall (%i) in (%c100) shared_outs(%s = %arg0) -> (tensor<100xf32>) {
+ %t = "test.foo"() : () -> tensor<100xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %t into %s[%i] [100] [1] : tensor<100xf32> into tensor<100xf32>
+ }
+ }
+ return %res : tensor<100xf32>
+ }
+// CHECK-LABEL: func.func @parallel_insert_slice(
+// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<100xf32>) -> tensor<100xf32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 100 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] {
+// CHECK: %[[VAL_5:.*]] = "test.foo"() : () -> tensor<100xf32>
+// CHECK: }
+// CHECK: return %[[VAL_0]] : tensor<100xf32>
+// CHECK: }
\ No newline at end of file
|
Could you please review this ? |
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index | ||
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index | ||
// CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] { | ||
// CHECK: %[[VAL_5:.*]] = "test.foo"() : () -> tensor<100xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where did the parallel_insert_slice
go? I think this pass is incorrect. It should have replaced the parallel_insert_slice
with insert_slice
.
// CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] { | ||
// CHECK: %[[VAL_5:.*]] = "test.foo"() : () -> tensor<100xf32> | ||
// CHECK: } | ||
// CHECK: return %[[VAL_0]] : tensor<100xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The result of the scf.for
should have been used here. It looks like the generated loop nest does not even have a result/iter_args. The issue that you are fixing here was probably an undocumented limitation of this pass, not necessarily a bug: shared_outs
are generally not supported, which made the implementation a bit easier.
But it would be nice to support shared_outs
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically, instead of dropping the terminator of the scf.forall
loop, you have to replace it with tensor.insert_slice
and yield the result. Also, the loop nest that this pass is generating must have an iter_arg (and result); one per shared_out.
Fixes #108164
This patch fixes a crash in the scf-forall-to-for conversion pass by ensuring that the replacement vector used during inlining contains both the induction variables and the shared outputs. Previously, only the induction variables were passed, causing a mismatch with the expected number of block arguments in the forall op’s body. The fix concatenates the shared outputs (retrieved via getOutputs()) with the induction variables and then replaces the forall op with its shared outputs, preserving the intended semantics without introducing regressions.
Minimal Example IR: