-
Notifications
You must be signed in to change notification settings - Fork 14k
[mlir][vector] Fix for WarpOpScfForOp failure when scf.for has results that are unused. #141853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
ff1012e
c6eb53f
3bdb596
6d47e3f
fe3ab99
cc621a1
e1a1bc1
7d25981
e883ba4
3ae7706
5224464
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -1554,22 +1554,37 @@ struct WarpOpScfForOp : public WarpDistributionPattern { | |||
llvm::SmallSetVector<Value, 32> escapingValues; | ||||
SmallVector<Type> inputTypes; | ||||
SmallVector<Type> distTypes; | ||||
auto collectEscapingValues = [&](Value value) { | ||||
if (!escapingValues.insert(value)) | ||||
return; | ||||
Type distType = value.getType(); | ||||
if (auto vecType = dyn_cast<VectorType>(distType)) { | ||||
AffineMap map = distributionMapFn(value); | ||||
distType = getDistributedType(vecType, map, warpOp.getWarpSize()); | ||||
} | ||||
inputTypes.push_back(value.getType()); | ||||
distTypes.push_back(distType); | ||||
}; | ||||
|
||||
mlir::visitUsedValuesDefinedAbove( | ||||
forOp.getBodyRegion(), [&](OpOperand *operand) { | ||||
Operation *parent = operand->get().getParentRegion()->getParentOp(); | ||||
if (warpOp->isAncestor(parent)) { | ||||
if (!escapingValues.insert(operand->get())) | ||||
return; | ||||
Type distType = operand->get().getType(); | ||||
if (auto vecType = dyn_cast<VectorType>(distType)) { | ||||
AffineMap map = distributionMapFn(operand->get()); | ||||
distType = getDistributedType(vecType, map, warpOp.getWarpSize()); | ||||
} | ||||
inputTypes.push_back(operand->get().getType()); | ||||
distTypes.push_back(distType); | ||||
collectEscapingValues(operand->get()); | ||||
} | ||||
}); | ||||
|
||||
// Any forOp result that is not already yielded by the warpOp | ||||
// region is also considered escaping and must be returned by the | ||||
// original warpOp. | ||||
for (OpResult forResult : forOp.getResults()) { | ||||
// Check if this forResult is already yielded by the yield op. | ||||
if (llvm::is_contained(yield->getOperands(), forResult)) { | ||||
continue; | ||||
} | ||||
collectEscapingValues(forResult); | ||||
} | ||||
|
||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't get this part, could you please clarify? This should collect the arguments of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great question. This is how the original implementation work. Initially original warpOp will return the used forOp results + any escaping values. Later the for op results are replaced with the corresponding initArgs, Here:
I did not change the original logic flow. Hope this clarifies your question. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I had to play with it for a bit to understand how it works. I think adding results of the func.func @warp_scf_for_unused_yield(%arg0: index) {
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0:2 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>) {
%ini = "some_def"() : () -> (vector<128xf32>)
%ini1 = "some_def"() : () -> (vector<128xf32>)
%other = "other_def"() : () -> (vector<128xf32>)
%3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1, %other_2 = %other) -> (vector<128xf32>, vector<128xf32>, vector<128xf32>) {
%add = arith.addi %arg3, %c1 : index
%1 = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
%acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
%other2 = "other2_def"(%arg4) : (vector<128xf32>) -> (vector<128xf32>)
scf.yield %acc, %1, %other2 : vector<128xf32>, vector<128xf32>, vector<128xf32>
}
gpu.yield %3#0, %other : vector<128xf32>, vector<128xf32>
}
"some_use"(%0#0) : (vector<4xf32>) -> ()
"other_use"(%0#1) : (vector<4xf32>) -> ()
return
} Right after the %0:4 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
%1 = "some_def"() : () -> vector<128xf32>
%2 = "some_def"() : () -> vector<128xf32>
%3 = "other_def"() : () -> vector<128xf32>
%4:3 = scf.for %arg1 = %c0 to %c128 step %c1 iter_args(%arg2 = %1, %arg3 = %2, %arg4 = %3) -> (vector<128xf32>, vector<128xf32>, vector<128xf32>) {
%5 = arith.addi %arg1, %c1 : index
%6 = "some_def"(%arg3, %5) : (vector<128xf32>, index) -> vector<128xf32>
%7 = "some_def"(%5, %arg2, %6) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
%8 = "other2_def"(%arg2) : (vector<128xf32>) -> vector<128xf32>
scf.yield %7, %6, %8 : vector<128xf32>, vector<128xf32>, vector<128xf32>
}
gpu.yield %4#0, %3, %4#1, %4#2 : vector<128xf32>, vector<128xf32>, vector<128xf32>, vector<128xf32>
} The results of the I think this is somewhat brittle. If it is done to preserve some existing code I'd reconsider in favor of simplicity. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the review :-) I agree the existing logic is bit confusing. I am also not sure why the forOp results are yielded instead of the iterArgs directly. Maybe code owners have some insights on this? But overall the logic will be similar apart from this part (even for a refactored version). This pattern is anyway bit complex since it has a lot of moving parts. :-) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the code is to add the initial value of the loop-carried variables to the new warp op's result - Instead collect the loop result and fixup latter, can we directly collect "forOp.getInitArgs()" as escaping value - since these are escaping value of the for statement, but not captured by the original code which only scan the loop body. |
||||
if (llvm::is_contained(distTypes, Type{})) | ||||
return failure(); | ||||
|
||||
|
@@ -1609,7 +1624,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern { | |||
forOp.getResultTypes().end()); | ||||
llvm::SmallDenseMap<Value, int64_t> argIndexMapping; | ||||
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) { | ||||
warpInput.push_back(newWarpOp.getResult(retIdx)); | ||||
auto newWarpResult = newWarpOp.getResult(retIdx); | ||||
// Unused forOp results yielded by the warpOp region are already included | ||||
// in the new ForOp. | ||||
if (llvm::is_contained(newOperands, newWarpResult)) | ||||
continue; | ||||
warpInput.push_back(newWarpResult); | ||||
argIndexMapping[escapingValues[i]] = warpInputType.size(); | ||||
warpInputType.push_back(inputTypes[i]); | ||||
} | ||||
|
Uh oh!
There was an error while loading. Please reload this page.