Skip to content

Commit ccbbb17

Browse files
authored
[mlir] [dataflow] : Improve the time and space footprint of data flow. (llvm#135325)
MLIR's data flow analysis (especially dense data flow analysis) constructs a lattice at every lattice anchor (which, for dense data flow, means every program point). As the program grows larger, the time and space complexity can become unmanageable. However, in many programs, the lattice values at numerous lattice anchors are actually identical. We can leverage this observation to improve the complexity of data flow analysis. This patch introducing equivalence lattice anchor to group lattice anchors that must contains identical lattice on certain state to improve the time and space footprint of data flow.
1 parent 06da00a commit ccbbb17

File tree

6 files changed

+211
-29
lines changed

6 files changed

+211
-29
lines changed

mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
7373
/// may modify the program state; that is, every operation and block.
7474
LogicalResult initialize(Operation *top) override;
7575

76+
/// Initialize lattice anchor equivalence class from the provided top-level
77+
/// operation.
78+
///
79+
/// This function will union lattice anchor to same equivalent class if the
80+
/// analysis can determine the lattice content of lattice anchor is
81+
/// necessarily identical under the corrensponding lattice type.
82+
virtual void initializeEquivalentLatticeAnchor(Operation *top) override;
83+
7684
/// Visit a program point that modifies the state of the program. If the
7785
/// program point is at the beginning of a block, then the state is propagated
7886
/// from control-flow predecessors or callsites. If the operation before
@@ -96,8 +104,8 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
96104
/// dependency. That is, every time the lattice after anchor is updated, the
97105
/// dependent program point must be visited, and the newly triggered visit
98106
/// might update the lattice on dependent.
99-
const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent,
100-
LatticeAnchor anchor);
107+
virtual const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent,
108+
LatticeAnchor anchor) = 0;
101109

102110
/// Set the dense lattice at control flow entry point and propagate an update
103111
/// if it changed.
@@ -114,6 +122,11 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
114122
/// operation transfer function.
115123
virtual LogicalResult processOperation(Operation *op);
116124

125+
/// Visit an operation. If this analysis can confirm that lattice content
126+
/// of lattice anchors around operation are necessarily identical, join
127+
/// them into the same equivalent class.
128+
virtual void buildOperationEquivalentLatticeAnchor(Operation *op) { return; }
129+
117130
/// Propagate the dense lattice forward along the control flow edge from
118131
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
119132
/// values correspond to control flow branches originating at or targeting the
@@ -252,6 +265,15 @@ class DenseForwardDataFlowAnalysis
252265
return getOrCreate<LatticeT>(anchor);
253266
}
254267

268+
/// Get the dense lattice on the given lattice anchor and add dependent as its
269+
/// dependency. That is, every time the lattice after anchor is updated, the
270+
/// dependent program point must be visited, and the newly triggered visit
271+
/// might update the lattice on dependent.
272+
const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent,
273+
LatticeAnchor anchor) override {
274+
return getOrCreateFor<LatticeT>(dependent, anchor);
275+
}
276+
255277
/// Set the dense lattice at control flow entry point and propagate an update
256278
/// if it changed.
257279
virtual void setToEntryState(LatticeT *lattice) = 0;
@@ -310,6 +332,14 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
310332
/// may modify the program state; that is, every operation and block.
311333
LogicalResult initialize(Operation *top) override;
312334

335+
/// Initialize lattice anchor equivalence class from the provided top-level
336+
/// operation.
337+
///
338+
/// This function will union lattice anchor to same equivalent class if the
339+
/// analysis can determine the lattice content of lattice anchor is
340+
/// necessarily identical under the corrensponding lattice type.
341+
virtual void initializeEquivalentLatticeAnchor(Operation *top) override;
342+
313343
/// Visit a program point that modifies the state of the program. The state is
314344
/// propagated along control flow directions for branch-, region- and
315345
/// call-based control flow using the respective interfaces. For other
@@ -336,8 +366,8 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
336366
/// dependency. That is, every time the lattice after anchor is updated, the
337367
/// dependent program point must be visited, and the newly triggered visit
338368
/// might update the lattice before dependent.
339-
const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent,
340-
LatticeAnchor anchor);
369+
virtual const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent,
370+
LatticeAnchor anchor) = 0;
341371

342372
/// Set the dense lattice before at the control flow exit point and propagate
343373
/// the update if it changed.
@@ -353,6 +383,11 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
353383
/// transfer function.
354384
virtual LogicalResult processOperation(Operation *op);
355385

