Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions include/tvm/runtime/disco/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ class SessionObj : public Object {
*/
template <typename... Args>
DRef TVM_ALWAYS_INLINE CallPacked(const DRef& func, Args&&... args);
/*!
* \brief Call packed function on each worker using a packed sequence. The calling convention:
* The first element must be DiscoAction::kCallPacked,
* The second element must be 0, which will later be updated by the session to return reg_id
* The thirtd element is the function to be called.
*/
virtual DRef CallWithPacked(const TVMArgs& args) = 0;
/*! \brief Get a global functions on workers. */
virtual DRef GetGlobalFunc(const std::string& name) = 0;
/*!
Expand Down Expand Up @@ -224,8 +231,6 @@ class SessionObj : public Object {
protected:
/*! \brief Deallocate a register id, kill it on all workers, and append it to `free_regs_`. */
virtual void DeallocReg(int reg_id) = 0;
/*! \brief Call packed function on each worker using a packed sequence */
virtual DRef CallWithPacked(const TVMArgs& args) = 0;
};

/*!
Expand Down
15 changes: 14 additions & 1 deletion src/relax/transform/fuse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,14 @@ std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const Var&
class TIRFuseMutator : public ExprMutator {
public:
static IRModule Transform(const IRModule& mod) {
Map<String, BaseFunc> tir_funcs;

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes allows TIRs with global symbols to be preserved if they are not called by relax - this is useful for the shard loader.

for (const auto& [gv, func] : mod->functions) {
if (const auto* prim_func = func.as<tir::PrimFuncNode>()) {
if (prim_func->GetAttr<String>("global_symbol").defined()) {
tir_funcs.Set(gv->name_hint, func);
}
}
}
// Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder.
TIRFuseMutator mutator(mod);
// Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_`
Expand All @@ -868,7 +876,12 @@ class TIRFuseMutator : public ExprMutator {
}
}

// Step 3. Copy over module attributes and return.
// Step 3. Recover all primitive TIR functions if they have global symbol
for (const auto& [name, func] : tir_funcs) {
mutator.builder_->AddFunction(func, name);
}

// Step 4. Copy over module attributes and return.
auto modified_mod = mutator.builder_->GetContextIRModule();
if (mod->attrs.defined()) modified_mod = WithAttrs(modified_mod, mod->attrs->dict);
return modified_mod;
Expand Down
20 changes: 20 additions & 0 deletions src/runtime/disco/bcast_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include <sstream>

namespace tvm {
namespace runtime {

Expand Down Expand Up @@ -88,6 +90,24 @@ DRef BcastSessionObj::CallWithPacked(const TVMArgs& args) {
setter(1, reg_id);
setter(2, func->reg_id);
}
{
std::ostringstream os;
int cnt = 0;
for (int i = 3; i < num_args; ++i) {
int type_code = type_codes[i];
if (type_code != kDLInt && type_code != kDLUInt && type_code != kDLFloat &&
type_code != kTVMDataType && type_code != kDLDevice && type_code != kTVMOpaqueHandle &&
type_code != kTVMStr && type_code != kTVMNullptr && type_code != kTVMBytes &&
type_code != kTVMObjectHandle) {
os << "\n Argument #" << i << " has unsupported type code: " << type_code << " ("
<< ArgTypeCode2Str(type_code) << ")";
cnt += 1;
}
}
if (cnt > 0) {
LOG(FATAL) << "CallWithPacked() does not support " << cnt << " argument(s):" << os.str();
}
}
this->BroadcastPacked(TVMArgs(values, type_codes, num_args));
return BcastSessionObj::Internal::MakeDRef(reg_id, GetRef<Session>(this));
}
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/disco/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ void RecvFromWorker0(NDArray buffer) { GetCCLFunc("recv_from_worker0")(buffer);

int WorkerId() { return DiscoWorker::ThreadLocal()->worker_id; }

void SyncWorker() { GetCCLFunc("sync_worker")(); }

TVM_REGISTER_GLOBAL("runtime.disco.load_vm_module").set_body_typed(LoadVMModule);
TVM_REGISTER_GLOBAL("runtime.disco.empty").set_body_typed(DiscoEmptyNDArray);
TVM_REGISTER_GLOBAL("runtime.disco.allreduce")
Expand Down
7 changes: 6 additions & 1 deletion src/runtime/disco/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,14 @@ void GatherToWorker0(NDArray send, Optional<NDArray> recv);
* \param buffer The buffer to be received
*/
void RecvFromWorker0(NDArray buffer);

/*! \brief Get the local worker id */
int WorkerId();
/*!
* \brief Called by the worker thread. Waiting until the worker completes all its tasks.
* As a specific example, on a CUDA worker, it blocks until all kernels are launched and
* cudaStreamSynchronize is complete.
*/
void SyncWorker();

} // namespace runtime
} // namespace tvm
Expand Down
28 changes: 21 additions & 7 deletions src/runtime/disco/loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

