Skip to content

[mlir] Improve mlir-query by adding matcher combinators #141423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions mlir/include/mlir/Query/Matcher/Marshallers.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ class MatcherDescriptor {
const llvm::ArrayRef<ParserValue> args,
Diagnostics *error) const = 0;

// If the matcher is variadic, it can take any number of arguments.
virtual bool isVariadic() const = 0;

// Returns the number of arguments accepted by the matcher.
virtual unsigned getNumArgs() const = 0;

Expand Down Expand Up @@ -140,6 +143,8 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
return marshaller(matcherFunc, matcherName, nameRange, args, error);
}

bool isVariadic() const override { return false; }

unsigned getNumArgs() const override { return argKinds.size(); }

void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
Expand All @@ -153,6 +158,54 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
const std::vector<ArgKind> argKinds;
};

class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
public:
using VarOp = DynMatcher::VariadicOperator;
VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
VarOp varOp, StringRef matcherName)
: minCount(minCount), maxCount(maxCount), varOp(varOp),
matcherName(matcherName) {}

VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
Diagnostics *error) const override {
if (args.size() < minCount || maxCount < args.size()) {
addError(error, nameRange, ErrorType::RegistryWrongArgCount,
{llvm::Twine("requires between "), llvm::Twine(minCount),
llvm::Twine(" and "), llvm::Twine(maxCount),
llvm::Twine(" args, got "), llvm::Twine(args.size())});
return VariantMatcher();
}

std::vector<VariantMatcher> innerArgs;
for (int64_t i = 0, e = args.size(); i != e; ++i) {
const ParserValue &arg = args[i];
const VariantValue &value = arg.value;
if (!value.isMatcher()) {
addError(error, arg.range, ErrorType::RegistryWrongArgType,
{llvm::Twine(i + 1), llvm::Twine("matcher: "),
llvm::Twine(value.getTypeAsString())});
return VariantMatcher();
}
innerArgs.push_back(value.getMatcher());
}
return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
}

bool isVariadic() const override { return true; }

unsigned getNumArgs() const override { return 0; }

void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
kinds.push_back(ArgKind(ArgKind::Matcher));
}

private:
const unsigned minCount;
const unsigned maxCount;
const VarOp varOp;
const StringRef matcherName;
};

// Helper function to check if argument count matches expected count
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
llvm::ArrayRef<ParserValue> args,
Expand Down Expand Up @@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
}

// Variadic operator overload.
template <unsigned MinCount, unsigned MaxCount>
std::unique_ptr<MatcherDescriptor>
makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
StringRef matcherName) {
return std::make_unique<VariadicOperatorMatcherDescriptor>(
MinCount, MaxCount, func.varOp, matcherName);
}
} // namespace mlir::query::matcher::internal

#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
4 changes: 3 additions & 1 deletion mlir/include/mlir/Query/Matcher/MatchFinder.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

namespace mlir::query::matcher {

/// A class that provides utilities to find operations in the IR.
/// Finds and collects matches from the IR. After construction
/// `collectMatches` can be used to traverse the IR and apply
/// matchers.
class MatchFinder {

public:
Expand Down
116 changes: 112 additions & 4 deletions mlir/include/mlir/Query/Matcher/MatchersInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
//
// Implements the base layer of the matcher framework.
//
// Matchers are methods that return a Matcher which provides a method one of the
// following methods: match(Operation *op), match(Operation *op,
// SetVector<Operation *> &matchedOps)
// Matchers are methods that return a Matcher which provides a
// `match(...)` method whose parameters define the context of the match.
// Support includes simple (unary) matchers as well as matcher combinators
// (anyOf, allOf, etc.)
//
// The matcher functions are defined in include/mlir/IR/Matchers.h.
// This file contains the wrapper classes needed to construct matchers for
// mlir-query.
//
Expand All @@ -25,6 +25,15 @@
#include "llvm/ADT/IntrusiveRefCntPtr.h"

namespace mlir::query::matcher {
class DynMatcher;
namespace internal {

bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
ArrayRef<DynMatcher> innerMatchers);
bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
ArrayRef<DynMatcher> innerMatchers);

} // namespace internal

// Defaults to false if T has no match() method with the signature:
// match(Operation* op).
Expand Down Expand Up @@ -84,6 +93,27 @@ class MatcherFnImpl : public MatcherInterface {
MatcherFn matcherFn;
};

// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
// match the given operation.
using VariadicOperatorFunction = bool (*)(Operation *op,
SetVector<Operation *> *matchedOps,
ArrayRef<DynMatcher> innerMatchers);

template <VariadicOperatorFunction Func>
class VariadicMatcher : public MatcherInterface {
public:
VariadicMatcher(std::vector<DynMatcher> matchers)
: matchers(std::move(matchers)) {}

bool match(Operation *op) override { return Func(op, nullptr, matchers); }
bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
return Func(op, &matchedOps, matchers);
}

private:
std::vector<DynMatcher> matchers;
};

