Skip to content

Commit

Permalink
CommNet dynamic register memory (#5281)
Browse files Browse the repository at this point in the history
Co-authored-by: guo ran <360112263@qq.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
3 people committed Jun 30, 2021
1 parent ed82d1d commit c428531
Show file tree
Hide file tree
Showing 22 changed files with 161 additions and 171 deletions.
37 changes: 28 additions & 9 deletions oneflow/core/actor/actor_message.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,23 @@ bool IsSoleBlobAndDynamicEmpty(Regst* regst) {

ActorMsg ActorMsg::BuildRegstMsgToConsumer(int64_t producer, int64_t consumer,
Regst* regst_raw_ptr) {
ActorMsg msg;
ActorMsg msg{};
msg.src_actor_id_ = producer;
msg.dst_actor_id_ = consumer;
msg.msg_type_ = ActorMsgType::kRegstMsg;
msg.regst_wrapper_.regst = regst_raw_ptr;
if (Global<IDMgr>::Get()->MachineId4ActorId(consumer) == GlobalProcessCtx::Rank()) {
msg.regst_wrapper_.comm_net_token = nullptr;
} else {
msg.regst_wrapper_.comm_net_token = regst_raw_ptr->comm_net_token();
}
msg.regst_wrapper_.comm_net_token = nullptr;
msg.regst_wrapper_.regst_status = regst_raw_ptr->status();
msg.regst_wrapper_.regst_status.regst_desc_id = regst_raw_ptr->regst_desc_id();
msg.regst_wrapper_.has_sole_empty_blob = IsSoleBlobAndDynamicEmpty(regst_raw_ptr);
msg.regst_wrapper_.is_data_regst_to_consumer =
regst_raw_ptr->regst_desc()->regst_desc_type().has_data_regst_desc();
return msg;
}

ActorMsg ActorMsg::BuildRegstMsgToProducer(int64_t consumer, int64_t producer,
Regst* regst_raw_ptr) {
ActorMsg msg;
ActorMsg msg{};
msg.src_actor_id_ = consumer;
msg.dst_actor_id_ = producer;
msg.msg_type_ = ActorMsgType::kRegstMsg;
Expand All @@ -64,11 +62,12 @@ ActorMsg ActorMsg::BuildRegstMsgToProducer(int64_t consumer, int64_t producer,
msg.regst_wrapper_.comm_net_token = nullptr;
// you can NOT access the regst ptr when multi nodes, because the address is in another machine
msg.regst_wrapper_.has_sole_empty_blob = false;
msg.regst_wrapper_.is_data_regst_to_consumer = false;
return msg;
}

ActorMsg ActorMsg::BuildEordMsg(int64_t consumer, int64_t regst_desc_id) {
ActorMsg msg;
ActorMsg msg{};
msg.src_actor_id_ = -1;
msg.dst_actor_id_ = consumer;
msg.msg_type_ = ActorMsgType::kEordMsg;
Expand All @@ -77,7 +76,7 @@ ActorMsg ActorMsg::BuildEordMsg(int64_t consumer, int64_t regst_desc_id) {
}

ActorMsg ActorMsg::BuildCommandMsg(int64_t dst_actor_id, ActorCmd cmd) {
ActorMsg msg;
ActorMsg msg{};
msg.src_actor_id_ = -1;
msg.dst_actor_id_ = dst_actor_id;
msg.msg_type_ = ActorMsgType::kCmdMsg;
Expand Down Expand Up @@ -123,6 +122,11 @@ void* ActorMsg::comm_net_token() const {
return regst_wrapper_.comm_net_token;
}

void ActorMsg::set_comm_net_token(void* token) {
CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg);
regst_wrapper_.comm_net_token = token;
}

bool ActorMsg::has_sole_empty_blob() const {
CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg);
return regst_wrapper_.has_sole_empty_blob;
Expand All @@ -133,4 +137,19 @@ int64_t ActorMsg::eord_regst_desc_id() const {
return eord_regst_desc_id_;
}

void ActorMsg::AddUserData(uint8_t size, const void* data) {
CHECK_EQ(user_data_size_, 0);
CHECK_LE(size, kActorMsgUserDataMaxSize);
user_data_size_ = size;
std::memcpy(user_data_, data, size);
}

uint8_t ActorMsg::user_data_size() const { return user_data_size_; }

const void* ActorMsg::user_data() const { return user_data_; }

bool ActorMsg::IsDataRegstMsgToConsumer() const {
return msg_type_ == ActorMsgType::kRegstMsg && regst_wrapper_.is_data_regst_to_consumer;
}

} // namespace oneflow
11 changes: 10 additions & 1 deletion oneflow/core/actor/actor_message.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ enum class ActorCmd {

enum class ActorMsgType { kRegstMsg = 0, kEordMsg, kCmdMsg };

constexpr uint8_t kActorMsgUserDataMaxSize = 32;

class ActorMsg final {
public:
// OF_DISALLOW_COPY_AND_MOVE(ActorMsg);
ActorMsg() = default;
~ActorMsg() = default;

Expand All @@ -54,8 +55,13 @@ class ActorMsg final {
int64_t piece_id() const;
int64_t act_id() const;
void* comm_net_token() const;
void set_comm_net_token(void* token);
bool has_sole_empty_blob() const;
int64_t eord_regst_desc_id() const;
void AddUserData(uint8_t size, const void* data);
uint8_t user_data_size() const;
const void* user_data() const;
bool IsDataRegstMsgToConsumer() const;

// Serialize
template<typename StreamT>
Expand All @@ -73,6 +79,7 @@ class ActorMsg final {
void* comm_net_token;
RegstStatus regst_status;
bool has_sole_empty_blob;
bool is_data_regst_to_consumer;
};

int64_t src_actor_id_;
Expand All @@ -83,6 +90,8 @@ class ActorMsg final {
RegstWrapper regst_wrapper_;
int64_t eord_regst_desc_id_;
};
uint8_t user_data_size_;
unsigned char user_data_[kActorMsgUserDataMaxSize];
};

template<typename StreamT>
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/comm_network/comm_network.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class CommNet {
// we can use this token to use the "Read"
virtual void* RegisterMemory(void* ptr, size_t byte_size) = 0;
virtual void UnRegisterMemory(void* token) = 0;
virtual void RegisterMemoryDone() = 0;

// Stream
void* NewActorReadId();
Expand Down
7 changes: 3 additions & 4 deletions oneflow/core/comm_network/epoll/epoll_comm_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,13 @@ EpollCommNet::~EpollCommNet() {
for (auto& pair : sockfd2helper_) { delete pair.second; }
}

void EpollCommNet::RegisterMemoryDone() {
// do nothing
}

void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_msg) {
SocketMsg msg;
msg.msg_type = SocketMsgType::kActor;
msg.actor_msg = actor_msg;
if (actor_msg.IsDataRegstMsgToConsumer()) {
msg.actor_msg.set_comm_net_token(actor_msg.regst()->comm_net_token());
}
GetSocketHelper(dst_machine_id)->AsyncWrite(msg);
}

Expand Down
4 changes: 1 addition & 3 deletions oneflow/core/comm_network/epoll/epoll_comm_network.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class EpollCommNet final : public CommNetIf<SocketMemDesc> {
OF_DISALLOW_COPY_AND_MOVE(EpollCommNet);
~EpollCommNet();

void RegisterMemoryDone() override;

void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg);
void SendTransportMsg(int64_t dst_machine_id, const TransportMsg& msg);
Expand All @@ -51,6 +49,6 @@ class EpollCommNet final : public CommNetIf<SocketMemDesc> {

} // namespace oneflow

#endif // OF_PLATFORM_POSIX
#endif // __linux__

#endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_EPOLL_COMM_NETWORK_H_
9 changes: 0 additions & 9 deletions oneflow/core/comm_network/ibverbs/ibverbs.proto
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
syntax = "proto2";
package oneflow;

message IBVerbsMemDescProto {
repeated uint64 mem_ptr = 1;
repeated uint32 mr_rkey = 2;
}

message IBVerbsConnectionInfo {
required uint32 lid = 1;
required uint32 qp_num = 2;
required uint64 subnet_prefix = 3;
required uint64 interface_id = 4;
}

message IBVerbsTokensMsg {
map<uint64, IBVerbsMemDescProto> token2mem_desc = 1;
}

55 changes: 28 additions & 27 deletions oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/platform/include/ibv.h"
#include "oneflow/core/actor/actor_message_bus.h"

#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)

Expand Down Expand Up @@ -55,37 +56,37 @@ IBVerbsCommNet::~IBVerbsCommNet() {
CHECK_EQ(ibv::wrapper.ibv_close_device(context_), 0);
}

void IBVerbsCommNet::RegisterMemoryDone() {
int64_t this_machine_id = GlobalProcessCtx::Rank();
IBVerbsTokensMsg this_tokens_msg;
for (IBVerbsMemDesc* mem_desc : mem_descs()) {
this_tokens_msg.mutable_token2mem_desc()->insert(
{reinterpret_cast<uint64_t>(mem_desc), mem_desc->ToProto()});
}
// TODO(chengcheng): Use Global<Transport> to sync session tokens.
Global<CtrlClient>::Get()->PushKV(GenTokensMsgKey(this_machine_id), this_tokens_msg);
for (int64_t peer_id : peer_machine_id()) {
IBVerbsTokensMsg peer_tokens_msg;
Global<CtrlClient>::Get()->PullKV(GenTokensMsgKey(peer_id), &peer_tokens_msg);
for (const auto& pair : peer_tokens_msg.token2mem_desc()) {
CHECK(token2mem_desc_.at(peer_id)
.emplace(reinterpret_cast<void*>(pair.first), pair.second)
.second);
}
void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
ActorMsg new_msg = msg;
if (msg.IsDataRegstMsgToConsumer()) {
CHECK_EQ(msg.user_data_size(), 0);
auto* mem_desc = reinterpret_cast<IBVerbsMemDesc*>(msg.regst()->comm_net_token());
CHECK(mem_desc != nullptr);
IBVerbsCommNetRMADesc rma_desc{};
rma_desc.mem_ptr = reinterpret_cast<uint64_t>(mem_desc->mem_ptr());
rma_desc.mem_size = mem_desc->mem_size();
rma_desc.mr_rkey = mem_desc->mr()->rkey;
static_assert(sizeof(IBVerbsCommNetRMADesc) <= kActorMsgUserDataMaxSize, "");
new_msg.AddUserData(sizeof(IBVerbsCommNetRMADesc), &rma_desc);
}
// TODO(chengcheng): change to OF_ENV_BARRIER
OF_SESSION_BARRIER();
Global<CtrlClient>::Get()->ClearKV(GenTokensMsgKey(this_machine_id));
qp_vec_.at(dst_machine_id)->PostSendRequest(new_msg);
}

void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
qp_vec_.at(dst_machine_id)->PostSendRequest(msg);
void IBVerbsCommNet::RecvActorMsg(const ActorMsg& msg) {
ActorMsg new_msg = msg;
if (msg.IsDataRegstMsgToConsumer()) {
std::lock_guard<std::mutex> lock(remote_regst2rma_desc_mutex_);
auto& desc = remote_regst2rma_desc_[std::make_pair(msg.src_actor_id(),
reinterpret_cast<uint64_t>(msg.regst()))];
if (!desc) { desc.reset(new IBVerbsCommNetRMADesc); }
CHECK_EQ(msg.user_data_size(), sizeof(IBVerbsCommNetRMADesc));
std::memcpy(desc.get(), msg.user_data(), sizeof(IBVerbsCommNetRMADesc));
new_msg.set_comm_net_token(desc.get());
}
Global<ActorMsgBus>::Get()->SendMsgWithoutCommNet(new_msg);
}

IBVerbsCommNet::IBVerbsCommNet()
: CommNetIf(),
token2mem_desc_(Global<ResourceDesc, ForEnv>::Get()->process_ranks().size()),
poll_exit_flag_(ATOMIC_FLAG_INIT) {
IBVerbsCommNet::IBVerbsCommNet() : CommNetIf(), poll_exit_flag_(ATOMIC_FLAG_INIT) {
ibv_device** device_list = ibv::wrapper.ibv_get_device_list(nullptr);
PCHECK(device_list);
ibv_device* device = device_list[0];
Expand Down Expand Up @@ -135,7 +136,7 @@ IBVerbsCommNet::IBVerbsCommNet()
void IBVerbsCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token,
void* dst_token) {
qp_vec_.at(src_machine_id)
->PostReadRequest(token2mem_desc_.at(src_machine_id).at(src_token),
->PostReadRequest(*reinterpret_cast<IBVerbsCommNetRMADesc*>(src_token),
*static_cast<const IBVerbsMemDesc*>(dst_token), read_id);
}

Expand Down
13 changes: 10 additions & 3 deletions oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,19 @@ limitations under the License.

namespace oneflow {

struct IBVerbsCommNetRMADesc {
uint64_t mem_ptr;
uint64_t mem_size;
uint32_t mr_rkey;
};

class IBVerbsCommNet final : public CommNetIf<IBVerbsMemDesc> {
public:
OF_DISALLOW_COPY_AND_MOVE(IBVerbsCommNet);
~IBVerbsCommNet();

void RegisterMemoryDone() override;

void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void RecvActorMsg(const ActorMsg& msg);

private:
friend class Global<IBVerbsCommNet>;
Expand All @@ -50,13 +55,15 @@ class IBVerbsCommNet final : public CommNetIf<IBVerbsMemDesc> {

static const int32_t max_poll_wc_num_;

std::vector<HashMap<void*, IBVerbsMemDescProto>> token2mem_desc_;
ibv_context* context_;
ibv_pd* pd_;
ibv_cq* cq_;
std::vector<IBVerbsQP*> qp_vec_;
std::atomic_flag poll_exit_flag_;
std::thread poll_thread_;
HashMap<std::pair<int64_t, uint64_t>, std::shared_ptr<IBVerbsCommNetRMADesc>>
remote_regst2rma_desc_;
std::mutex remote_regst2rma_desc_mutex_;
};

} // namespace oneflow
Expand Down
44 changes: 7 additions & 37 deletions oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,52 +14,22 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/platform/include/ibv.h"

#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)

namespace oneflow {

IBVerbsMemDesc::IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size) {
CHECK_GE(byte_size, 1);
size_t block_num =
(byte_size - 1) / Global<ResourceDesc, ForSession>::Get()->rdma_mem_block_byte() + 1;
sge_vec_.reserve(block_num);
mr_vec_.reserve(block_num);
char* ch_mem_ptr = reinterpret_cast<char*>(mem_ptr);
while (byte_size > 0) {
size_t cur_size =
std::min<size_t>(byte_size, Global<ResourceDesc, ForSession>::Get()->rdma_mem_block_byte());
ibv_mr* cur_mr = ibv::wrapper.ibv_reg_mr_wrap(
pd, ch_mem_ptr, cur_size,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ);
CHECK(cur_mr);
mr_vec_.push_back(cur_mr);
ibv_sge cur_sge{};
cur_sge.addr = reinterpret_cast<uint64_t>(ch_mem_ptr);
cur_sge.length = cur_size;
cur_sge.lkey = cur_mr->lkey;
sge_vec_.push_back(cur_sge);
ch_mem_ptr += cur_size;
byte_size -= cur_size;
}
CHECK_EQ(byte_size, 0);
CHECK_EQ(block_num, sge_vec_.size());
CHECK_EQ(block_num, mr_vec_.size());
IBVerbsMemDesc::IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size)
: mem_ptr_(mem_ptr), mem_size_(byte_size) {
mr_ = ibv::wrapper.ibv_reg_mr_wrap(
pd, mem_ptr, byte_size,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ);
CHECK(mr_);
}

IBVerbsMemDesc::~IBVerbsMemDesc() {
for (ibv_mr* mr : mr_vec_) { CHECK_EQ(ibv::wrapper.ibv_dereg_mr(mr), 0); }
}

IBVerbsMemDescProto IBVerbsMemDesc::ToProto() {
IBVerbsMemDescProto proto;
for (const ibv_sge& sge : sge_vec_) { proto.add_mem_ptr(sge.addr); }
for (ibv_mr* mr : mr_vec_) { proto.add_mr_rkey(mr->rkey); }
return proto;
}
IBVerbsMemDesc::~IBVerbsMemDesc() { CHECK_EQ(ibv::wrapper.ibv_dereg_mr(mr_), 0); }

} // namespace oneflow

Expand Down
11 changes: 7 additions & 4 deletions oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ class IBVerbsMemDesc final {
IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size);
~IBVerbsMemDesc();

const std::vector<ibv_sge>& sge_vec() const { return sge_vec_; }
void* mem_ptr() const { return mem_ptr_; }

IBVerbsMemDescProto ToProto();
size_t mem_size() const { return mem_size_; }

const ibv_mr* mr() const { return mr_; }

private:
std::vector<ibv_sge> sge_vec_;
std::vector<ibv_mr*> mr_vec_;
ibv_mr* mr_;
void* mem_ptr_;
uint64_t mem_size_;
};

} // namespace oneflow
Expand Down

0 comments on commit c428531

Please sign in to comment.