Skip to content

Commit 5904168

Browse files
committed
[DA] DivergenceAnalysis for unstructured, reducible CFGs
Summary: This is patch 2 of the new DivergenceAnalysis (https://reviews.llvm.org/D50433). This patch contains a generic divergence analysis implementation for unstructured, reducible Control-Flow Graphs. It contains two new classes. The `SyncDependenceAnalysis` class lazily computes sync dependences, which relate divergent branches to points of joining divergent control. The `DivergenceAnalysis` class contains the generic divergence analysis implementation. Reviewers: nhaehnle Reviewed By: nhaehnle Subscribers: sameerds, kristina, nhaehnle, xbolva00, tschuett, mgorny, llvm-commits Differential Revision: https://reviews.llvm.org/D51491 llvm-svn: 344734
1 parent 547f89d commit 5904168

File tree

8 files changed

+1508
-0
lines changed

8 files changed

+1508
-0
lines changed

llvm/include/llvm/ADT/PostOrderIterator.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,15 @@ class ReversePostOrderTraversal {
296296

297297
public:
298298
using rpo_iterator = typename std::vector<NodeRef>::reverse_iterator;
299+
using const_rpo_iterator = typename std::vector<NodeRef>::const_reverse_iterator;
299300

300301
ReversePostOrderTraversal(GraphT G) { Initialize(GT::getEntryNode(G)); }
301302

302303
// Because we want a reverse post order, use reverse iterators from the vector
303304
rpo_iterator begin() { return Blocks.rbegin(); }
305+
const_rpo_iterator begin() const { return Blocks.crbegin(); }
304306
rpo_iterator end() { return Blocks.rend(); }
307+
const_rpo_iterator end() const { return Blocks.crend(); }
305308
};
306309

307310
} // end namespace llvm
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
//===- llvm/Analysis/DivergenceAnalysis.h - Divergence Analysis -*- C++ -*-===//
2+
//
3+
// The LLVM Compiler Infrastructure
4+
//
5+
// This file is distributed under the University of Illinois Open Source
6+
// License. See LICENSE.TXT for details.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// \file
11+
// The divergence analysis determines which instructions and branches are
12+
// divergent given a set of divergent source instructions.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#ifndef LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
17+
#define LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
18+
19+
#include "llvm/ADT/DenseSet.h"
20+
#include "llvm/Analysis/SyncDependenceAnalysis.h"
21+
#include "llvm/IR/Function.h"
22+
#include "llvm/Pass.h"
23+
#include <vector>
24+
25+
namespace llvm {
26+
class Module;
27+
class Value;
28+
class Instruction;
29+
class Loop;
30+
class raw_ostream;
31+
class TargetTransformInfo;
32+
33+
/// \brief Generic divergence analysis for reducible CFGs.
34+
///
35+
/// This analysis propagates divergence in a data-parallel context from sources
36+
/// of divergence to all users. It requires reducible CFGs. All assignments
37+
/// should be in SSA form.
38+
class DivergenceAnalysis {
39+
public:
40+
/// \brief This instance will analyze the whole function \p F or the loop \p
41+
/// RegionLoop.
42+
///
43+
/// \param RegionLoop if non-null the analysis is restricted to \p RegionLoop.
44+
/// Otherwise the whole function is analyzed.
45+
/// \param IsLCSSAForm whether the analysis may assume that the IR in the
46+
/// region in in LCSSA form.
47+
DivergenceAnalysis(const Function &F, const Loop *RegionLoop,
48+
const DominatorTree &DT, const LoopInfo &LI,
49+
SyncDependenceAnalysis &SDA, bool IsLCSSAForm);
50+
51+
/// \brief The loop that defines the analyzed region (if any).
52+
const Loop *getRegionLoop() const { return RegionLoop; }
53+
const Function &getFunction() const { return F; }
54+
55+
/// \brief Whether \p BB is part of the region.
56+
bool inRegion(const BasicBlock &BB) const;
57+
/// \brief Whether \p I is part of the region.
58+
bool inRegion(const Instruction &I) const;
59+
60+
/// \brief Mark \p UniVal as a value that is always uniform.
61+
void addUniformOverride(const Value &UniVal);
62+
63+
/// \brief Mark \p DivVal as a value that is always divergent.
64+
void markDivergent(const Value &DivVal);
65+
66+
/// \brief Propagate divergence to all instructions in the region.
67+
/// Divergence is seeded by calls to \p markDivergent.
68+
void compute();
69+
70+
/// \brief Whether any value was marked or analyzed to be divergent.
71+
bool hasDetectedDivergence() const { return !DivergentValues.empty(); }
72+
73+
/// \brief Whether \p Val will always return a uniform value regardless of its
74+
/// operands
75+
bool isAlwaysUniform(const Value &Val) const;
76+
77+
/// \brief Whether \p Val is a divergent value
78+
bool isDivergent(const Value &Val) const;
79+
80+
void print(raw_ostream &OS, const Module *) const;
81+
82+
private:
83+
bool updateTerminator(const TerminatorInst &Term) const;
84+
bool updatePHINode(const PHINode &Phi) const;
85+
86+
/// \brief Computes whether \p Inst is divergent based on the
87+
/// divergence of its operands.
88+
///
89+
/// \returns Whether \p Inst is divergent.
90+
///
91+
/// This should only be called for non-phi, non-terminator instructions.
92+
bool updateNormalInstruction(const Instruction &Inst) const;
93+
94+
/// \brief Mark users of live-out users as divergent.
95+
///
96+
/// \param LoopHeader the header of the divergent loop.
97+
///
98+
/// Marks all users of live-out values of the loop headed by \p LoopHeader
99+
/// as divergent and puts them on the worklist.
100+
void taintLoopLiveOuts(const BasicBlock &LoopHeader);
101+
102+
/// \brief Push all users of \p Val (in the region) to the worklist
103+
void pushUsers(const Value &I);
104+
105+
/// \brief Push all phi nodes in @block to the worklist
106+
void pushPHINodes(const BasicBlock &Block);
107+
108+
/// \brief Mark \p Block as join divergent
109+
///
110+
/// A block is join divergent if two threads may reach it from different
111+
/// incoming blocks at the same time.
112+
void markBlockJoinDivergent(const BasicBlock &Block) {
113+
DivergentJoinBlocks.insert(&Block);
114+
}
115+
116+
/// \brief Whether \p Val is divergent when read in \p ObservingBlock.
117+
bool isTemporalDivergent(const BasicBlock &ObservingBlock,
118+
const Value &Val) const;
119+
120+
/// \brief Whether \p Block is join divergent
121+
///
122+
/// (see markBlockJoinDivergent).
123+
bool isJoinDivergent(const BasicBlock &Block) const {
124+
return DivergentJoinBlocks.find(&Block) != DivergentJoinBlocks.end();
125+
}
126+
127+
/// \brief Propagate control-induced divergence to users (phi nodes and
128+
/// instructions).
129+
//
130+
// \param JoinBlock is a divergent loop exit or join point of two disjoint
131+
// paths.
132+
// \returns Whether \p JoinBlock is a divergent loop exit of \p TermLoop.
133+
bool propagateJoinDivergence(const BasicBlock &JoinBlock,
134+
const Loop *TermLoop);
135+
136+
/// \brief Propagate induced value divergence due to control divergence in \p
137+
/// Term.
138+
void propagateBranchDivergence(const TerminatorInst &Term);
139+
140+
/// \brief Propagate divergent caused by a divergent loop exit.
141+
///
142+
/// \param ExitingLoop is a divergent loop.
143+
void propagateLoopDivergence(const Loop &ExitingLoop);
144+
145+
private:
146+
const Function &F;
147+
// If regionLoop != nullptr, analysis is only performed within \p RegionLoop.
148+
// Otw, analyze the whole function
149+
const Loop *RegionLoop;
150+
151+
const DominatorTree &DT;
152+
const LoopInfo &LI;
153+
154+
// Recognized divergent loops
155+
DenseSet<const Loop *> DivergentLoops;
156+
157+
// The SDA links divergent branches to divergent control-flow joins.
158+
SyncDependenceAnalysis &SDA;
159+
160+
// Use simplified code path for LCSSA form.
161+
bool IsLCSSAForm;
162+
163+
// Set of known-uniform values.
164+
DenseSet<const Value *> UniformOverrides;
165+
166+
// Blocks with joining divergent control from different predecessors.
167+
DenseSet<const BasicBlock *> DivergentJoinBlocks;
168+
169+
// Detected/marked divergent values.
170+
DenseSet<const Value *> DivergentValues;
171+
172+
// Internal worklist for divergence propagation.
173+
std::vector<const Instruction *> Worklist;
174+
};
175+
176+
} // namespace llvm
177+
178+
#endif // LLVM_ANALYSIS_DIVERGENCE_ANALYSIS_H
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
//===- SyncDependenceAnalysis.h - Divergent Branch Dependence -*- C++ -*-===//
2+
//
3+
// The LLVM Compiler Infrastructure
4+
//
5+
// This file is distributed under the University of Illinois Open Source
6+
// License. See LICENSE.TXT for details.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// \file
11+
// This file defines the SyncDependenceAnalysis class, which computes for
12+
// every divergent branch the set of phi nodes that the branch will make
13+
// divergent.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#ifndef LLVM_ANALYSIS_SYNC_DEPENDENCE_ANALYSIS_H
18+
#define LLVM_ANALYSIS_SYNC_DEPENDENCE_ANALYSIS_H
19+
20+
#include "llvm/ADT/DenseMap.h"
21+
#include "llvm/ADT/PostOrderIterator.h"
22+
#include "llvm/ADT/SmallPtrSet.h"
23+
#include "llvm/Analysis/LoopInfo.h"
24+
#include <memory>
25+
26+
namespace llvm {
27+
28+
class BasicBlock;
29+
class DominatorTree;
30+
class Loop;
31+
class PostDominatorTree;
32+
class TerminatorInst;
33+
class TerminatorInst;
34+
35+
using ConstBlockSet = SmallPtrSet<const BasicBlock *, 4>;
36+
37+
/// \brief Relates points of divergent control to join points in
38+
/// reducible CFGs.
39+
///
40+
/// This analysis relates points of divergent control to points of converging
41+
/// divergent control. The analysis requires all loops to be reducible.
42+
class SyncDependenceAnalysis {
43+
void visitSuccessor(const BasicBlock &succBlock, const Loop *termLoop,
44+
const BasicBlock *defBlock);
45+
46+
public:
47+
bool inRegion(const BasicBlock &BB) const;
48+
49+
~SyncDependenceAnalysis();
50+
SyncDependenceAnalysis(const DominatorTree &DT, const PostDominatorTree &PDT,
51+
const LoopInfo &LI);
52+
53+
/// \brief Computes divergent join points and loop exits caused by branch
54+
/// divergence in \p Term.
55+
///
56+
/// The set of blocks which are reachable by disjoint paths from \p Term.
57+
/// The set also contains loop exits if there two disjoint paths:
58+
/// one from \p Term to the loop exit and another from \p Term to the loop
59+
/// header. Those exit blocks are added to the returned set.
60+
/// If L is the parent loop of \p Term and an exit of L is in the returned
61+
/// set then L is a divergent loop.
62+
const ConstBlockSet &join_blocks(const TerminatorInst &Term);
63+
64+
/// \brief Computes divergent join points and loop exits (in the surrounding
65+
/// loop) caused by the divergent loop exits of\p Loop.
66+
///
67+
/// The set of blocks which are reachable by disjoint paths from the
68+
/// loop exits of \p Loop.
69+
/// This treats the loop as a single node in \p Loop's parent loop.
70+
/// The returned set has the same properties as for join_blocks(TermInst&).
71+
const ConstBlockSet &join_blocks(const Loop &Loop);
72+
73+
private:
74+
static ConstBlockSet EmptyBlockSet;
75+
76+
ReversePostOrderTraversal<const Function *> FuncRPOT;
77+
const DominatorTree &DT;
78+
const PostDominatorTree &PDT;
79+
const LoopInfo &LI;
80+
81+
std::map<const Loop *, std::unique_ptr<ConstBlockSet>> CachedLoopExitJoins;
82+
std::map<const TerminatorInst *, std::unique_ptr<ConstBlockSet>>
83+
CachedBranchJoins;
84+
};
85+
86+
} // namespace llvm
87+
88+
#endif // LLVM_ANALYSIS_SYNC_DEPENDENCE_ANALYSIS_H

llvm/lib/Analysis/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ add_llvm_library(LLVMAnalysis
2525
Delinearization.cpp
2626
DemandedBits.cpp
2727
DependenceAnalysis.cpp
28+
DivergenceAnalysis.cpp
2829
DomPrinter.cpp
2930
DominanceFrontier.cpp
3031
EHPersonalities.cpp
@@ -80,6 +81,7 @@ add_llvm_library(LLVMAnalysis
8081
ScalarEvolutionAliasAnalysis.cpp
8182
ScalarEvolutionExpander.cpp
8283
ScalarEvolutionNormalization.cpp
84+
SyncDependenceAnalysis.cpp
8385
SyntheticCountsUtils.cpp
8486
TargetLibraryInfo.cpp
8587
TargetTransformInfo.cpp

0 commit comments

Comments
 (0)