Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support match pvar with dtype constraint #9016

Merged
merged 2 commits into from
Sep 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion src/arith/pattern_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,73 @@ class PVar : public Pattern<PVar<T>> {
mutable bool filled_{false};
};

/*!
* \brief Wrapper for pattern variable container with extra match logic.
*
* \tparam Derived the type of derived class.
* \tparam T the type of the hole.
*/
template <typename Derived, typename T>
class PVarWithCheck : public arith::Pattern<PVarWithCheck<Derived, T>> {
public:
// Store by reference in the expression.
using Nested = const PVarWithCheck<Derived, T>&;

void InitMatch_() const { pvar_.InitMatch_(); }

bool Match_(const T& value) const {
if (!static_cast<const Derived*>(this)->Match_(value)) return false;
return pvar_.Match_(value);
}

template <typename NodeRefType,
typename = typename std::enable_if<std::is_base_of<NodeRefType, T>::value>::type>
bool Match_(const NodeRefType& value) const {
if (const auto* ptr = value.template as<typename T::ContainerType>()) {
return Match_(GetRef<T>(ptr));
} else {
return false;
}
}

T Eval() const { return pvar_.Eval(); }

protected:
arith::PVar<T> pvar_;
};

/*!
* \brief Pattern variable container with expr type check.
*
* \tparam T the type of the hole.
* \tparam DType the Pattern type of dtype.
*/
template <typename T, typename DType,
typename = std::enable_if<std::is_base_of<T, PrimExpr>::value>>
class PVarWithDataType : public PVarWithCheck<PVarWithDataType<T, DType>, T> {
public:
explicit PVarWithDataType(const DType& dtype) : dtype_(dtype) {}

bool Match_(const T& value) const { return dtype_.Match_(value->dtype); }

protected:
typename DType::Nested dtype_;
};

/*!
* \brief Pattern variable container for data type with lanes.
*/
class PVecDataType : public PVarWithCheck<PVecDataType, DataType> {
public:
/*! \brief construct vector dtype placeholder with element type check */
explicit PVecDataType(const DataType& elem_dtype) : elem_dtype_(elem_dtype) {}

bool Match_(const DataType& dtype) const { return dtype.code() == elem_dtype_.code(); }

protected:
DataType elem_dtype_;
};

/*!
* \brief Constant Pattern variable container.
*
Expand Down Expand Up @@ -467,7 +534,7 @@ class PCastExpr : public Pattern<PCastExpr<DType, TA>> {
/*!
* \brief Construct a cast pattern.
*
* \param dtype The target data type, can be PVar<Type> or PConst<Type>.
* \param dtype The target data type, can be PVar<DataType> or PConst<DataType>.
* \param value The input type.
*
* \return The result pattern.
Expand Down
22 changes: 22 additions & 0 deletions tests/cpp/pattern_match_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,25 @@ TEST(Pattern, IntImm) {
// cannot match tx + 1 to v
ICHECK(!(v * c).Match((tx + 1) * 3));
}

TEST(Pattern, MatchWithType) {
using namespace tvm;
// match expr with specified dtype
arith::PVarWithDataType<PrimExpr, arith::PConst<DataType>> pat(DataType::Float(32));
tir::Var x("x", DataType::Float(32));
tir::Var y("y", DataType::Float(32));
tir::Var x_int("x", DataType::Int(32));
tir::Var y_int("y", DataType::Int(32));
ICHECK(pat.Match(x + y * 2.0f));
ICHECK(!pat.Match(x_int + y_int * 2));

// match vectorized expr with specified element dtype
arith::PVecDataType vec_ty(DataType::Float(32));
arith::PVarWithDataType<PrimExpr, arith::PVecDataType> vpat(vec_ty);
tir::Var vx = tir::Var("x", DataType::Float(32, 8));
tir::Var vy("y", DataType::Float(32, 8));
tir::Var vx_int("x", DataType::Int(32, 8));
tir::Var vy_int("y", DataType::Int(32, 8));
ICHECK(vpat.Match(vx + vy * tir::Broadcast(2.0f, 8)));
ICHECK(!vpat.Match(vx_int + vy_int * tir::Broadcast(2, 8)));
}