Skip to content

Commit

Permalink
[Fix] add TVM_DLL to disco functions (#16258)
Browse files Browse the repository at this point in the history
  • Loading branch information
LeshengJin committed Dec 18, 2023
1 parent 09acbc8 commit 7c35267
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 16 deletions.
4 changes: 2 additions & 2 deletions include/tvm/runtime/disco/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ void AllGather(NDArray send, NDArray recv);
* \param send The buffer to be broadcasted
* \param recv The buffer receives the broadcasted array
*/
void BroadcastFromWorker0(NDArray send, NDArray recv);
TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv);
/*!
* \brief Perform a scatter operation from worker-0, chunking the given buffer into equal parts.
* \param send For worker-0, it must be provided, and otherwise, the buffer must be None.
* The buffer will be divided into equal parts and sent to each worker accordingly.
* \param recv The receiving buffer, which must not be None.
*/
void ScatterFromWorker0(Optional<NDArray> send, NDArray recv);
TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv);
/*!
* \brief Perform a gather operation to worker-0.
* \param send The sending buffer, which must not be None.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/disco/disco_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class DiscoWorker {
/*! \brief Main loop of the worker */
void MainLoop();
/*! \brief Get the worker instance on the current thread */
static DiscoWorker* ThreadLocal();
TVM_DLL static DiscoWorker* ThreadLocal();
/*! \brief Set the specific register to a specific value */
void SetRegister(int reg_id, TVMArgValue value);

Expand Down
10 changes: 5 additions & 5 deletions include/tvm/runtime/relax_vm/ndarray_cache_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ struct NDArrayCacheMetadata {
};

/*! \brief Load a FileRecord into memory */
Array<NDArray> Load(Device device, //
const std::string& path_prefix, //
std::string* raw_data_buffer, //
Optional<NDArray>* staging_buffer = nullptr) const;
TVM_DLL Array<NDArray> Load(Device device, //
const std::string& path_prefix, //
std::string* raw_data_buffer, //
Optional<NDArray>* staging_buffer = nullptr) const;

/*! \brief Relative path to the bin file */
std::string data_path;
Expand All @@ -83,7 +83,7 @@ struct NDArrayCacheMetadata {
std::string path;

/*! \brief Load the metadata from a specific directory */
static NDArrayCacheMetadata Load(const std::string& path);
TVM_DLL static NDArrayCacheMetadata Load(const std::string& path);
/*! \brief Load the metadata from a given JSON string */
static NDArrayCacheMetadata LoadFromStr(const std::string& json_str, const std::string& path);
};
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/disco/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) {

void AllGather(NDArray send, NDArray recv) { GetCCLFunc("allgather")(send, recv); }

void BroadcastFromWorker0(NDArray send, NDArray recv) {
TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv) {
GetCCLFunc("broadcast_from_worker0")(send, recv);
}

void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) {
GetCCLFunc("scatter_from_worker0")(send, recv);
}

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/disco/disco_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct ThreadLocalDiscoWorker {
}
};

DiscoWorker* DiscoWorker::ThreadLocal() {
TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() {
DiscoWorker* ret = ThreadLocalDiscoWorker::Get()->worker;
CHECK(ret) << "ValueError: The current thread is not a DiscoWorker thread";
return ret;
Expand Down
11 changes: 6 additions & 5 deletions src/runtime/relax_vm/ndarray_cache_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ NDArrayCacheMetadata NDArrayCacheMetadata::LoadFromStr(const std::string& json_s
return result;
}

NDArrayCacheMetadata NDArrayCacheMetadata::Load(const std::string& path) {
TVM_DLL NDArrayCacheMetadata NDArrayCacheMetadata::Load(const std::string& path) {
picojson::value json_info;
{
std::string json_str;
Expand Down Expand Up @@ -183,10 +183,11 @@ NDArray NDArrayCacheMetadata::FileRecord::ParamRecord::Load(
return arr;
}

Array<NDArray> NDArrayCacheMetadata::FileRecord::Load(Device device,
const std::string& path_prefix, //
std::string* raw_data_buffer, //
Optional<NDArray>* staging_buffer) const {
TVM_DLL Array<NDArray> NDArrayCacheMetadata::FileRecord::Load(
Device device,
const std::string& path_prefix, //
std::string* raw_data_buffer, //
Optional<NDArray>* staging_buffer) const {
LoadBinaryFromFile(path_prefix + "/" + this->data_path, raw_data_buffer);
CHECK_EQ(this->format, "raw-shard") << "ValueError: Only `raw-shard` format is supported";
CHECK_EQ(this->nbytes, raw_data_buffer->length())
Expand Down

0 comments on commit 7c35267

Please sign in to comment.