From 5c01095029e21a4a0468116f5bf46968b69d8f3e Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 15 Sep 2021 19:53:00 +0800 Subject: [PATCH 1/2] support match pvar with dtype check --- src/arith/pattern_match.h | 69 ++++++++++++++++++++++++++++++++- tests/cpp/pattern_match_test.cc | 22 +++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 01baaa8d13a2..f17233d9ed64 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -210,6 +210,73 @@ class PVar : public Pattern> { mutable bool filled_{false}; }; +/*! + * \brief Wrapper for pattern variable container with extra match logic. + * + * \tparam Derived the type of derived class. + * \tparam T the type of the hole. + */ +template +class PVarWithCheck : public arith::Pattern> { + public: + // Store by reference in the expression. + using Nested = const PVarWithCheck&; + + void InitMatch_() const { pvar_.InitMatch_(); } + + bool Match_(const T& value) const { + if (!static_cast(this)->Match_(value)) return false; + return pvar_.Match_(value); + } + + template ::value>::type> + bool Match_(const NodeRefType& value) const { + if (const auto* ptr = value.template as()) { + return Match_(GetRef(ptr)); + } else { + return false; + } + } + + T Eval() const { return pvar_.Eval(); } + + protected: + arith::PVar pvar_; +}; + +/*! + * \brief Pattern variable container with expr type check. + * + * \tparam T the type of the hole. + * \tparam DType the Pattern type of dtype. + */ +template ::value>> +class PVarWithType : public PVarWithCheck, T> { + public: + explicit PVarWithType(const DType& dtype) : dtype_(dtype) {} + + bool Match_(const T& value) const { return dtype_.Match_(value->dtype); } + + protected: + typename DType::Nested dtype_; +}; + +/*! + * \brief Pattern variable container for data type with lanes. + */ +class PVecType : public PVarWithCheck { + public: + /*! \brief construct vector dtype placeholder with element type check */ + explicit PVecType(const DataType& elem_dtype) : elem_dtype_(elem_dtype) {} + + bool Match_(const DataType& dtype) const { return dtype.code() == elem_dtype_.code(); } + + protected: + DataType elem_dtype_; +}; + /*! * \brief Constant Pattern variable container. * @@ -467,7 +534,7 @@ class PCastExpr : public Pattern> { /*! * \brief Construct a cast pattern. * - * \param dtype The target data type, can be PVar or PConst. + * \param dtype The target data type, can be PVar or PConst. * \param value The input type. * * \return The result pattern. diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 4194c760628a..9484a4f9cba4 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -138,3 +138,25 @@ TEST(Pattern, IntImm) { // cannot match tx + 1 to v ICHECK(!(v * c).Match((tx + 1) * 3)); } + +TEST(Pattern, MatchWithType) { + using namespace tvm; + // match expr with specified dtype + arith::PVarWithType> pat(DataType::Float(32)); + tir::Var x("x", DataType::Float(32)); + tir::Var y("y", DataType::Float(32)); + tir::Var x_int("x", DataType::Int(32)); + tir::Var y_int("y", DataType::Int(32)); + ICHECK(pat.Match(x + y * 2.0f)); + ICHECK(!pat.Match(x_int + y_int * 2)); + + // match vectorized expr with specified element dtype + arith::PVecType vec_ty(DataType::Float(32)); + arith::PVarWithType vpat(vec_ty); + tir::Var vx = tir::Var("x", DataType::Float(32, 8)); + tir::Var vy("y", DataType::Float(32, 8)); + tir::Var vx_int("x", DataType::Int(32, 8)); + tir::Var vy_int("y", DataType::Int(32, 8)); + ICHECK(vpat.Match(vx + vy * tir::Broadcast(2.0f, 8))); + ICHECK(!vpat.Match(vx_int + vy_int * tir::Broadcast(2, 8))); +} From dbc18120ad9aad5532bedcf4bebd081350a96a41 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sat, 18 Sep 2021 12:13:28 +0800 Subject: [PATCH 2/2] fix rename Type -> DataType in pvarwithXXX classes --- src/arith/pattern_match.h | 8 ++++---- tests/cpp/pattern_match_test.cc | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index f17233d9ed64..7d1f315b3cb3 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -253,9 +253,9 @@ class PVarWithCheck : public arith::Pattern> { */ template ::value>> -class PVarWithType : public PVarWithCheck, T> { +class PVarWithDataType : public PVarWithCheck, T> { public: - explicit PVarWithType(const DType& dtype) : dtype_(dtype) {} + explicit PVarWithDataType(const DType& dtype) : dtype_(dtype) {} bool Match_(const T& value) const { return dtype_.Match_(value->dtype); } @@ -266,10 +266,10 @@ class PVarWithType : public PVarWithCheck, T> { /*! * \brief Pattern variable container for data type with lanes. */ -class PVecType : public PVarWithCheck { +class PVecDataType : public PVarWithCheck { public: /*! \brief construct vector dtype placeholder with element type check */ - explicit PVecType(const DataType& elem_dtype) : elem_dtype_(elem_dtype) {} + explicit PVecDataType(const DataType& elem_dtype) : elem_dtype_(elem_dtype) {} bool Match_(const DataType& dtype) const { return dtype.code() == elem_dtype_.code(); } diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 9484a4f9cba4..2e386c48b75c 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -142,7 +142,7 @@ TEST(Pattern, IntImm) { TEST(Pattern, MatchWithType) { using namespace tvm; // match expr with specified dtype - arith::PVarWithType> pat(DataType::Float(32)); + arith::PVarWithDataType> pat(DataType::Float(32)); tir::Var x("x", DataType::Float(32)); tir::Var y("y", DataType::Float(32)); tir::Var x_int("x", DataType::Int(32)); @@ -151,8 +151,8 @@ TEST(Pattern, MatchWithType) { ICHECK(!pat.Match(x_int + y_int * 2)); // match vectorized expr with specified element dtype - arith::PVecType vec_ty(DataType::Float(32)); - arith::PVarWithType vpat(vec_ty); + arith::PVecDataType vec_ty(DataType::Float(32)); + arith::PVarWithDataType vpat(vec_ty); tir::Var vx = tir::Var("x", DataType::Float(32, 8)); tir::Var vy("y", DataType::Float(32, 8)); tir::Var vx_int("x", DataType::Int(32, 8));