Expand All @@ -42,8 +43,7 @@ class ShardLoaderObj : public Object {
public:
/*! \brief Create a shard loader. */
static ObjectRef Create(const std::string& path_to_metadata, const std::string& metadata,
const std::string& shard_info,
TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard);
const std::string& shard_info, Module mod);
/*! \brief Load the i-th parameter */
NDArray Load(int weight_index) const;
/*! \brief Load all the parameters */
Expand All @@ -68,7 +68,9 @@ class ShardLoaderObj : public Object {
/*! \brief Maps the name of a shard to its index */
std::unordered_map<std::string, int> param_name_to_index_;
/*! \brief A method to slice a 3-D tensor */
TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard_;
TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard3d_fp16_;
TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard3d_fp32_;
TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard3d_uint32_;
/*! \brief The current file opened to load weights in it */
mutable const FileRecord* current_file_;
/*! \brief The context of the current file to be loaded from */
Expand Down Expand Up @@ -98,10 +100,14 @@ inline std::vector<ShapeTuple::index_type> ShardShape(const ShapeTuple& shape, i
}

ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std::string& metadata,
const std::string& shard_info,
TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard) {
const std::string& shard_info, Module mod) {
ObjectPtr<ShardLoaderObj> n = make_object<ShardLoaderObj>();
n->f_shard_ = f_shard;
n->f_shard3d_fp16_ = mod->GetFunction("shard3d_fp16", true);
n->f_shard3d_fp32_ = mod->GetFunction("shard3d_fp32", true);
n->f_shard3d_uint32_ = mod->GetFunction("shard3d_uint32", true);
CHECK(n->f_shard3d_fp16_ != nullptr) << "ValueError: Cannot find the function: shard3d_fp16";
CHECK(n->f_shard3d_fp32_ != nullptr) << "ValueError: Cannot find the function: shard3d_fp32";
CHECK(n->f_shard3d_uint32_ != nullptr) << "ValueError: Cannot find the function: shard3d_uint32";
n->metadata_ = NDArrayCacheMetadata::LoadFromStr(metadata, path_to_metadata);
n->current_file_ = nullptr;
n->shard_info_.clear();
Expand Down Expand Up @@ -205,7 +211,15 @@ NDArray ShardLoaderObj::Shard(NDArray source, int dim, int num_slices) const {
dst_tensor.ndim = 4;
dst_tensor.shape = dst_flat;
// Copy slices using the API
this->f_shard_(&src_tensor, num_slices, &dst_tensor);
if (source.DataType() == DataType::Float(32)) {
this->f_shard3d_fp32_(&src_tensor, num_slices, &dst_tensor);
} else if (source.DataType() == DataType::Float(16)) {
this->f_shard3d_fp16_(&src_tensor, num_slices, &dst_tensor);
} else if (source.DataType() == DataType::UInt(32)) {
this->f_shard3d_uint32_(&src_tensor, num_slices, &dst_tensor);
} else {
LOG(FATAL) << "ValueError: Unsupported data type: " << source.DataType();
}
return destination;
}

Expand Down
20 changes: 20 additions & 0 deletions src/runtime/disco/nccl/nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <tvm/runtime/registry.h>

#include <mutex>
#include <sstream>
#include <vector>

#include "../../cuda/cuda_common.h"
Expand All @@ -41,6 +42,19 @@ struct NCCLGlobalContext {
}

void Initialize(const std::vector<int>& device_ids) {
{
std::ostringstream os;
bool is_first = true;
for (int device_id : device_ids) {
if (!is_first) {
os << ",";
} else {
is_first = false;
}
os << device_id;
}
LOG(INFO) << "Initializing NCCL with devices: " << os.str() << ".";
}
// TODO(@junrushao): support more flexible communicator pattern for generic SPMD usecases
DiscoWorker* worker = DiscoWorker::ThreadLocal();
int num_workers = worker->num_workers;
Expand Down Expand Up @@ -208,6 +222,11 @@ void RecvFromWorker0(NDArray buffer) {
NCCL_CALL(ncclGroupEnd());
}

void SyncWorker() {
NCCLThreadLocalContext* ctx = NCCLThreadLocalContext::Get();
CUDA_CALL(cudaStreamSynchronize(ctx->stream));
}

TVM_REGISTER_GLOBAL("runtime.disco.nccl.init_ccl")
.set_body([](TVMArgs args, TVMRetValue* rv) -> void {
std::vector<int> device_ids;
Expand All @@ -225,6 +244,7 @@ TVM_REGISTER_GLOBAL("runtime.disco.nccl.broadcast_from_worker0")
TVM_REGISTER_GLOBAL("runtime.disco.nccl.scatter_from_worker0").set_body_typed(ScatterFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.nccl.gather_to_worker0").set_body_typed(GatherToWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.nccl.recv_from_worker0").set_body_typed(RecvFromWorker0);
TVM_REGISTER_GLOBAL("runtime.disco.nccl.sync_worker").set_body_typed(SyncWorker);

} // namespace nccl
} // namespace runtime
Expand Down
10 changes: 9 additions & 1 deletion src/runtime/disco/worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
*/
#include "./worker.h"

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>

#include <thread>

#include "./builtin.h"

namespace tvm {
namespace runtime {

Expand All @@ -34,7 +37,11 @@ struct ThreadLocalDiscoWorker {
}
};

DiscoWorker* DiscoWorker::ThreadLocal() { return ThreadLocalDiscoWorker::Get()->worker; }
DiscoWorker* DiscoWorker::ThreadLocal() {
DiscoWorker* ret = ThreadLocalDiscoWorker::Get()->worker;
CHECK(ret) << "ValueError: The current thread is not a DiscoWorker thread";
return ret;
}

struct DiscoWorker::Impl {
static void MainLoop(DiscoWorker* self) {
Expand Down Expand Up @@ -116,6 +123,7 @@ struct DiscoWorker::Impl {

static void SyncWorker(DiscoWorker* self, int worker_id) {
if (worker_id == self->worker_id) {
::tvm::runtime::SyncWorker();
TVMValue values[2];
int type_codes[2];
PackArgs(values, type_codes, static_cast<int>(DiscoAction::kSyncWorker), worker_id);
Expand Down