386+
/// Visit an operation. If this analysis can confirm that lattice content
387+
/// of lattice anchors around operation are necessarily identical, join
388+
/// them into the same equivalent class.
389+
virtual void buildOperationEquivalentLatticeAnchor(Operation *op) { return; }
390+
356391
/// Propagate the dense lattice backwards along the control flow edge from
357392
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
358393
/// values correspond to control flow branches originating at or targeting the
@@ -502,6 +537,15 @@ class DenseBackwardDataFlowAnalysis
502537
return getOrCreate<LatticeT>(anchor);
503538
}
504539

540+
/// Get the dense lattice on the given lattice anchor and add dependent as its
541+
/// dependency. That is, every time the lattice after anchor is updated, the
542+
/// dependent program point must be visited, and the newly triggered visit
543+
/// might update the lattice before dependent.
544+
virtual const AbstractDenseLattice *
545+
getLatticeFor(ProgramPoint *dependent, LatticeAnchor anchor) override {
546+
return getOrCreateFor<LatticeT>(dependent, anchor);
547+
}
548+
505549
/// Set the dense lattice at control flow exit point (after the terminator)
506550
/// and propagate an update if it changed.
507551
virtual void setToExitState(LatticeT *lattice) = 0;

mlir/include/mlir/Analysis/DataFlowFramework.h

Lines changed: 115 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "mlir/IR/Operation.h"
2020
#include "mlir/Support/StorageUniquer.h"
21+
#include "llvm/ADT/EquivalenceClasses.h"
2122
#include "llvm/ADT/Hashing.h"
2223
#include "llvm/ADT/SetVector.h"
2324
#include "llvm/Support/Compiler.h"
@@ -265,6 +266,14 @@ struct LatticeAnchor
265266
/// Forward declaration of the data-flow analysis class.
266267
class DataFlowAnalysis;
267268

