Skip to content

Commit f942d34

Browse files
committed
Improve MLIR-Query by adding matcher combinators
Limit backward-slice with nested matching Add variadic operators Add test cases Add test cases for variadic matchers Relocate variadic matchers use signed for arithemtic & avoid copy Add verifier check for extract function Add slicing function extraction test; improve documentation; use lowercase for errors
1 parent 01f9dff commit f942d34

15 files changed

+471
-17
lines changed

mlir/include/mlir/Query/Matcher/Marshallers.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ class MatcherDescriptor {
108108
const llvm::ArrayRef<ParserValue> args,
109109
Diagnostics *error) const = 0;
110110

111+
// If the matcher is variadic, it can take any number of arguments.
112+
virtual bool isVariadic() const = 0;
113+
111114
// Returns the number of arguments accepted by the matcher.
112115
virtual unsigned getNumArgs() const = 0;
113116

@@ -140,6 +143,8 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
140143
return marshaller(matcherFunc, matcherName, nameRange, args, error);
141144
}
142145

146+
bool isVariadic() const override { return false; }
147+
143148
unsigned getNumArgs() const override { return argKinds.size(); }
144149

145150
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -153,6 +158,54 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
153158
const std::vector<ArgKind> argKinds;
154159
};
155160

161+
class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
162+
public:
163+
using VarOp = DynMatcher::VariadicOperator;
164+
VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
165+
VarOp varOp, StringRef matcherName)
166+
: minCount(minCount), maxCount(maxCount), varOp(varOp),
167+
matcherName(matcherName) {}
168+
169+
VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
170+
Diagnostics *error) const override {
171+
if (args.size() < minCount || maxCount < args.size()) {
172+
addError(error, nameRange, ErrorType::RegistryWrongArgCount,
173+
{llvm::Twine("requires between "), llvm::Twine(minCount),
174+
llvm::Twine(" and "), llvm::Twine(maxCount),
175+
llvm::Twine(" args, got "), llvm::Twine(args.size())});
176+
return VariantMatcher();
177+
}
178+
179+
std::vector<VariantMatcher> innerArgs;
180+
for (int64_t i = 0, e = args.size(); i != e; ++i) {
181+
const ParserValue &arg = args[i];
182+
const VariantValue &value = arg.value;
183+
if (!value.isMatcher()) {
184+
addError(error, arg.range, ErrorType::RegistryWrongArgType,
185+
{llvm::Twine(i + 1), llvm::Twine("matcher: "),
186+
llvm::Twine(value.getTypeAsString())});
187+
return VariantMatcher();
188+
}
189+
innerArgs.push_back(value.getMatcher());
190+
}
191+
return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
192+
}
193+
194+
bool isVariadic() const override { return true; }
195+
196+
unsigned getNumArgs() const override { return 0; }
197+
198+
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
199+
kinds.push_back(ArgKind(ArgKind::Matcher));
200+
}
201+
202+
private:
203+
const unsigned minCount;
204+
const unsigned maxCount;
205+
const VarOp varOp;
206+
const StringRef matcherName;
207+
};
208+
156209
// Helper function to check if argument count matches expected count
157210
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
158211
llvm::ArrayRef<ParserValue> args,
@@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
224277
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
225278
}
226279

280+
// Variadic operator overload.
281+
template <unsigned MinCount, unsigned MaxCount>
282+
std::unique_ptr<MatcherDescriptor>
283+
makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
284+
StringRef matcherName) {
285+
return std::make_unique<VariadicOperatorMatcherDescriptor>(
286+
MinCount, MaxCount, func.varOp, matcherName);
287+
}
227288
} // namespace mlir::query::matcher::internal
228289

229290
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H

