Skip to content

[SimplifyCFG] Avoid branch threading of divergent conditionals #141867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion llvm/include/llvm/Transforms/Utils/Local.h
Original file line number Diff line number Diff line change
@@ -49,6 +49,11 @@ class StoreInst;
class TargetLibraryInfo;
class TargetTransformInfo;

template <typename T> class GenericSSAContext;
using SSAContext = GenericSSAContext<Function>;
template <typename T> class GenericUniformityInfo;
using UniformityInfo = GenericUniformityInfo<SSAContext>;

//===----------------------------------------------------------------------===//
// Local constant propagation.
//
@@ -183,7 +188,7 @@ bool EliminateDuplicatePHINodes(BasicBlock *BB,
/// providing the set of loop headers that SimplifyCFG should not eliminate.
extern cl::opt<bool> RequireAndPreserveDomTree;
bool simplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI,
DomTreeUpdater *DTU = nullptr,
DomTreeUpdater *DTU = nullptr, UniformityInfo *UI = nullptr,
const SimplifyCFGOptions &Options = {},
ArrayRef<WeakVH> LoopHeaders = {});

25 changes: 16 additions & 9 deletions llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/UniformityAnalysis.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/DebugInfoMetadata.h"
@@ -229,7 +230,7 @@ static bool tailMergeBlocksWithSimilarFunctionTerminators(Function &F,
/// Call SimplifyCFG on all the blocks in the function,
/// iterating until no more changes are made.
static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,
DomTreeUpdater *DTU,
DomTreeUpdater *DTU, UniformityInfo *UI,
const SimplifyCFGOptions &Options) {
bool Changed = false;
bool LocalChange = true;
@@ -261,7 +262,7 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,
while (BBIt != F.end() && DTU->isBBPendingDeletion(&*BBIt))
++BBIt;
}
if (simplifyCFG(&BB, TTI, DTU, Options, LoopHeaders)) {
if (simplifyCFG(&BB, TTI, DTU, UI, Options, LoopHeaders)) {
LocalChange = true;
++NumSimpl;
}
@@ -272,14 +273,15 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,
}

static bool simplifyFunctionCFGImpl(Function &F, const TargetTransformInfo &TTI,
DominatorTree *DT,
DominatorTree *DT, UniformityInfo *UI,
const SimplifyCFGOptions &Options) {
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);

bool EverChanged = removeUnreachableBlocks(F, DT ? &DTU : nullptr);
EverChanged |=
tailMergeBlocksWithSimilarFunctionTerminators(F, DT ? &DTU : nullptr);
EverChanged |= iterativelySimplifyCFG(F, TTI, DT ? &DTU : nullptr, Options);
EverChanged |=
iterativelySimplifyCFG(F, TTI, DT ? &DTU : nullptr, UI, Options);

// If neither pass changed anything, we're done.
if (!EverChanged) return false;
@@ -293,21 +295,22 @@ static bool simplifyFunctionCFGImpl(Function &F, const TargetTransformInfo &TTI,
return true;

do {
EverChanged = iterativelySimplifyCFG(F, TTI, DT ? &DTU : nullptr, Options);
EverChanged =
iterativelySimplifyCFG(F, TTI, DT ? &DTU : nullptr, UI, Options);
EverChanged |= removeUnreachableBlocks(F, DT ? &DTU : nullptr);
} while (EverChanged);

return true;
}