// Matcher wraps a MatcherInterface implementation and provides match()
// methods that redirect calls to the underlying implementation.
class DynMatcher {
Expand All @@ -92,6 +122,31 @@ class DynMatcher {
DynMatcher(MatcherInterface *implementation)
: implementation(implementation) {}

// Construct from a variadic function.
enum VariadicOperator {
// Matches operations for which all provided matchers match.
AllOf,
// Matches operations for which at least one of the provided matchers
// matches.
AnyOf
};

static std::unique_ptr<DynMatcher>
constructVariadic(VariadicOperator Op,
std::vector<DynMatcher> innerMatchers) {
switch (Op) {
case AllOf:
return std::make_unique<DynMatcher>(
new VariadicMatcher<internal::allOfVariadicOperator>(
std::move(innerMatchers)));
case AnyOf:
return std::make_unique<DynMatcher>(
new VariadicMatcher<internal::anyOfVariadicOperator>(
std::move(innerMatchers)));
}
llvm_unreachable("Invalid Op value.");
}

template <typename MatcherFn>
static std::unique_ptr<DynMatcher>
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
Expand All @@ -113,6 +168,59 @@ class DynMatcher {
std::string functionName;
};

// VariadicOperatorMatcher related types.
template <typename... Ps>
class VariadicOperatorMatcher {
public:
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
: varOp(varOp), params(std::forward<Ps>(params)...) {}

operator std::unique_ptr<DynMatcher>() const & {
return DynMatcher::constructVariadic(
varOp, getMatchers(std::index_sequence_for<Ps...>()));
}

operator std::unique_ptr<DynMatcher>() && {
return DynMatcher::constructVariadic(
varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
}

private:
// Helper method to unpack the tuple into a vector.
template <std::size_t... Is>
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
return {DynMatcher(std::get<Is>(params))...};
}

template <std::size_t... Is>
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
return {DynMatcher(std::get<Is>(std::move(params)))...};
}

const DynMatcher::VariadicOperator varOp;
std::tuple<Ps...> params;
};

// Overloaded function object to generate VariadicOperatorMatcher objects from
// arbitrary matchers.
template <unsigned MinCount, unsigned MaxCount>
struct VariadicOperatorMatcherFunc {
DynMatcher::VariadicOperator varOp;

template <typename... Ms>
VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
"invalid number of parameters for variadic matcher");
return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
}
};

namespace internal {
const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
anyOf = {DynMatcher::AnyOf};
const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
allOf = {DynMatcher::AllOf};
} // namespace internal
} // namespace mlir::query::matcher

#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
104 changes: 99 additions & 5 deletions mlir/include/mlir/Query/Matcher/SliceMatchers.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
//
//===----------------------------------------------------------------------===//
//
// This file provides matchers for MLIRQuery that peform slicing analysis
// This file defines slicing-analysis matchers that extend and abstract the
// core implementations from `SliceAnalysis.h`.
//
//===----------------------------------------------------------------------===//

Expand All @@ -16,9 +17,9 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Operation.h"

/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
/// Additionally, it limits the slice computation to a certain depth level using
/// a custom filter.
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
/// if `innerMatcher` matches. The traversal stops once the desired depth level
/// is reached.
///
/// Example: starting from node 9, assuming the matcher
/// computes the slice for the first two depth levels:
Expand Down Expand Up @@ -119,6 +120,77 @@ bool BackwardSliceMatcher<Matcher>::matches(
: backwardSlice.size() >= 1;
}

/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
template <typename BaseMatcher, typename Filter>
class PredicateBackwardSliceMatcher {
public:
PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
bool inclusive, bool omitBlockArguments,
bool omitUsesFromAbove)
: innerMatcher(std::move(innerMatcher)),
filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
omitBlockArguments(omitBlockArguments),
omitUsesFromAbove(omitUsesFromAbove) {}

bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
backwardSlice.clear();
BackwardSliceOptions options;
options.inclusive = inclusive;
options.omitUsesFromAbove = omitUsesFromAbove;
options.omitBlockArguments = omitBlockArguments;
if (innerMatcher.match(rootOp)) {
options.filter = [&](Operation *subOp) {
return !filterMatcher.match(subOp);
};
LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
assert(result.succeeded() && "expected backward slice to succeed");
(void)result;
return options.inclusive ? backwardSlice.size() > 1
: backwardSlice.size() >= 1;
}
return false;
}

private:
BaseMatcher innerMatcher;
Filter filterMatcher;
bool inclusive;
bool omitBlockArguments;
bool omitUsesFromAbove;
};

/// Computes the forward-slice of all users reachable from `rootOp`,
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
template <typename BaseMatcher, typename Filter>
class PredicateForwardSliceMatcher {
public:
PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
bool inclusive)
: innerMatcher(std::move(innerMatcher)),
filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}

bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
forwardSlice.clear();
ForwardSliceOptions options;
options.inclusive = inclusive;
if (innerMatcher.match(rootOp)) {
options.filter = [&](Operation *subOp) {
return !filterMatcher.match(subOp);
};
getForwardSlice(rootOp, &forwardSlice, options);
return options.inclusive ? forwardSlice.size() > 1
: forwardSlice.size() >= 1;
}
return false;
}

private:
BaseMatcher innerMatcher;
Filter filterMatcher;
bool inclusive;
};

/// Matches transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher>
Expand All @@ -130,7 +202,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
omitUsesFromAbove);
}

/// Matches all transitive defs of a top-level operation up to N levels
/// Matches all transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
int64_t maxDepth) {
Expand All @@ -139,6 +211,28 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
false, false);
}

/// Matches all transitive defs of a top-level operation and stops where
/// `filterMatcher` rejects.
template <typename BaseMatcher, typename Filter>
inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
bool inclusive, bool omitBlockArguments,
bool omitUsesFromAbove) {
return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
std::move(innerMatcher), std::move(filterMatcher), inclusive,
omitBlockArguments, omitUsesFromAbove);
}

/// Matches all users of a top-level operation and stops where
/// `filterMatcher` rejects.
template <typename BaseMatcher, typename Filter>
inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
bool inclusive) {
return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
std::move(innerMatcher), std::move(filterMatcher), inclusive);
}

} // namespace mlir::query::matcher

#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
Loading
Loading