mlir/include/mlir/Query/Matcher/MatchFinder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222
namespace mlir::query::matcher {
2323

24-
/// A class that provides utilities to find operations in the IR.
24+
/// Finds and collects matches from the IR. After construction
25+
/// `collectMatches` can be used to traverse the IR and apply
26+
/// matchers.
2527
class MatchFinder {
2628

2729
public:

mlir/include/mlir/Query/Matcher/MatchersInternal.h

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
//
99
// Implements the base layer of the matcher framework.
1010
//
11-
// Matchers are methods that return a Matcher which provides a method one of the
12-
// following methods: match(Operation *op), match(Operation *op,
13-
// SetVector<Operation *> &matchedOps)
11+
// Matchers are methods that return a Matcher which provides a
12+
// `match(...)` method whose parameters define the context of the match.
13+
// Support includes simple (unary) matchers as well as matcher combinators
14+
// (anyOf, allOf, etc.)
1415
//
15-
// The matcher functions are defined in include/mlir/IR/Matchers.h.
1616
// This file contains the wrapper classes needed to construct matchers for
1717
// mlir-query.
1818
//
@@ -25,6 +25,15 @@
2525
#include "llvm/ADT/IntrusiveRefCntPtr.h"
2626

2727
namespace mlir::query::matcher {
28+
class DynMatcher;
29+
namespace internal {
30+
31+
bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
32+
ArrayRef<DynMatcher> innerMatchers);
33+
bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
34+
ArrayRef<DynMatcher> innerMatchers);
35+
36+
} // namespace internal
2837

2938
// Defaults to false if T has no match() method with the signature:
3039
// match(Operation* op).
@@ -84,6 +93,27 @@ class MatcherFnImpl : public MatcherInterface {
8493
MatcherFn matcherFn;
8594
};
8695

96+
// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
97+
// match the given operation.
98+
using VariadicOperatorFunction = bool (*)(Operation *op,
99+
SetVector<Operation *> *matchedOps,
100+
ArrayRef<DynMatcher> innerMatchers);
101+
102+
template <VariadicOperatorFunction Func>
103+
class VariadicMatcher : public MatcherInterface {
104+
public:
105+
VariadicMatcher(std::vector<DynMatcher> matchers)
106+
: matchers(std::move(matchers)) {}
107+
108+
bool match(Operation *op) override { return Func(op, nullptr, matchers); }
109+
bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
110+
return Func(op, &matchedOps, matchers);
111+
}
112+
113+
private:
114+
std::vector<DynMatcher> matchers;
115+
};
116+
87117
// Matcher wraps a MatcherInterface implementation and provides match()
88118
// methods that redirect calls to the underlying implementation.
89119
class DynMatcher {
@@ -92,6 +122,31 @@ class DynMatcher {
92122
DynMatcher(MatcherInterface *implementation)
93123
: implementation(implementation) {}
94124

125+
// Construct from a variadic function.
126+
enum VariadicOperator {
127+
// Matches operations for which all provided matchers match.
128+
AllOf,
129+
// Matches operations for which at least one of the provided matchers
130+
// matches.
131+
AnyOf
132+
};
133+
134+
static std::unique_ptr<DynMatcher>
135+
constructVariadic(VariadicOperator Op,
136+
std::vector<DynMatcher> innerMatchers) {
137+
switch (Op) {
138+
case AllOf:
139+
return std::make_unique<DynMatcher>(
140+
new VariadicMatcher<internal::allOfVariadicOperator>(
141+
std::move(innerMatchers)));
142+
case AnyOf:
143+
return std::make_unique<DynMatcher>(
144+
new VariadicMatcher<internal::anyOfVariadicOperator>(
145+
std::move(innerMatchers)));
146+
}
147+
llvm_unreachable("Invalid Op value.");
148+
}
149+
95150
template <typename MatcherFn>
96151
static std::unique_ptr<DynMatcher>
97152
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -113,6 +168,59 @@ class DynMatcher {
113168
std::string functionName;
114169
};
115170

171+
// VariadicOperatorMatcher related types.
172+
template <typename... Ps>
173+
class VariadicOperatorMatcher {
174+
public:
175+
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
176+
: varOp(varOp), params(std::forward<Ps>(params)...) {}
177+
178+
operator std::unique_ptr<DynMatcher>() const & {
179+
return DynMatcher::constructVariadic(
180+
varOp, getMatchers(std::index_sequence_for<Ps...>()));
181+
}
182+
183+
operator std::unique_ptr<DynMatcher>() && {
184+
return DynMatcher::constructVariadic(
185+
varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
186+
}
187+
188+
private:
189+
// Helper method to unpack the tuple into a vector.
190+
template <std::size_t... Is>
191+
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
192+
return {DynMatcher(std::get<Is>(params))...};
193+
}
194+
195+
template <std::size_t... Is>
196+
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
197+
return {DynMatcher(std::get<Is>(std::move(params)))...};
198+
}
199+
200+
const DynMatcher::VariadicOperator varOp;
201+
std::tuple<Ps...> params;
202+
};
203+
204+
// Overloaded function object to generate VariadicOperatorMatcher objects from
205+
// arbitrary matchers.
206+
template <unsigned MinCount, unsigned MaxCount>
207+
struct VariadicOperatorMatcherFunc {
208+
DynMatcher::VariadicOperator varOp;
209+
210+
template <typename... Ms>
211+
VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
212+
static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
213+
"invalid number of parameters for variadic matcher");
214+
return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
215+
}
216+
};
217+
218+
namespace internal {
219+
const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
220+
anyOf = {DynMatcher::AnyOf};
221+
const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
222+
allOf = {DynMatcher::AllOf};
223+
} // namespace internal
116224
} // namespace mlir::query::matcher
117225

118226
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H