269+
} // namespace mlir
270+
271+
template <>
272+
struct llvm::DenseMapInfo<mlir::LatticeAnchor>
273+
: public llvm::DenseMapInfo<mlir::LatticeAnchor::ParentTy> {};
274+
275+
namespace mlir {
276+
268277
//===----------------------------------------------------------------------===//
269278
// DataFlowConfig
270279
//===----------------------------------------------------------------------===//
@@ -332,7 +341,9 @@ class DataFlowSolver {
332341
/// does not exist.
333342
template <typename StateT, typename AnchorT>
334343
const StateT *lookupState(AnchorT anchor) const {
335-
const auto &mapIt = analysisStates.find(LatticeAnchor(anchor));
344+
LatticeAnchor latticeAnchor =
345+
getLeaderAnchorOrSelf<StateT>(LatticeAnchor(anchor));
346+
const auto &mapIt = analysisStates.find(latticeAnchor);
336347
if (mapIt == analysisStates.end())
337348
return nullptr;
338349
auto it = mapIt->second.find(TypeID::get<StateT>());
@@ -344,12 +355,34 @@ class DataFlowSolver {
344355
/// Erase any analysis state associated with the given lattice anchor.
345356
template <typename AnchorT>
346357
void eraseState(AnchorT anchor) {
347-
LatticeAnchor la(anchor);
348-
analysisStates.erase(LatticeAnchor(anchor));
358+
LatticeAnchor latticeAnchor(anchor);
359+
360+
// Update equivalentAnchorMap.
361+
for (auto &&[TypeId, eqClass] : equivalentAnchorMap) {
362+
if (!eqClass.contains(latticeAnchor)) {
363+
continue;
364+
}
365+
llvm::EquivalenceClasses<LatticeAnchor>::member_iterator leaderIt =
366+
eqClass.findLeader(latticeAnchor);
367+
368+
// Update analysis states with new leader if needed.
369+
if (*leaderIt == latticeAnchor && ++leaderIt != eqClass.member_end()) {
370+
analysisStates[*leaderIt][TypeId] =
371+
std::move(analysisStates[latticeAnchor][TypeId]);
372+
}
373+
374+
eqClass.erase(latticeAnchor);
375+
}
376+
377+
// Update analysis states.
378+
analysisStates.erase(latticeAnchor);
349379
}
350380

351-
// Erase all analysis states
352-
void eraseAllStates() { analysisStates.clear(); }
381+
/// Erase all analysis states.
382+
void eraseAllStates() {
383+
analysisStates.clear();
384+
equivalentAnchorMap.clear();
385+
}
353386

354387
/// Get a uniqued lattice anchor instance. If one is not present, it is
355388
/// created with the provided arguments.
@@ -399,6 +432,20 @@ class DataFlowSolver {
399432
template <typename StateT, typename AnchorT>
400433
StateT *getOrCreateState(AnchorT anchor);
401434

435+
/// Get leader lattice anchor in equivalence lattice anchor group, return
436+
/// input lattice anchor if input not found in equivalece lattice anchor
437+
/// group.
438+
template <typename StateT>
439+
LatticeAnchor getLeaderAnchorOrSelf(LatticeAnchor latticeAnchor) const;
440+
441+
/// Union input anchors under the given state.
442+
template <typename StateT, typename AnchorT>
443+
void unionLatticeAnchors(AnchorT anchor, AnchorT other);
444+
445+
/// Return given lattice is equivalent on given state.
446+
template <typename StateT>
447+
bool isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const;
448+
402449
/// Propagate an update to an analysis state if it changed by pushing
403450
/// dependent work items to the back of the queue.
404451
/// This should only be used when DataFlowSolver is running.
@@ -429,10 +476,15 @@ class DataFlowSolver {
429476

430477
/// A type-erased map of lattice anchors to associated analysis states for
431478
/// first-class lattice anchors.
432-
DenseMap<LatticeAnchor, DenseMap<TypeID, std::unique_ptr<AnalysisState>>,
433-
DenseMapInfo<LatticeAnchor::ParentTy>>
479+
DenseMap<LatticeAnchor, DenseMap<TypeID, std::unique_ptr<AnalysisState>>>
434480
analysisStates;
435481

482+
/// A map of Ananlysis state type to the equivalent lattice anchors.
483+
/// Lattice anchors are considered equivalent under a certain analysis state
484+
/// type if and only if, the analysis states pointed to by these lattice
485+
/// anchors necessarily contain identical value.
486+
DenseMap<TypeID, llvm::EquivalenceClasses<LatticeAnchor>> equivalentAnchorMap;
487+
436488
/// Allow the base child analysis class to access the internals of the solver.
437489
friend class DataFlowAnalysis;
438490
};
@@ -564,6 +616,14 @@ class DataFlowAnalysis {
564616
/// will provide a value for then.
565617
virtual LogicalResult visit(ProgramPoint *point) = 0;
566618

619+
/// Initialize lattice anchor equivalence class from the provided top-level
620+
/// operation.
621+
///
622+
/// This function will union lattice anchor to same equivalent class if the
623+
/// analysis can determine the lattice content of lattice anchor is
624+
/// necessarily identical under the corrensponding lattice type.
625+
virtual void initializeEquivalentLatticeAnchor(Operation *top) { return; }
626+
567627
protected:
568628
/// Create a dependency between the given analysis state and lattice anchor
569629
/// on this analysis.
@@ -584,6 +644,12 @@ class DataFlowAnalysis {
584644
return solver.getLatticeAnchor<AnchorT>(std::forward<Args>(args)...);
585645
}
586646

647+
/// Union input anchors under the given state.
648+
template <typename StateT, typename AnchorT>
649+
void unionLatticeAnchors(AnchorT anchor, AnchorT other) {
650+
return solver.unionLatticeAnchors<StateT>(anchor, other);
651+
}
652+
587653
/// Get the analysis state associated with the lattice anchor. The returned
588654
/// state is expected to be "write-only", and any updates need to be
589655
/// propagated by `propagateIfChanged`.
@@ -598,7 +664,9 @@ class DataFlowAnalysis {
598664
template <typename StateT, typename AnchorT>
599665
const StateT *getOrCreateFor(ProgramPoint *dependent, AnchorT anchor) {
600666
StateT *state = getOrCreate<StateT>(anchor);
601-
addDependency(state, dependent);
667+
if (!solver.isEquivalent<StateT>(LatticeAnchor(anchor),
668+
LatticeAnchor(dependent)))
669+
addDependency(state, dependent);
602670
return state;
603671
}
604672

@@ -644,10 +712,29 @@ AnalysisT *DataFlowSolver::load(Args &&...args) {
644712
return static_cast<AnalysisT *>(childAnalyses.back().get());
645713
}
646714

715+
template <typename StateT>
716+
LatticeAnchor
717+
DataFlowSolver::getLeaderAnchorOrSelf(LatticeAnchor latticeAnchor) const {
718+
if (!equivalentAnchorMap.contains(TypeID::get<StateT>())) {
719+
return latticeAnchor;
720+
}
721+
const llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
722+
equivalentAnchorMap.at(TypeID::get<StateT>());
723+
llvm::EquivalenceClasses<LatticeAnchor>::member_iterator leaderIt =
724+
eqClass.findLeader(latticeAnchor);
725+
if (leaderIt != eqClass.member_end()) {
726+
return *leaderIt;
727+
}
728+
return latticeAnchor;
729+
}
730+
647731
template <typename StateT, typename AnchorT>
648732
StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) {
733+
// Replace to leader anchor if found.
734+
LatticeAnchor latticeAnchor(anchor);
735+
latticeAnchor = getLeaderAnchorOrSelf<StateT>(latticeAnchor);
649736
std::unique_ptr<AnalysisState> &state =
650-
analysisStates[LatticeAnchor(anchor)][TypeID::get<StateT>()];
737+
analysisStates[latticeAnchor][TypeID::get<StateT>()];
651738
if (!state) {
652739
state = std::unique_ptr<StateT>(new StateT(anchor));
653740
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -657,6 +744,25 @@ StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) {
657744
return static_cast<StateT *>(state.get());
658745
}
659746

747+
template <typename StateT>
748+
bool DataFlowSolver::isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const {
749+
if (!equivalentAnchorMap.contains(TypeID::get<StateT>())) {
750+
return false;
751+
}
752+
const llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
753+
equivalentAnchorMap.at(TypeID::get<StateT>());
754+
if (!eqClass.contains(lhs) || !eqClass.contains(rhs))
755+
return false;
756+
return eqClass.isEquivalent(lhs, rhs);
757+
}
758+
759+
template <typename StateT, typename AnchorT>
760+
void DataFlowSolver::unionLatticeAnchors(AnchorT anchor, AnchorT other) {
761+
llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
762+
equivalentAnchorMap[TypeID::get<StateT>()];
763+
eqClass.unionSets(LatticeAnchor(anchor), LatticeAnchor(other));
764+
}
765+
660766
inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
661767
state.print(os);
662768
return os;

mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ using namespace mlir::dataflow;
2828
// AbstractDenseForwardDataFlowAnalysis
2929
//===----------------------------------------------------------------------===//
3030

31+
void AbstractDenseForwardDataFlowAnalysis::initializeEquivalentLatticeAnchor(
32+
Operation *top) {
33+
top->walk([&](Operation *op) {
34+
if (isa<RegionBranchOpInterface, CallOpInterface>(op))
35+
return;
36+
buildOperationEquivalentLatticeAnchor(op);
37+
});
38+
}
39+
3140
LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) {
3241
// Visit every operation and block.
3342
if (failed(processOperation(top)))
@@ -240,18 +249,19 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
240249
}
241250
}
242251

243-
const AbstractDenseLattice *
244-
AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint *dependent,
245-
LatticeAnchor anchor) {
246-
AbstractDenseLattice *state = getLattice(anchor);
247-
addDependency(state, dependent);
248-
return state;
249-
}
250-
251252
//===----------------------------------------------------------------------===//
252253
// AbstractDenseBackwardDataFlowAnalysis
253254
//===----------------------------------------------------------------------===//
254255

256+
void AbstractDenseBackwardDataFlowAnalysis::initializeEquivalentLatticeAnchor(
257+
Operation *top) {
258+
top->walk([&](Operation *op) {
259+
if (isa<RegionBranchOpInterface, CallOpInterface>(op))
260+
return;
261+
buildOperationEquivalentLatticeAnchor(op);
262+
});
263+
}
264+
255265
LogicalResult
256266
AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) {
257267
// Visit every operation and block.
@@ -455,11 +465,3 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
455465
before);
456466
}
457467
}
458-
459-
const AbstractDenseLattice *
460-
AbstractDenseBackwardDataFlowAnalysis::getLatticeFor(ProgramPoint *dependent,
461-
LatticeAnchor anchor) {
462-
AbstractDenseLattice *state = getLattice(anchor);
463-
addDependency(state, dependent);
464-
return state;
465-
}

mlir/lib/Analysis/DataFlowFramework.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
109109
isRunning = true;
110110
auto guard = llvm::make_scope_exit([&]() { isRunning = false; });
111111

112+
// Initialize equivalent lattice anchors.
113+
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
114+
analysis.initializeEquivalentLatticeAnchor(top);
115+
}
116+
112117
// Initialize the analyses.
113118
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
114119
DATAFLOW_DEBUG(llvm::dbgs()

0 commit comments

Comments
 (0)