Skip to content

Commit a32846b

Browse files
committed
Refactor the architecture of mlir-reduce
Add iterator for ReductionNode traversal and use range to indicate the region we would like to keep. Refactor the interaction between Pass/Tester/ReductionNode. Now it'll be easier to add new traversal type and OpReducer Reviewed By: jpienaar, rriddle Differential Revision: https://reviews.llvm.org/D99713
1 parent 72142b9 commit a32846b

File tree

13 files changed

+408
-583
lines changed

13 files changed

+408
-583
lines changed

mlir/include/mlir/Reducer/Passes/OpReducer.h

Lines changed: 30 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,65 +15,52 @@
1515
#ifndef MLIR_REDUCER_PASSES_OPREDUCER_H
1616
#define MLIR_REDUCER_PASSES_OPREDUCER_H
1717

18-
#include "mlir/IR/Region.h"
18+
#include <limits>
19+
1920
#include "mlir/Reducer/ReductionNode.h"
20-
#include "mlir/Reducer/ReductionTreeUtils.h"
2121
#include "mlir/Reducer/Tester.h"
2222

2323
namespace mlir {
2424

25-
class OpReducerImpl {
25+
class OpReducer {
2626
public:
27-
OpReducerImpl(
28-
llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps);
29-
30-
/// Return the name of this reducer class.
31-
StringRef getName();
32-
33-
/// Return the initial transformSpace containing the transformable indices.
34-
std::vector<bool> initTransformSpace(ModuleOp module);
35-
36-
/// Generate variants by removing OpType operations from the module in the
37-
/// parent and link the variants as childs in the Reduction Tree Pass.
38-
void generateVariants(ReductionNode *parent, const Tester &test,
39-
int numVariants);
40-
41-
/// Generate variants by removing OpType operations from the module in the
42-
/// parent and link the variants as childs in the Reduction Tree Pass. The
43-
/// transform argument defines the function used to remove the OpTpye
44-
/// operations in range of indexed OpType operations.
45-
void generateVariants(ReductionNode *parent, const Tester &test,
46-
int numVariants,
47-
llvm::function_ref<void(ModuleOp, int, int)> transform);
48-
49-
private:
50-
llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps;
27+
virtual ~OpReducer() = default;
28+
/// According to rangeToKeep, try to reduce the given module. We implicitly
29+
/// number each interesting operation and rangeToKeep indicates that if an
30+
/// operation's number falls into certain range, then we will not try to
31+
/// reduce that operation.
32+
virtual void reduce(ModuleOp module,
33+
ArrayRef<ReductionNode::Range> rangeToKeep) = 0;
34+
/// Return the number of certain kind of operations that we would like to
35+
/// reduce. This can be used to build a range map to exclude uninterested
36+
/// operations.
37+
virtual int getNumTargetOps(ModuleOp module) const = 0;
5138
};
5239

53-
/// The OpReducer class defines a variant generator method that produces
54-
/// multiple variants by eliminating different OpType operations from the
55-
/// parent module.
40+
/// Reducer is a helper class to remove potential uninteresting operations from
41+
/// module.
5642
template <typename OpType>
57-
class OpReducer {
43+
class Reducer : public OpReducer {
5844
public:
59-
OpReducer() : impl(new OpReducerImpl(getSpecificOps)) {}
45+
~Reducer() override = default;
6046

61-
/// Returns the vector of pointer to the OpType operations in the module.
62-
static std::vector<Operation *> getSpecificOps(ModuleOp module) {
63-
std::vector<Operation *> ops;
64-
for (auto op : module.getOps<OpType>()) {
65-
ops.push_back(op);
66-
}
67-
return ops;
47+
int getNumTargetOps(ModuleOp module) const override {
48+
return std::distance(module.getOps<OpType>().begin(),
49+
module.getOps<OpType>().end());
6850
}
6951

70-
/// Deletes the OpType operations in the module in the specified index.
71-
static void deleteOps(ModuleOp module, int start, int end) {
52+
void reduce(ModuleOp module,
53+
ArrayRef<ReductionNode::Range> rangeToKeep) override {
7254
std::vector<Operation *> opsToRemove;
55+
size_t keepIndex = 0;
7356

74-
for (auto op : enumerate(getSpecificOps(module))) {
57+
for (auto op : enumerate(module.getOps<OpType>())) {
7558
int index = op.index();
76-
if (index >= start && index < end)
59+
if (keepIndex < rangeToKeep.size() &&
60+
index == rangeToKeep[keepIndex].second)
61+
++keepIndex;
62+
if (keepIndex == rangeToKeep.size() ||
63+
index < rangeToKeep[keepIndex].first)
7764
opsToRemove.push_back(op.value());
7865
}
7966

@@ -82,24 +69,6 @@ class OpReducer {
8269
o->erase();
8370
}
8471
}
85-
86-
/// Return the name of this reducer class.
87-
StringRef getName() { return impl->getName(); }
88-
89-
/// Return the initial transformSpace containing the transformable indices.
90-
std::vector<bool> initTransformSpace(ModuleOp module) {
91-
return impl->initTransformSpace(module);
92-
}
93-
94-
/// Generate variants by removing OpType operations from the module in the
95-
/// parent and link the variants as childs in the Reduction Tree Pass.
96-
void generateVariants(ReductionNode *parent, const Tester &test,
97-
int numVariants) {
98-
impl->generateVariants(parent, test, numVariants, deleteOps);
99-
}
100-
101-
private:
102-
std::unique_ptr<OpReducerImpl> impl;
10372
};
10473

10574
} // end namespace mlir

mlir/include/mlir/Reducer/ReductionNode.h

Lines changed: 101 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,82 +17,129 @@
1717
#ifndef MLIR_REDUCER_REDUCTIONNODE_H
1818
#define MLIR_REDUCER_REDUCTIONNODE_H
1919

20+
#include <queue>
2021
#include <vector>
2122

2223
#include "mlir/Reducer/Tester.h"
24+
#include "llvm/Support/Allocator.h"
2325
#include "llvm/Support/ToolOutputFile.h"
2426

2527
namespace mlir {
2628

27-
/// This class defines the ReductionNode which is used to wrap the module of
28-
/// a generated variant and keep track of the necessary metadata for the
29-
/// reduction pass. The nodes are linked together in a reduction tree structure
30-
/// which defines the relationship between all the different generated variants.
29+
/// Defines the traversal method options to be used in the reduction tree
30+
/// traversal.
31+
enum TraversalMode { SinglePath, Backtrack, MultiPath };
32+
33+
/// This class defines the ReductionNode which is used to generate variant and
34+
/// keep track of the necessary metadata for the reduction pass. The nodes are
35+
/// linked together in a reduction tree structure which defines the relationship
36+
/// between all the different generated variants.
3137
class ReductionNode {
3238
public:
33-
ReductionNode(ModuleOp module, ReductionNode *parent);
34-
35-
ReductionNode(ModuleOp module, ReductionNode *parent,
36-
std::vector<bool> transformSpace);
39+
template <TraversalMode mode>
40+
class iterator;
3741

38-
/// Calculates and initializes the size and interesting values of the node.
39-
void measureAndTest(const Tester &test);
42+
using Range = std::pair<int, int>;
4043

41-
/// Returns the module.
42-
ModuleOp getModule() const { return module; }
44+
ReductionNode(ReductionNode *parent, std::vector<Range> range,
45+
llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator);
4346

44-
/// Returns true if the size and interestingness have been calculated.
45-
bool isEvaluated() const;
47+
ReductionNode *getParent() const;
4648

47-
/// Returns the size in bytes of the module.
48-
int getSize() const;
49+
size_t getSize() const;
4950

5051
/// Returns true if the module exhibits the interesting behavior.
51-
bool isInteresting() const;
52-
53-
/// Returns the pointer to a child variant by index.
54-
ReductionNode *getVariant(unsigned long index) const;
52+
Tester::Interestingness isInteresting() const;
5553

56-
/// Returns the number of child variants.
57-
int variantsSize() const;
54+
std::vector<Range> getRanges() const;
5855

59-
/// Returns true if the vector containing the child variants is empty.
60-
bool variantsEmpty() const;
56+
std::vector<ReductionNode *> &getVariants();
6157

62-
/// Sort the child variants and remove the uninteresting ones.
63-
void organizeVariants(const Tester &test);
58+
/// Split the ranges and generate new variants.
59+
std::vector<ReductionNode *> generateNewVariants();
6460

65-
/// Returns the number of child variants.
66-
int transformSpaceSize();
67-
68-
/// Returns a vector indicating the transformed indices as true.
69-
const std::vector<bool> getTransformSpace();
61+
/// Update the interestingness result from tester.
62+
void update(std::pair<Tester::Interestingness, size_t> result);
7063

7164
private:
72-
/// Link a child variant node.
73-
void linkVariant(ReductionNode *newVariant);
74-
75-
// This is the MLIR module of this variant.
76-
ModuleOp module;
77-
78-
// This is true if the module has been evaluated and it exhibits the
79-
// interesting behavior.
80-
bool interesting;
81-
82-
// This indicates the number of characters in the printed module if the module
83-
// has been evaluated.
84-
int size;
85-
86-
// This indicates if the module has been evaluated (measured and tested).
87-
bool evaluated;
88-
89-
// Indicates the indices in the node that have been transformed in previous
90-
// levels of the reduction tree.
91-
std::vector<bool> transformSpace;
65+
/// A custom BFS iterator. The difference between
66+
/// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
67+
/// We may explore more neighbors at certain node if we didn't find interested
68+
/// event. As a result, we defer pushing adjacent nodes until poping the last
69+
/// visited node. The graph exploration strategy will be put in
70+
/// getNeighbors().
71+
///
72+
/// Subclass BaseIterator and implement traversal strategy in getNeighbors().
73+
template <typename T>
74+
class BaseIterator {
75+
public:
76+
BaseIterator(ReductionNode *node) { visitQueue.push(node); }
77+
BaseIterator(const BaseIterator &) = default;
78+
BaseIterator() = default;
79+
80+
static BaseIterator end() { return BaseIterator(); }
81+
82+
bool operator==(const BaseIterator &i) {
83+
return visitQueue == i.visitQueue;
84+
}
85+
bool operator!=(const BaseIterator &i) { return !(*this == i); }
86+
87+
BaseIterator &operator++() {
88+
ReductionNode *top = visitQueue.front();
89+
visitQueue.pop();
90+
std::vector<ReductionNode *> neighbors = getNeighbors(top);
91+
for (ReductionNode *node : neighbors)
92+
visitQueue.push(node);
93+
return *this;
94+
}
95+
96+
BaseIterator operator++(int) {
97+
BaseIterator tmp = *this;
98+
++*this;
99+
return tmp;
100+
}
101+
102+
ReductionNode &operator*() const { return *(visitQueue.front()); }
103+
ReductionNode *operator->() const { return visitQueue.front(); }
104+
105+
protected:
106+
std::vector<ReductionNode *> getNeighbors(ReductionNode *node) {
107+
return static_cast<T *>(this)->getNeighbors(node);
108+
}
109+
110+
private:
111+
std::queue<ReductionNode *> visitQueue;
112+
};
113+
114+
/// The size of module after applying the range constraints.
115+
size_t size;
116+
117+
/// This is true if the module has been evaluated and it exhibits the
118+
/// interesting behavior.
119+
Tester::Interestingness interesting;
120+
121+
ReductionNode *parent;
122+
123+
/// We will only keep the operation with index falls into the ranges.
124+
/// For example, number each function in a certain module and then we will
125+
/// remove the functions with index outside the ranges and see if the
126+
/// resulting module is still interesting.
127+
std::vector<Range> ranges;
128+
129+
/// This points to the child variants that were created using this node as a
130+
/// starting point.
131+
std::vector<ReductionNode *> variants;
132+
133+
llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator;
134+
};
92135

93-
// This points to the child variants that were created using this node as a
94-
// starting point.
95-
std::vector<std::unique_ptr<ReductionNode>> variants;
136+
// Specialized iterator for SinglePath traversal
137+
template <>
138+
class ReductionNode::iterator<SinglePath>
139+
: public BaseIterator<iterator<SinglePath>> {
140+
friend BaseIterator<iterator<SinglePath>>;
141+
using BaseIterator::BaseIterator;
142+
std::vector<ReductionNode *> getNeighbors(ReductionNode *node);
96143
};
97144

98145
} // end namespace mlir

0 commit comments

Comments
 (0)