mlir/include/mlir/Query/Matcher/SliceMatchers.h

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file provides matchers for MLIRQuery that peform slicing analysis
9+
// This file defines slicing-analysis matchers that extend and abstract the
10+
// core implementations from `SliceAnalysis.h`.
1011
//
1112
//===----------------------------------------------------------------------===//
1213

@@ -16,9 +17,9 @@
1617
#include "mlir/Analysis/SliceAnalysis.h"
1718
#include "mlir/IR/Operation.h"
1819

19-
/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
20-
/// Additionally, it limits the slice computation to a certain depth level using
21-
/// a custom filter.
20+
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
21+
/// if `innerMatcher` matches. The traversal stops once the desired depth level
22+
/// is reached.
2223
///
2324
/// Example: starting from node 9, assuming the matcher
2425
/// computes the slice for the first two depth levels:
@@ -119,6 +120,77 @@ bool BackwardSliceMatcher<Matcher>::matches(
119120
: backwardSlice.size() >= 1;
120121
}
121122

123+
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
124+
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
125+
template <typename BaseMatcher, typename Filter>
126+
class PredicateBackwardSliceMatcher {
127+
public:
128+
PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
129+
bool inclusive, bool omitBlockArguments,
130+
bool omitUsesFromAbove)
131+
: innerMatcher(std::move(innerMatcher)),
132+
filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
133+
omitBlockArguments(omitBlockArguments),
134+
omitUsesFromAbove(omitUsesFromAbove) {}
135+
136+
bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
137+
backwardSlice.clear();
138+
BackwardSliceOptions options;
139+
options.inclusive = inclusive;
140+
options.omitUsesFromAbove = omitUsesFromAbove;
141+
options.omitBlockArguments = omitBlockArguments;
142+
if (innerMatcher.match(rootOp)) {
143+
options.filter = [&](Operation *subOp) {
144+
return !filterMatcher.match(subOp);
145+
};
146+
LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
147+
assert(result.succeeded() && "expected backward slice to succeed");
148+
(void)result;
149+
return options.inclusive ? backwardSlice.size() > 1
150+
: backwardSlice.size() >= 1;
151+
}
152+
return false;
153+
}
154+
155+
private:
156+
BaseMatcher innerMatcher;
157+
Filter filterMatcher;
158+
bool inclusive;
159+
bool omitBlockArguments;
160+
bool omitUsesFromAbove;
161+
};
162+
163+
/// Computes the forward-slice of all users reachable from `rootOp`,
164+
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
165+
template <typename BaseMatcher, typename Filter>
166+
class PredicateForwardSliceMatcher {
167+
public:
168+
PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
169+
bool inclusive)
170+
: innerMatcher(std::move(innerMatcher)),
171+
filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
172+
173+
bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
174+
forwardSlice.clear();
175+
ForwardSliceOptions options;
176+
options.inclusive = inclusive;
177+
if (innerMatcher.match(rootOp)) {
178+
options.filter = [&](Operation *subOp) {
179+
return !filterMatcher.match(subOp);
180+
};
181+
getForwardSlice(rootOp, &forwardSlice, options);
182+
return options.inclusive ? forwardSlice.size() > 1
183+
: forwardSlice.size() >= 1;
184+
}
185+
return false;
186+
}
187+
188+
private:
189+
BaseMatcher innerMatcher;
190+
Filter filterMatcher;
191+
bool inclusive;
192+
};
193+
122194
/// Matches transitive defs of a top-level operation up to N levels.
123195
template <typename Matcher>
124196
inline BackwardSliceMatcher<Matcher>
@@ -130,7 +202,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
130202
omitUsesFromAbove);
131203
}
132204

133-
/// Matches all transitive defs of a top-level operation up to N levels
205+
/// Matches all transitive defs of a top-level operation up to N levels.
134206
template <typename Matcher>
135207
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
136208
int64_t maxDepth) {
@@ -139,6 +211,28 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
139211
false, false);
140212
}
141213

214+
/// Matches all transitive defs of a top-level operation and stops where
215+
/// `filterMatcher` rejects.
216+
template <typename BaseMatcher, typename Filter>
217+
inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
218+
m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
219+
bool inclusive, bool omitBlockArguments,
220+
bool omitUsesFromAbove) {
221+
return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
222+
std::move(innerMatcher), std::move(filterMatcher), inclusive,
223+
omitBlockArguments, omitUsesFromAbove);
224+
}
225+
226+
/// Matches all users of a top-level operation and stops where
227+
/// `filterMatcher` rejects.
228+
template <typename BaseMatcher, typename Filter>
229+
inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
230+
m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
231+
bool inclusive) {
232+
return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
233+
std::move(innerMatcher), std::move(filterMatcher), inclusive);
234+
}
235+
142236
} // namespace mlir::query::matcher
143237

144238
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H

0 commit comments

Comments
 (0)