Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API/JIT] Enable registrable global function, introduce StackVM #25

Merged
merged 1 commit into from
Jan 25, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "./base.h"
#include "./expr.h"
#include "./module.h"
#include "./runtime/runtime.h"
#include "./runtime/packed_func.h"


namespace tvm {
Expand Down
8 changes: 5 additions & 3 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*!
* \brief See pesudo code
*
* bool tvm_print(VType value) {
* LOG(INFO) << value;
* int tvm_call_global(name, TVMValue* args) {
* PackedFunc f = PackedFunc::GetGlobal(name);
* f (args, type_code_of(args), len(args));
* return 0;
* }
*/
constexpr const char* tvm_print = "tvm_print";
constexpr const char* tvm_call_global = "tvm_call_global";

/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
Expand Down
3 changes: 0 additions & 3 deletions include/tvm/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ namespace tvm {
// Internal node container of lowered function.
class LoweredFuncNode;

// Internal node container of module.
class ModuleNode;

/*!
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.
Expand Down
60 changes: 57 additions & 3 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ TVM_DLL const char *TVMGetLastError(void);
* \param option_vals Additional option values to pass
* \param num_options Number of options to be passed into it.
* \param out_code 1: success, 0: already initialized
* \return Whether the function is successful.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMDeviceInit(int dev_mask,
const char** option_keys,
Expand All @@ -188,7 +188,7 @@ TVM_DLL int TVMContextEnabled(TVMContext ctx,
* \param dtype The array data type.
* \param ctx The ctx this array sits on.
* \param out The output handle.
* \return Whether the function is successful.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim,
Expand All @@ -198,6 +198,7 @@ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
/*!
* \brief Free the TVM Array.
* \param handle The array handle to be freed.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayFree(TVMArrayHandle handle);

Expand All @@ -206,6 +207,7 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
* \param from The array to be copied from.
* \param to The target space.
* \param stream The stream where the copy happens, can be NULL.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMArrayHandle to,
Expand All @@ -214,13 +216,14 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
* \brief Wait until all computations on stream completes.
* \param ctx The ctx to be synchronized.
* \param stream The stream to be synchronized.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);

/*!
* \brief Free the function when it is no longer needed.
* \param func The function handle
* \return whether
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMFuncFree(TVMFunctionHandle func);

Expand All @@ -239,6 +242,57 @@ TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args,
int* type_codes,
int num_args);

/*!
* \brief C type of packed function.
*
* \param args The arguments
* \param type_codes The type codes of the arguments
* \param num_args Number of arguments.
* \param resource_handle The handle additional resouce handle from fron-end.
*/
typedef void (*TVMPackedCFunc)(
TVMValue* args, int* type_codes, int num_args, void* resource_handle);

/*!
* \brief C callback to free the resource handle in C packed function.
* \param resource_handle The handle additional resouce handle from fron-end.
*/
typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle);

/*!
* \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
*
* The resource_handle will be managed by TVM API, until the function is no longer used.
*
* \param func The packed C function.
* \param resource_handle The resource handle from front-end, can be NULL.
* \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
* \param out the result function handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out);

/*!
* \brief Register the function to runtime's global table.
*
* The registered function then can be pulled by the backend by the name.
*
* \param name The name of the function.
* \param f The function to be registered.
*/
TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f);

/*!
* \brief Get a global function.
*
* \param name The name of the function.
* \param out the result function pointer.
*/
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
} // TVM_EXTERN_C

#endif // TVM_RUNTIME_C_RUNTIME_API_H_
52 changes: 40 additions & 12 deletions include/tvm/runtime/runtime.h → include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
@@ -1,35 +1,43 @@
/*!
* Copyright (c) 2016 by Contributors
* \file runtime.h
* \file packed_func.h
* \brief Runtime related c++ class.
*/
#ifndef TVM_RUNTIME_RUNTIME_H_
#define TVM_RUNTIME_RUNTIME_H_
#ifndef TVM_RUNTIME_PACKED_FUNC_H_
#define TVM_RUNTIME_PACKED_FUNC_H_

#include <functional>
#include <tuple>
#include <vector>
#include <string>
#include "./c_runtime_api.h"

namespace tvm {
namespace runtime {

/*!
* \brief Packed function is a runtime function
* whose argument type_codes are erased by packed format.
* \brief Packed function is a type-erased function.
* The arguments are passed by packed format.
*
* This is an useful unified interface to call generated functions.
* This is an useful unified interface to call generated functions,
* It is the unified function function type of TVM.
* It corresponds to TVMFunctionHandle in C runtime API.
*/
class PackedFunc {
public:
/*! \brief The internal std::function */
using FType = std::function<void(const TVMValue* args, const int* type_codes, int num_args)>;
/*! \brief default constructor */
PackedFunc() {}
/*!
* \brief constructing a packed function from a std::function.
* \param body the internal container of packed function.
*/
explicit PackedFunc(FType body) : body_(body) {}
/*!
* \brief invoke the packed function by directly passing in arguments.
* \brief Call packed function by directly passing in unpacked format.
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
* \return The first return value.
*/
template<typename... Args>
inline void operator()(Args&& ...args) const;
Expand All @@ -41,9 +49,25 @@ class PackedFunc {
*/
inline void CallPacked(const TVMValue* args, const int* type_codes, int num_args) const;
/*! \return the internal body function */
inline FType body() const {
return body_;
}
inline FType body() const;
/*!
* \brief Register f as into global function table
* \param name The name of the function.
* \param f The function to be registered.
* \return Reference to the registered function.
* \note The returned reference is valid until the end of the program
*/
static const PackedFunc& RegisterGlobal(const std::string& name, PackedFunc f);
/*!
* \brief Get the global function by name.
* \param name The name of the function.
* \return reference to the registered function.
*/
static const PackedFunc& GetGlobal(const std::string& name);
/*!
* \brief Get the names of currently registered global function.
*/
static std::vector<std::string> ListGlobalNames();

private:
/*! \brief internal container of packed function */
Expand All @@ -56,6 +80,10 @@ inline void PackedFunc::CallPacked(
body_(args, type_codes, num_args);
}

inline PackedFunc::FType PackedFunc::body() const {
return body_;
}

template<bool stop, std::size_t I, typename F, typename ...Args>
struct for_each_dispatcher_ {
static inline void run(const std::tuple<Args...>& args, F f) {
Expand Down Expand Up @@ -124,4 +152,4 @@ inline void PackedFunc::operator()(Args&& ...args) const {
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_RUNTIME_H_
#endif // TVM_RUNTIME_PACKED_FUNC_H_
Loading