Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/tvm/runtime/disco/process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def result_func(worker_id: int):
if worker_id != 0:
read_fd, write_fd = pool[worker_id - 1].start()
return ShapeTuple([read_fd, write_fd])
print("Shutting down the process pool")
del pool
return None

Expand Down
198 changes: 90 additions & 108 deletions src/runtime/disco/loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <functional>
#include <numeric>
#include <string>
#include <unordered_map>
#include <vector>

#include "../file_utils.h"
Expand All @@ -36,41 +37,39 @@ namespace runtime {
using relax_vm::NDArrayCacheMetadata;
using FileRecord = NDArrayCacheMetadata::FileRecord;
using ParamRecord = NDArrayCacheMetadata::FileRecord::ParamRecord;
using relax_vm::LoadShardInfoFromStr;
using relax_vm::ShardInfo;

/*! \brief An object that helps to load parameters in shards. */
class ShardLoaderObj : public Object {
public:
/*! \brief Create a shard loader. */
static ObjectRef Create(const std::string& path_to_metadata, const std::string& metadata,
const std::string& shard_info, Module mod);
std::string shard_info, Module mod);
/*! \brief Load the i-th parameter */
NDArray Load(int weight_index) const;
/*! \brief Load all the parameters */
Array<NDArray> LoadAll() const;
/*! \brief Slice the given tensor at a specific dimension */
NDArray Shard(NDArray source, int dim, int num_slices) const;

NDArray ApplyShardFunc(const ShardInfo::ShardFunc& shard_func, const NDArray& param) const;

static constexpr const char* _type_key = "runtime.disco.ShardLoader";
TVM_DECLARE_FINAL_OBJECT_INFO(ShardLoaderObj, Object);

public:
/*! \brief Information of how each weight is stored and sharded */
struct ShardInfo {
struct ParamInfo {
const FileRecord* file;
const ParamRecord* param;
int shard_dim;
ShardInfo shard_info;
};
/*! \brief The PackedFuncs being used during sharding */
std::unordered_map<std::string, PackedFunc> shard_funcs_;
/*! \brief The metadata loaded from `ndarray-cache.json` */
NDArrayCacheMetadata metadata_;
/*! \brief Sharding information for each weight */
std::vector<ShardInfo> shard_info_;
std::vector<ParamInfo> param_info_;
/*! \brief Maps the name of a shard to its index */
std::unordered_map<std::string, int> param_name_to_index_;
/*! \brief A method to slice a 3-D tensor */
TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard3d_fp16_;
TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard3d_fp32_;
TypedPackedFunc<void(DLTensor*, int, DLTensor*)> f_shard3d_uint32_;
/*! \brief The current file opened to load weights in it */
mutable const FileRecord* current_file_;
/*! \brief The context of the current file to be loaded from */
Expand All @@ -79,50 +78,61 @@ class ShardLoaderObj : public Object {

TVM_REGISTER_OBJECT_TYPE(ShardLoaderObj);

/*!
* \brief Get the shape of a result tensor if it is scattered along a given axis.
* \param shape The shape of the input tensor.
* \param dim The axis along which the tensor is scattered.
* \param num_shards The number of shards.
* \return The shape of the result tensor.
*/
inline std::vector<ShapeTuple::index_type> ShardShape(const ShapeTuple& shape, int dim,
int num_shards) {
CHECK(0 <= dim && dim < static_cast<int>(shape.size()))
<< "ValueError: Cannot scatter at dim " << dim << ", because "
<< "shape is " << shape << ".";
CHECK_EQ(shape[dim] % num_shards, 0)
<< "ValueError: The shape " << shape << " cannot be scattered at dim " << dim << " into "
<< num_shards << " shards.";
std::vector<ShapeTupleObj::index_type> result{shape.begin(), shape.end()};
result[dim] /= num_shards;
return result;
}

ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std::string& metadata,
const std::string& shard_info, Module mod) {
std::string shard_info, Module mod) {
if (shard_info.empty() && mod.defined()) {
if (PackedFunc get_shard_info = mod->GetFunction("get_shard_info"); get_shard_info != nullptr) {
shard_info = get_shard_info().operator String();
}
}
ObjectPtr<ShardLoaderObj> n = make_object<ShardLoaderObj>();
n->f_shard3d_fp16_ = mod->GetFunction("shard3d_fp16", true);
n->f_shard3d_fp32_ = mod->GetFunction("shard3d_fp32", true);
n->f_shard3d_uint32_ = mod->GetFunction("shard3d_uint32", true);
CHECK(n->f_shard3d_fp16_ != nullptr) << "ValueError: Cannot find the function: shard3d_fp16";
CHECK(n->f_shard3d_fp32_ != nullptr) << "ValueError: Cannot find the function: shard3d_fp32";
CHECK(n->f_shard3d_uint32_ != nullptr) << "ValueError: Cannot find the function: shard3d_uint32";
n->metadata_ = NDArrayCacheMetadata::LoadFromStr(metadata, path_to_metadata);
n->current_file_ = nullptr;
n->shard_info_.clear();
std::unordered_map<std::string, int> shards = LoadShardInfoFromStr(shard_info);
n->param_info_.clear();
std::unordered_map<std::string, ShardInfo> shards = relax_vm::LoadShardInfoFromStr(shard_info);
for (const FileRecord& file_record : n->metadata_.records) {
for (const ParamRecord& param_record : file_record.records) {
const std::string& name = param_record.name;
int shard_id = shards.count(name) ? shards[name] : -1;
n->param_name_to_index_[name] = n->shard_info_.size();
n->shard_info_.push_back(ShardInfo{&file_record, &param_record, shard_id});
int index = n->param_info_.size();
n->param_name_to_index_[name] = index;
ShardInfo& shard_info = shards[name];
for (const ShardInfo::ShardFunc& shard_func : shard_info.funcs) {
const std::string& name = shard_func.name;
if (PackedFunc f = mod.defined() ? mod->GetFunction(name, true) : nullptr; f != nullptr) {
n->shard_funcs_[name] = f;
} else if (const PackedFunc* f = runtime::Registry::Get(name)) {
n->shard_funcs_[name] = *f;
} else {
LOG(FATAL) << "ValueError: Undefined function: " << name;
}
}
n->param_info_.emplace_back(ParamInfo{&file_record, &param_record, shard_info});
}
}
return ObjectRef(std::move(n));
}

NDArray ShardLoaderObj::ApplyShardFunc(const ShardInfo::ShardFunc& shard_func,
const NDArray& param) const {
Device device = param->device;
NDArray o = NDArray::Empty(shard_func.output_info.shape, shard_func.output_info.dtype, device);
PackedFunc f = this->shard_funcs_.at(shard_func.name);
int n = static_cast<int>(shard_func.params.size());
std::vector<TVMValue> tvm_args(n + 2);
std::vector<int> type_codes(n + 2);
TVMArgsSetter setter(tvm_args.data(), type_codes.data());
const DLTensor* w_in = param.operator->();
const DLTensor* w_out = o.operator->();
setter(0, const_cast<DLTensor*>(w_in));
for (int i = 0; i < n; ++i) {
setter(i + 1, shard_func.params[i]);
}
setter(n + 1, const_cast<DLTensor*>(w_out));
TVMRetValue rv;
f.CallPacked(TVMArgs(tvm_args.data(), type_codes.data(), n + 2), &rv);
return o;
}

std::string GetSiblingPath(const std::string& path, const std::string& filename) {
size_t found = path.find_last_of("/\\");
if (found != std::string::npos) {
Expand All @@ -133,97 +143,69 @@ std::string GetSiblingPath(const std::string& path, const std::string& filename)

NDArray ShardLoaderObj::Load(int weight_index) const {
DiscoWorker* worker = DiscoWorker::ThreadLocal();
int shard_idx = worker->worker_id;
Device device = worker->default_device;
const auto& shard_info = shard_info_.at(weight_index);
const ParamRecord* param = shard_info.param;
const FileRecord* file = shard_info.file;
int shard_dim = shard_info.shard_dim;
int worker_id = worker->worker_id;
int num_shards = worker->num_workers;
Optional<NDArray> send = NullOpt;
if (shard_idx == 0) {
Device device = worker->default_device;
const ParamInfo& param_info = param_info_.at(weight_index);
const ParamRecord* param = param_info.param;
const FileRecord* file = param_info.file;

auto load = [this, param, device, file]() {
if (file != current_file_) {
current_file_ = file;
std::string file_name = GetSiblingPath(this->metadata_.path, file->data_path);
LoadBinaryFromFile(file_name, &this->current_file_stream_);
}
auto f_load = [](NDArray param, const void* data, size_t nbytes) {
param.CopyFromBytes(data, nbytes);
};
if (shard_dim != -1) {
send = this->Shard(param->Load(device, &this->current_file_stream_, f_load), shard_dim,
num_shards);
return param->Load(
device, &this->current_file_stream_,
[](NDArray param, const void* data, size_t nbytes) { param.CopyFromBytes(data, nbytes); });
};

bool needs_sharding = !param_info.shard_info.funcs.empty();
if (needs_sharding) {
ShapeTuple shape = param_info.shard_info.funcs.back().output_info.shape;
DataType dtype = param_info.shard_info.funcs.back().output_info.dtype;
ICHECK(shape.size() >= 1 && shape[0] == num_shards)
<< "ValueError: The first dimension of the "
<< "output shape must be equal to the "
<< "number of shards, but got: " << shape << " and num_shards = " << num_shards;
NDArray recv = NDArray::Empty(ShapeTuple(shape.begin() + 1, shape.end()), dtype, device);
if (worker_id == 0) {
NDArray w = load();
for (const ShardInfo::ShardFunc& shard_func : param_info.shard_info.funcs) {
w = this->ApplyShardFunc(shard_func, w);
}
ScatterFromWorker0(w, recv);
} else {
send = param->Load(device, &this->current_file_stream_, f_load);
ScatterFromWorker0(NullOpt, recv);
}
}
if (shard_dim != -1) {
NDArray recv =
NDArray::Empty(ShardShape(param->shape, shard_dim, num_shards), param->dtype, device);
ScatterFromWorker0(send, recv);
return recv;
} else {
NDArray recv;
if (send.defined()) {
recv = NDArray(send.value());
if (worker_id == 0) {
NDArray w = load();
BroadcastFromWorker0(w, w);
return w;
} else {
recv = NDArray::Empty(param->shape, param->dtype, device);
NDArray w = NDArray::Empty(param->shape, param->dtype, device);
BroadcastFromWorker0(w, w);
return w;
}
BroadcastFromWorker0(recv, recv);
return recv;
}
}

Array<NDArray> ShardLoaderObj::LoadAll() const {
int n = static_cast<int>(shard_info_.size());
int n = static_cast<int>(param_info_.size());
Array<NDArray> shards;
shards.reserve(n);
for (int i = 0; i < n; ++i) {
std::string param_name = "param_" + std::to_string(i);
ICHECK(this->param_name_to_index_.count(param_name));
int shard_id = this->param_name_to_index_.at(param_name);
shards.push_back(this->Load(shard_id));
}
return shards;
}

NDArray ShardLoaderObj::Shard(NDArray source, int dim, int num_slices) const {
ICHECK(dim >= 0 && dim < source->ndim);
// Assemble a flattened 3d src tensor
int64_t src_flat[3] = {1, 1, 1};
{
const int64_t* s = source.Shape().data();
int ndim = source->ndim;
src_flat[0] = std::accumulate(&s[0], &s[dim], 1, std::multiplies<int64_t>());
src_flat[1] = s[dim];
src_flat[2] = std::accumulate(&s[dim + 1], &s[ndim], 1, std::multiplies<int64_t>());
}
DLTensor src_tensor = *source.operator->();
src_tensor.ndim = 3;
src_tensor.shape = src_flat;
// Assmeble a flattened 4d dst tensor
int64_t dst_flat[4] = {num_slices, src_flat[0], src_flat[1] / num_slices, src_flat[2]};
NDArray destination{nullptr};
{
std::vector<ShapeTuple::index_type> dst_shape = ShardShape(source.Shape(), dim, num_slices);
dst_shape.insert(dst_shape.begin(), static_cast<ShapeTuple::index_type>(num_slices));
destination = NDArray::Empty(dst_shape, source->dtype, source->device);
}
DLTensor dst_tensor = *destination.operator->();
dst_tensor.ndim = 4;
dst_tensor.shape = dst_flat;
// Copy slices using the API
if (source.DataType() == DataType::Float(32)) {
this->f_shard3d_fp32_(&src_tensor, num_slices, &dst_tensor);
} else if (source.DataType() == DataType::Float(16)) {
this->f_shard3d_fp16_(&src_tensor, num_slices, &dst_tensor);
} else if (source.DataType() == DataType::UInt(32)) {
this->f_shard3d_uint32_(&src_tensor, num_slices, &dst_tensor);
} else {
LOG(FATAL) << "ValueError: Unsupported data type: " << source.DataType();
}
return destination;
}

TVM_REGISTER_GLOBAL("runtime.disco.ShardLoader").set_body_typed(ShardLoaderObj::Create);
TVM_REGISTER_GLOBAL("runtime.disco.ShardLoaderLoad")
.set_body_typed([](ObjectRef loader_obj, ShapeTuple weight_index) {
Expand Down
Loading