-
Notifications
You must be signed in to change notification settings - Fork 14k
Add Dead Block Elimination to NVVMReflect #144171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Yonah Goldberg (YonahGoldberg) ChangesCurrently, NVVMReflect replaces calls to __nvvm_reflect with a constant, and then constant propagates/folds the result, but doesn't handle dead block elimination. The most common use case of reflect calls is to query the arch number and select valid code depending on the arch. Therefore, the blocks that become dead after reflect replacement need to be deleted as a matter of correctness. The way this gets cleaned up now in llc is with UnreachableBlockElim followed by CodegenPrepare, which I've observed work together to delete the dead blocks. It's better to just have this pass handle deleting the dead blocks right away. This PR introduces some additional code to handle the dead block deletion. I think what I've written is actually pretty general, it's kind've like a lightweight version of SCCP. I wonder if I missed somewhere where this is already implemented in LLVM so I don't duplicate code. If I didn't, would it ever be useful to put this somewhere more general where others can use it instead of in NVVMReflect? Note that I also removed running simplifycfg in two test cases, which shows that this pass is now able to handle the dead block elimination without simplifycfg. Full diff: https://github.com/llvm/llvm-project/pull/144171.diff 3 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 208bab52284a3..2585ff45bde4c 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -19,6 +19,7 @@
//===----------------------------------------------------------------------===//
#include "NVPTX.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ConstantFolding.h"
@@ -59,7 +60,10 @@ class NVVMReflect {
StringMap<unsigned> ReflectMap;
bool handleReflectFunction(Module &M, StringRef ReflectName);
void populateReflectMap(Module &M);
- void foldReflectCall(CallInst *Call, Constant *NewValue);
+ void replaceReflectCalls(
+ SmallVector<std::pair<CallInst *, Constant *>, 8> &ReflectReplacements,
+ const DataLayout &DL);
+ SetVector<BasicBlock *> findTransitivelyDeadBlocks(BasicBlock *DeadBB);
public:
// __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
@@ -138,6 +142,8 @@ bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
assert(F->getReturnType()->isIntegerTy() &&
"_reflect's return type should be integer");
+ SmallVector<std::pair<CallInst *, Constant *>, 8> ReflectReplacements;
+
const bool Changed = !F->use_empty();
for (User *U : make_early_inc_range(F->users())) {
// Reflect function calls look like:
@@ -178,38 +184,111 @@ bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
<< "(" << ReflectArg << ") with value " << ReflectVal
<< "\n");
auto *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
- foldReflectCall(Call, NewValue);
- Call->eraseFromParent();
+ ReflectReplacements.push_back({Call, NewValue});
}
- // Remove the __nvvm_reflect function from the module
+ replaceReflectCalls(ReflectReplacements, M.getDataLayout());
F->eraseFromParent();
return Changed;
}
-void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
+/// Find all blocks that become dead transitively from an initial dead block.
+/// Returns the complete set including the original dead block and any blocks
+/// that lose all their predecessors due to the deletion cascade.
+SetVector<BasicBlock *>
+NVVMReflect::findTransitivelyDeadBlocks(BasicBlock *DeadBB) {
+ SmallVector<BasicBlock *, 8> Worklist({DeadBB});
+ SetVector<BasicBlock *> DeadBlocks;
+ while (!Worklist.empty()) {
+ auto *BB = Worklist.pop_back_val();
+ DeadBlocks.insert(BB);
+
+ for (BasicBlock *Succ : successors(BB))
+ if (pred_size(Succ) == 1 && DeadBlocks.insert(Succ))
+ Worklist.push_back(Succ);
+ }
+ return DeadBlocks;
+}
+
+/// Replace calls to __nvvm_reflect with corresponding constant values. Then
+/// clean up through constant folding and propagation and dead block
+/// elimination.
+///
+/// The purpose of this cleanup is not optimization because that could be
+/// handled by later passes
+/// (i.e. SCCP, SimplifyCFG, etc.), but for correctness. Reflect calls are most
+/// commonly used to query the arch number and select a valid instruction for
+/// the arch. Therefore, you need to eliminate blocks that become dead because
+/// they may contain invalid instructions for the arch. The purpose of the
+/// cleanup is to do the minimal amount of work to leave the code in a valid
+/// state.
+void NVVMReflect::replaceReflectCalls(
+ SmallVector<std::pair<CallInst *, Constant *>, 8> &ReflectReplacements,
+ const DataLayout &DL) {
SmallVector<Instruction *, 8> Worklist;
- // Replace an instruction with a constant and add all users of the instruction
- // to the worklist
+ SetVector<BasicBlock *> DeadBlocks;
+
+ // Replace an instruction with a constant and add all users to the worklist,
+ // then delete the instruction
auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) {
for (auto *U : I->users())
if (auto *UI = dyn_cast<Instruction>(U))
Worklist.push_back(UI);
I->replaceAllUsesWith(C);
+ if (isInstructionTriviallyDead(I))
+ I->eraseFromParent();
};
- ReplaceInstructionWithConst(Call, NewValue);
+ for (auto &[Call, NewValue] : ReflectReplacements)
+ ReplaceInstructionWithConst(Call, NewValue);
- auto &DL = Call->getModule()->getDataLayout();
- while (!Worklist.empty()) {
- auto *I = Worklist.pop_back_val();
- if (auto *C = ConstantFoldInstruction(I, DL)) {
- ReplaceInstructionWithConst(I, C);
- if (isInstructionTriviallyDead(I))
- I->eraseFromParent();
- } else if (I->isTerminator()) {
- ConstantFoldTerminator(I->getParent());
+ // Alternate between constant folding/propagation and dead block elimination.
+ // Terminator folding may create new dead blocks. When those dead blocks are
+ // deleted, their live successors may have PHIs that can be simplified, which
+ // may yield more work for folding/propagation.
+ while (true) {
+ // Iterate folding and propagating constants until the worklist is empty.
+ while (!Worklist.empty()) {
+ auto *I = Worklist.pop_back_val();
+ if (auto *C = ConstantFoldInstruction(I, DL)) {
+ ReplaceInstructionWithConst(I, C);
+ } else if (I->isTerminator()) {
+ BasicBlock *BB = I->getParent();
+ SmallVector<BasicBlock *, 8> Succs(successors(BB));
+ // Some blocks may become dead if the terminator is folded because
+ // a conditional branch is turned into a direct branch.
+ if (ConstantFoldTerminator(BB)) {
+ for (BasicBlock *Succ : Succs) {
+ if (pred_empty(Succ) &&
+ Succ != &Succ->getParent()->getEntryBlock()) {
+ SetVector<BasicBlock *> TransitivelyDead =
+ findTransitivelyDeadBlocks(Succ);
+ DeadBlocks.insert(TransitivelyDead.begin(),
+ TransitivelyDead.end());
+ }
+ }
+ }
+ }
}
+ // No more constants to fold and no more dead blocks
+ // to create more work. We're done.
+ if (DeadBlocks.empty())
+ break;
+ // PHI nodes of live successors of dead blocks get eliminated when the dead
+ // blocks are eliminated. Their users can now be simplified further, so add
+ // them to the worklist.
+ for (BasicBlock *DeadBB : DeadBlocks)
+ for (BasicBlock *Succ : successors(DeadBB))
+ if (!DeadBlocks.contains(Succ))
+ for (PHINode &PHI : Succ->phis())
+ for (auto *U : PHI.users())
+ if (auto *UI = dyn_cast<Instruction>(U))
+ Worklist.push_back(UI);
+ // Delete all dead blocks
+ for (BasicBlock *DeadBB : DeadBlocks)
+ DeleteDeadBlock(DeadBB);
+
+ DeadBlocks.clear();
}
}
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll
index 19c74df303702..7bb1af707001a 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll
@@ -3,12 +3,12 @@
; RUN: cat %s > %t.noftz
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 0}' >> %t.noftz
-; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
+; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
; RUN: | FileCheck %s --check-prefix=USE_FTZ_0 --check-prefix=CHECK
; RUN: cat %s > %t.ftz
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.ftz
-; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
+; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
; RUN: | FileCheck %s --check-prefix=USE_FTZ_1 --check-prefix=CHECK
@str = private unnamed_addr addrspace(4) constant [11 x i8] c"__CUDA_FTZ\00"
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll
index 244b44fea9b83..581dbf353c1ff 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll
@@ -3,12 +3,12 @@
; RUN: cat %s > %t.noftz
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 0}' >> %t.noftz
-; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
+; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
; RUN: | FileCheck %s --check-prefix=USE_FTZ_0 --check-prefix=CHECK
; RUN: cat %s > %t.ftz
; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.ftz
-; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
+; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
; RUN: | FileCheck %s --check-prefix=USE_FTZ_1 --check-prefix=CHECK
@str = private unnamed_addr addrspace(4) constant [11 x i8] c"__CUDA_FTZ\00"
|
I'm not quite sure what is the problem the patch is intended to solve.
How is that better than the optimization passes we already have to do exactly that job? NVVMReflect is normally added very early in the pipeline, and subsequent passes do a pretty good job eliminating the dead code after the call is replaced with a constant. llvm-project/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp Lines 245 to 250 in d7e64d9
|
@Artem-B The problem is that Codegen Prepare, which is the pass in llc that ultimately cleans up the dead code introduced by reflect elimination, doesn't get run in llc -O0. So in many cases with llc -O0 you will leave behind a ptx instruction that is illegal for the subtarget. To ensure that this cleanup always happens, regardless of the optimization pipeline, it's better to have it done in the NVVMReflect pass. |
OK. That's a good point. But. How is incorporating DCE into NVVMReflect pass better than running the existing DCE pass right after reflect, regardless of the optimization level? You can not avoid the fact that to solve the problem, you do need to run DCE one way or another. We may as well just run the existing DCE pass and avoid duplicating the functionality. The second-order problem is that unconditionally running DCE may not be universally desirable. After all, if the user explicitly specifies In any case, it may be prudent to implement this as an optional feature, controlled by a command line option, if somebody really needs to enable or avoid DCE after reflect pass. |
You said it yourself that running the existing DCE pass in O0 isn't optimal because that's not what the user wants. The code I wrote only eliminates blocks that become dead as a result from folding reflect calls because it's necessary for valid code. All other dead blocks are left alone. |
It's hard to say why you haven't run into a bug with this. It might be that running llc with -O0 is uncommon for NVPTX. I think we had a bug internally that was the result of this a couple years ago. |
You're arguing about the degree of optimization, but the fact that you do the optimization remains.
The only code NVVMReflect operates on in practice is CUDA's libdevice. The way libdevice is linked into the CUDA module is on as-needed basis, so we only pull in the functions that are actually needed by the module. So, it's quite possible that we never compiled any function that pulls in uncompileable IR from the libdevice with As for compile-ability of the IR that relies on NVVMReflect, this is the same issue as an attempt to compile the code that looks like this:
Should this code be expected to compile to something sensible with -O0? With any other optimization option? I would argue that the answer is "no" and that the same argument applies to libdevice and the use of nvvm_reflect(), if one of the if branches contains uncompileable code. We can make ther example above work with C++17 Considering that libdevice and nvvm_reflect() do exist, I'm OK with giving the user an option to make it work, with reasonable trade-offs, but I do not want to create an illusion that it is something we want to support. NVVMReflect's attempt to be a n IR-level preprocessor is broken by design, IMO. It happens to work most of the time, but it implicitly relies on things IR does not guarantee. I do not think we want to add such guarantees. If for some reason we do want to provide a library which does contain IR for different targets, it should be via separate per-target IR blobs, IMO. These days Clang does support LTO and knows how to package multiple IR variants into an object file, and linkt the right blobs together at the end, so we have a better mechanism for providing GPU-side libraries. @jhuber6 WDYT? |
We ideally want similar functionality to what #134016 is doing. Basically late resolve features by eliminating potentially invalid code. The previous version did the simple DCE which was good enough to eliminate trivial uses, but we definitely need something more complicated for it to be completely robust. We do this at O0 because otherwise it's useless as a means to guard potentially invalid code for that target. In either case this is a bit of a hack because a lot of target specific features require |
Fascinating. Looks like the
Do you have an example?
Yes, the invalid code should either never show up in IR (AST-level resolution in #134016) or eliminated by something on IR level. In this case we're arguing about how to do that, and the options are:
Now, it looks like we have another question -- how robust is DCE or the folding code in this patch.
Yup. That's why I think that the right way out of this well painted corner is to provide per-target bitcode which is intended to be valid for the given target. I do not think we can guarantee "generic" IR functionality for everyone at once. There will always be corner cases when we mix it with other IR modules, unless the IR in question itself is very generic, or we have some way to enforce that the wrong parts were eliminated before we mix it with any other IR. E.g. in this case, the const-folding and DCE would work fine if it were applied to libdevice only. However, nothing stops users from using I don't want to keep adding more magic to NVVMReflect. If we want a corner case workaround, I think it should be an explicitly enabled thing, with the clear understanding of the trade-offs involved. If we want a robust solution -- this PR is not it. |
This is fair. I think there's some argument to be made that we should always at least clean up the most common use case, which is code that looks like: |
Currently, NVVMReflect replaces calls to __nvvm_reflect with a constant, and then constant propagates/folds the result, but doesn't handle dead block elimination.
The most common use case of reflect calls is to query the arch number and select valid code depending on the arch. Therefore, the blocks that become dead after reflect replacement need to be deleted as a matter of correctness.
The way this gets cleaned up now in llc is with UnreachableBlockElim followed by CodegenPrepare, which I've observed work together to delete the dead blocks. It's better to just have this pass handle deleting the dead blocks right away.
This PR introduces some additional code to handle the dead block deletion. I think what I've written is actually pretty general, it's kind've like a lightweight version of SCCP. I wonder if I missed somewhere where this is already implemented in LLVM so I don't duplicate code. If I didn't, would it ever be useful to put this somewhere more general where others can use it instead of in NVVMReflect?
Note that I also removed running simplifycfg in two test cases, which shows that this pass is now able to handle the dead block elimination without simplifycfg.