Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ class PackedFunc {
/*!
* \brief constructing a packed function from a std::function.
* \param body the internal container of packed function.
* \param name the name of this packed function.
*/
explicit PackedFunc(FType body) : body_(body) {}
explicit PackedFunc(FType body, String name = "<anonymous>") : body_(body), name_(name) {}
/*!
* \brief Call packed function by directly passing in unpacked format.
* \param args Arguments to be passed.
Expand All @@ -126,6 +127,10 @@ class PackedFunc {
inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;
/*! \return the internal body function */
inline FType body() const;
/*! \return the name of this function */
inline String name() const;
/*! \brief Set the name of the packed function. */
void set_name(String name) { name_ = name; }
/*! \return Whether the packed function is nullptr */
bool operator==(std::nullptr_t null) const { return body_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
Expand All @@ -134,6 +139,9 @@ class PackedFunc {
private:
/*! \brief internal container of packed function */
FType body_;

/*! \brief the name of this packed function */
String name_;
};

/*!
Expand Down Expand Up @@ -223,19 +231,19 @@ class TypedPackedFunc<R(Args...)> {
* \code
* auto typed_lambda = [](int x)->int { return x + 1; }
* // construct from packed function
* TypedPackedFunc<int(int)> ftyped(typed_lambda);
* TypedPackedFunc<int(int)> ftyped(typed_lambda, "add_one");
* // call the typed version.
* ICHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param typed_lambda typed lambda function.
* \param name the name of this function.
* \tparam FLambda the type of the lambda function.
*/
template <typename FLambda, typename = typename std::enable_if<
std::is_convertible<FLambda,
std::function<R(Args...)>>::value>::type>
TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*)
this->AssignTypedLambda(typed_lambda);
template <typename FLambda, typename = typename std::enable_if<std::is_convertible<
FLambda, std::function<R(Args...)>>::value>::type>
TypedPackedFunc(const FLambda& typed_lambda, String name = "<anonymous>") { // NOLINT(*)
this->AssignTypedLambda(typed_lambda, name);
}
/*!
* \brief copy assignment operator from typed lambda
Expand All @@ -257,7 +265,7 @@ class TypedPackedFunc<R(Args...)> {
std::is_convertible<FLambda,
std::function<R(Args...)>>::value>::type>
TSelf& operator=(FLambda typed_lambda) { // NOLINT(*)
this->AssignTypedLambda(typed_lambda);
this->AssignTypedLambda(typed_lambda, "<anonymous>");
return *this;
}
/*!
Expand All @@ -284,6 +292,9 @@ class TypedPackedFunc<R(Args...)> {
* \return reference the internal PackedFunc
*/
const PackedFunc& packed() const { return packed_; }
String name() const { return packed_.name(); }
/*! \brief Set the name associated with the typed packed function. */
void set_name(String name) { packed_.set_name(name); }
/*! \return Whether the packed function is nullptr */
bool operator==(std::nullptr_t null) const { return packed_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
Expand All @@ -297,11 +308,12 @@ class TypedPackedFunc<R(Args...)> {
* \brief Assign the packed field using a typed lambda function.
*
* \param flambda The lambda function.
* \param name The name to associate with the lambda function.
* \tparam FLambda The lambda function type.
* \note We capture the lambda when possible for maximum efficiency.
*/
template <typename FLambda>
inline void AssignTypedLambda(FLambda flambda);
inline void AssignTypedLambda(FLambda flambda, String name);
};

/*! \brief Arguments into TVM functions. */
Expand Down Expand Up @@ -991,6 +1003,8 @@ inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(

inline PackedFunc::FType PackedFunc::body() const { return body_; }

inline String PackedFunc::name() const { return name_; }

// internal namespace
inline const char* ArgTypeCode2Str(int type_code) {
switch (type_code) {
Expand Down Expand Up @@ -1205,7 +1219,10 @@ inline TVMRetValue PackedFunc::operator()(Args&&... args) const {
int type_codes[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);
TVMRetValue rv;
body_(TVMArgs(values, type_codes, kNumArgs), &rv);
TVMArgs tvm_args(values, type_codes, kNumArgs);
CHECK_EQ(tvm_args.size(), sizeof...(Args))
<< name_ << " expects " << sizeof...(Args) << ", but " << tvm_args.size() << " were provided";
body_(tvm_args, &rv);
return rv;
}

Expand Down Expand Up @@ -1302,10 +1319,15 @@ TypedPackedFunc<R(Args...)>::TypedPackedFunc(TVMMovableArgValue_&& value)

template <typename R, typename... Args>
template <typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) {
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda, String name) {
packed_ = PackedFunc([flambda, name](const TVMArgs& args, TVMRetValue* rv) {
if (args.size() != sizeof...(Args)) {
LOG(FATAL) << name << " expects " << sizeof...(Args) << " arguments, but " << args.size()
<< " were provided.";
}
detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
});
packed_.set_name(name);
}

template <typename R, typename... Args>
Expand Down
12 changes: 6 additions & 6 deletions include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Registry {
* \param f The body of the function.
*/
Registry& set_body(PackedFunc::FType f) { // NOLINT(*)
return set_body(PackedFunc(f));
return set_body(PackedFunc(f, name_));
}
/*!
* \brief set the body of the function to the given function.
Expand Down Expand Up @@ -93,7 +93,7 @@ class Registry {
template <typename FLambda>
Registry& set_body_typed(FLambda f) {
using FType = typename detail::function_signature<FLambda>::FType;
return set_body(TypedPackedFunc<FType>(std::move(f)).packed());
return set_body(TypedPackedFunc<FType>(std::move(f), name_).packed());
}
/*!
* \brief set the body of the function to be the passed method pointer.
Expand Down Expand Up @@ -122,7 +122,7 @@ class Registry {
// call method pointer
return (target.*f)(params...);
};
return set_body(TypedPackedFunc<R(T, Args...)>(fwrap));
return set_body(TypedPackedFunc<R(T, Args...)>(fwrap, name_));
}

/*!
Expand Down Expand Up @@ -152,7 +152,7 @@ class Registry {
// call method pointer
return (target.*f)(params...);
};
return set_body(TypedPackedFunc<R(const T, Args...)>(fwrap));
return set_body(TypedPackedFunc<R(const T, Args...)>(fwrap, name_));
}

/*!
Expand Down Expand Up @@ -194,7 +194,7 @@ class Registry {
// call method pointer
return (target->*f)(params...);
};
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap));
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_));
}

/*!
Expand Down Expand Up @@ -236,7 +236,7 @@ class Registry {
// call method pointer
return (target->*f)(params...);
};
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap));
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap, name_));
}

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
}
return PackedFunc();
};
*ret = TypedPackedFunc<PackedFunc(std::string)>(f);
*ret = TypedPackedFunc<PackedFunc(std::string)>(f, "arith.CreateAnalyzer");
});

} // namespace arith
Expand Down
1 change: 1 addition & 0 deletions src/runtime/registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ struct Registry::Manager {

Registry& Registry::set_body(PackedFunc f) { // NOLINT(*)
func_ = f;
func_.set_name(name_);
return *this;
}

Expand Down