static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI,
DominatorTree *DT,
DominatorTree *DT, UniformityInfo *UI,
const SimplifyCFGOptions &Options) {
assert((!RequireAndPreserveDomTree ||
(DT && DT->verify(DominatorTree::VerificationLevel::Full))) &&
"Original domtree is invalid?");

bool Changed = simplifyFunctionCFGImpl(F, TTI, DT, Options);
bool Changed = simplifyFunctionCFGImpl(F, TTI, DT, UI, Options);

assert((!RequireAndPreserveDomTree ||
(DT && DT->verify(DominatorTree::VerificationLevel::Full))) &&
@@ -378,7 +381,8 @@ PreservedAnalyses SimplifyCFGPass::run(Function &F,
DominatorTree *DT = nullptr;
if (RequireAndPreserveDomTree)
DT = &AM.getResult<DominatorTreeAnalysis>(F);
if (!simplifyFunctionCFG(F, TTI, DT, Options))
auto *UA = &AM.getResult<UniformityInfoAnalysis>(F);
if (!simplifyFunctionCFG(F, TTI, DT, UA, Options))
Comment on lines +384 to +385
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Naming inconsistency. It's UI everywhere else.

return PreservedAnalyses::all();
PreservedAnalyses PA;
if (RequireAndPreserveDomTree)
@@ -412,7 +416,8 @@ struct CFGSimplifyPass : public FunctionPass {
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();

auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
return simplifyFunctionCFG(F, TTI, DT, Options);
auto &UI = getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
return simplifyFunctionCFG(F, TTI, DT, &UI, Options);
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<AssumptionCacheTracker>();
@@ -422,6 +427,7 @@ struct CFGSimplifyPass : public FunctionPass {
if (RequireAndPreserveDomTree)
AU.addPreserved<DominatorTreeWrapperPass>();
AU.addPreserved<GlobalsAAWrapperPass>();
AU.addRequired<UniformityInfoWrapperPass>();
}
};
}
@@ -432,6 +438,7 @@ INITIALIZE_PASS_BEGIN(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false,
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
INITIALIZE_PASS_END(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false,
false)

55 changes: 39 additions & 16 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/UniformityAnalysis.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
@@ -258,6 +259,7 @@ struct ValueEqualityComparisonCase {
class SimplifyCFGOpt {
const TargetTransformInfo &TTI;
DomTreeUpdater *DTU;
UniformityInfo *UI;
const DataLayout &DL;
ArrayRef<WeakVH> LoopHeaders;
const SimplifyCFGOptions &Options;
@@ -306,9 +308,10 @@ class SimplifyCFGOpt {

public:
SimplifyCFGOpt(const TargetTransformInfo &TTI, DomTreeUpdater *DTU,
const DataLayout &DL, ArrayRef<WeakVH> LoopHeaders,
const SimplifyCFGOptions &Opts)
: TTI(TTI), DTU(DTU), DL(DL), LoopHeaders(LoopHeaders), Options(Opts) {
const DataLayout &DL, UniformityInfo *UI,
ArrayRef<WeakVH> LoopHeaders, const SimplifyCFGOptions &Opts)
: TTI(TTI), DTU(DTU), UI(UI), DL(DL), LoopHeaders(LoopHeaders),
Options(Opts) {
assert((!DTU || !DTU->hasPostDomTree()) &&
"SimplifyCFG is not yet capable of maintaining validity of a "
"PostDomTree, so don't ask for it.");
@@ -3490,6 +3493,17 @@ static bool blockIsSimpleEnoughToThreadThrough(BasicBlock *BB) {
return true;
}

static bool blockIsFreeToThreadThrough(BasicBlock *BB, PHINode *PN) {
unsigned Size = 0;
for (Instruction &I : BB->instructionsWithoutDebug(false)) {
if (&I == PN)
continue;
if (++Size > 1)
return false;
Comment on lines +3499 to +3502
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd prefer "count what we want" instead of "filter unwanted instructions first, count the rest later".

Size += (&I != PN);
if (Size > 1 )
  return false;

}
return true;
}

static ConstantInt *getKnownValueOnEdge(Value *V, BasicBlock *From,
BasicBlock *To) {
// Don't look past the block defining the value, we might get the value from
@@ -3511,10 +3525,9 @@ static ConstantInt *getKnownValueOnEdge(Value *V, BasicBlock *From,
/// If we have a conditional branch on something for which we know the constant
/// value in predecessors (e.g. a phi node in the current block), thread edges
/// from the predecessor to their ultimate destination.
static std::optional<bool>
foldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
const DataLayout &DL,
AssumptionCache *AC) {
static std::optional<bool> foldCondBranchOnValueKnownInPredecessorImpl(
BranchInst *BI, DomTreeUpdater *DTU, const DataLayout &DL,
const TargetTransformInfo &TTI, UniformityInfo *UI, AssumptionCache *AC) {
SmallMapVector<ConstantInt *, SmallSetVector<BasicBlock *, 2>, 2> KnownValues;
BasicBlock *BB = BI->getParent();
Value *Cond = BI->getCondition();
@@ -3555,6 +3568,16 @@ foldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
if (RealDest == BB)
continue; // Skip self loops.

// Check to see that we're not duplicating instructions into divergent
// branches. Doing so would essentially double the execution time, since
// the instructions will be executed by divergent threads serially.
if (TTI.hasBranchDivergence() && UI &&
!blockIsFreeToThreadThrough(BB, PN) &&
any_of(PredBBs, [&](BasicBlock *PredBB) {
return UI->hasDivergentTerminator(*PredBB);
}))
continue;

// Skip if the predecessor's terminator is an indirect branch.
if (any_of(PredBBs, [](BasicBlock *PredBB) {
return isa<IndirectBrInst>(PredBB->getTerminator());
@@ -3669,15 +3692,15 @@ foldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU,
return false;
}

static bool foldCondBranchOnValueKnownInPredecessor(BranchInst *BI,
DomTreeUpdater *DTU,
const DataLayout &DL,
AssumptionCache *AC) {
static bool foldCondBranchOnValueKnownInPredecessor(
BranchInst *BI, DomTreeUpdater *DTU, const DataLayout &DL,
const TargetTransformInfo &TTI, UniformityInfo *UI, AssumptionCache *AC) {
std::optional<bool> Result;
bool EverChanged = false;
do {
// Note that None means "we changed things, but recurse further."
Result = foldCondBranchOnValueKnownInPredecessorImpl(BI, DTU, DL, AC);
Result =
foldCondBranchOnValueKnownInPredecessorImpl(BI, DTU, DL, TTI, UI, AC);
EverChanged |= Result == std::nullopt || *Result;
} while (Result == std::nullopt);
return EverChanged;
@@ -8082,7 +8105,7 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
// If this is a branch on something for which we know the constant value in
// predecessors (e.g. a phi node in the current block), thread control
// through this block.
if (foldCondBranchOnValueKnownInPredecessor(BI, DTU, DL, Options.AC))
if (foldCondBranchOnValueKnownInPredecessor(BI, DTU, DL, TTI, UI, Options.AC))
return requestResimplify();

// Scan predecessor blocks for conditional branches.
@@ -8402,9 +8425,9 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) {
}

bool llvm::simplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI,
DomTreeUpdater *DTU, const SimplifyCFGOptions &Options,
DomTreeUpdater *DTU, UniformityInfo *UI,
const SimplifyCFGOptions &Options,
ArrayRef<WeakVH> LoopHeaders) {
return SimplifyCFGOpt(TTI, DTU, BB->getDataLayout(), LoopHeaders,
Options)
return SimplifyCFGOpt(TTI, DTU, BB->getDataLayout(), UI, LoopHeaders, Options)
.run(BB);
}
2 changes: 2 additions & 0 deletions llvm/test/Transforms/SimplifyCFG/NVPTX/lit.local.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
if not "NVPTX" in config.root.targets:
config.unsupported = True
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.