Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CommNet dynamic register memory #5281

Merged
merged 7 commits into from
Jun 30, 2021
37 changes: 28 additions & 9 deletions oneflow/core/actor/actor_message.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,23 @@ bool IsSoleBlobAndDynamicEmpty(Regst* regst) {

ActorMsg ActorMsg::BuildRegstMsgToConsumer(int64_t producer, int64_t consumer,
Regst* regst_raw_ptr) {
ActorMsg msg;
ActorMsg msg{};
msg.src_actor_id_ = producer;
msg.dst_actor_id_ = consumer;
msg.msg_type_ = ActorMsgType::kRegstMsg;
msg.regst_wrapper_.regst = regst_raw_ptr;
if (Global<IDMgr>::Get()->MachineId4ActorId(consumer) == GlobalProcessCtx::Rank()) {
msg.regst_wrapper_.comm_net_token = nullptr;
} else {
msg.regst_wrapper_.comm_net_token = regst_raw_ptr->comm_net_token();
}
msg.regst_wrapper_.comm_net_token = nullptr;
msg.regst_wrapper_.regst_status = regst_raw_ptr->status();
msg.regst_wrapper_.regst_status.regst_desc_id = regst_raw_ptr->regst_desc_id();
msg.regst_wrapper_.has_sole_empty_blob = IsSoleBlobAndDynamicEmpty(regst_raw_ptr);
msg.regst_wrapper_.is_data_regst_to_consumer =
regst_raw_ptr->regst_desc()->regst_desc_type().has_data_regst_desc();
return msg;
}

ActorMsg ActorMsg::BuildRegstMsgToProducer(int64_t consumer, int64_t producer,
Regst* regst_raw_ptr) {
ActorMsg msg;
ActorMsg msg{};
msg.src_actor_id_ = consumer;
msg.dst_actor_id_ = producer;
msg.msg_type_ = ActorMsgType::kRegstMsg;
Expand All @@ -64,11 +62,12 @@ ActorMsg ActorMsg::BuildRegstMsgToProducer(int64_t consumer, int64_t producer,
msg.regst_wrapper_.comm_net_token = nullptr;
// you can NOT access the regst ptr when multi nodes, because the address is in another machine
msg.regst_wrapper_.has_sole_empty_blob = false;
msg.regst_wrapper_.is_data_regst_to_consumer = false;
return msg;
}

ActorMsg ActorMsg::BuildEordMsg(int64_t consumer, int64_t regst_desc_id) {
ActorMsg msg;
ActorMsg msg{};
msg.src_actor_id_ = -1;
msg.dst_actor_id_ = consumer;
msg.msg_type_ = ActorMsgType::kEordMsg;
Expand All @@ -77,7 +76,7 @@ ActorMsg ActorMsg::BuildEordMsg(int64_t consumer, int64_t regst_desc_id) {
}

ActorMsg ActorMsg::BuildCommandMsg(int64_t dst_actor_id, ActorCmd cmd) {
ActorMsg msg;
ActorMsg msg{};
msg.src_actor_id_ = -1;
msg.dst_actor_id_ = dst_actor_id;
msg.msg_type_ = ActorMsgType::kCmdMsg;
Expand Down Expand Up @@ -123,6 +122,11 @@ void* ActorMsg::comm_net_token() const {
return regst_wrapper_.comm_net_token;
}

void ActorMsg::set_comm_net_token(void* token) {
CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg);
regst_wrapper_.comm_net_token = token;
}

bool ActorMsg::has_sole_empty_blob() const {
CHECK_EQ(msg_type_, ActorMsgType::kRegstMsg);
return regst_wrapper_.has_sole_empty_blob;
Expand All @@ -133,4 +137,19 @@ int64_t ActorMsg::eord_regst_desc_id() const {
return eord_regst_desc_id_;
}

void ActorMsg::AddUserData(uint8_t size, const void* data) {
CHECK_EQ(user_data_size_, 0);
CHECK_LE(size, kActorMsgUserDataMaxSize);
user_data_size_ = size;
std::memcpy(user_data_, data, size);
}

uint8_t ActorMsg::user_data_size() const { return user_data_size_; }

const void* ActorMsg::user_data() const { return user_data_; }

bool ActorMsg::IsDataRegstMsgToConsumer() const {
return msg_type_ == ActorMsgType::kRegstMsg && regst_wrapper_.is_data_regst_to_consumer;
}

} // namespace oneflow
11 changes: 10 additions & 1 deletion oneflow/core/actor/actor_message.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ enum class ActorCmd {

enum class ActorMsgType { kRegstMsg = 0, kEordMsg, kCmdMsg };

constexpr uint8_t kActorMsgUserDataMaxSize = 32;

class ActorMsg final {
public:
// OF_DISALLOW_COPY_AND_MOVE(ActorMsg);
ActorMsg() = default;
~ActorMsg() = default;

Expand All @@ -54,8 +55,13 @@ class ActorMsg final {
int64_t piece_id() const;
int64_t act_id() const;
void* comm_net_token() const;
void set_comm_net_token(void* token);
bool has_sole_empty_blob() const;
int64_t eord_regst_desc_id() const;
void AddUserData(uint8_t size, const void* data);
uint8_t user_data_size() const;
const void* user_data() const;
bool IsDataRegstMsgToConsumer() const;

// Serialize
template<typename StreamT>
Expand All @@ -73,6 +79,7 @@ class ActorMsg final {
void* comm_net_token;
RegstStatus regst_status;
bool has_sole_empty_blob;
bool is_data_regst_to_consumer;
};

int64_t src_actor_id_;
Expand All @@ -83,6 +90,8 @@ class ActorMsg final {
RegstWrapper regst_wrapper_;
int64_t eord_regst_desc_id_;
};
uint8_t user_data_size_;
unsigned char user_data_[kActorMsgUserDataMaxSize];
};

template<typename StreamT>
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/comm_network/comm_network.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class CommNet {
// we can use this token to use the "Read"
virtual void* RegisterMemory(void* ptr, size_t byte_size) = 0;
virtual void UnRegisterMemory(void* token) = 0;
virtual void RegisterMemoryDone() = 0;

// Stream
void* NewActorReadId();
Expand Down
7 changes: 3 additions & 4 deletions oneflow/core/comm_network/epoll/epoll_comm_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,13 @@ EpollCommNet::~EpollCommNet() {
for (auto& pair : sockfd2helper_) { delete pair.second; }
}

void EpollCommNet::RegisterMemoryDone() {
// do nothing
}

void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_msg) {
SocketMsg msg;
msg.msg_type = SocketMsgType::kActor;
msg.actor_msg = actor_msg;
if (actor_msg.IsDataRegstMsgToConsumer()) {
msg.actor_msg.set_comm_net_token(actor_msg.regst()->comm_net_token());
}
GetSocketHelper(dst_machine_id)->AsyncWrite(msg);
}

Expand Down
4 changes: 1 addition & 3 deletions oneflow/core/comm_network/epoll/epoll_comm_network.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class EpollCommNet final : public CommNetIf<SocketMemDesc> {
OF_DISALLOW_COPY_AND_MOVE(EpollCommNet);
~EpollCommNet();

void RegisterMemoryDone() override;

void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg);
void SendTransportMsg(int64_t dst_machine_id, const TransportMsg& msg);
Expand All @@ -51,6 +49,6 @@ class EpollCommNet final : public CommNetIf<SocketMemDesc> {

} // namespace oneflow

#endif // OF_PLATFORM_POSIX
#endif // __linux__

#endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_EPOLL_COMM_NETWORK_H_
9 changes: 0 additions & 9 deletions oneflow/core/comm_network/ibverbs/ibverbs.proto
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
syntax = "proto2";
package oneflow;

message IBVerbsMemDescProto {
repeated uint64 mem_ptr = 1;
repeated uint32 mr_rkey = 2;
}

message IBVerbsConnectionInfo {
required uint32 lid = 1;
required uint32 qp_num = 2;
required uint64 subnet_prefix = 3;
required uint64 interface_id = 4;
}

message IBVerbsTokensMsg {
map<uint64, IBVerbsMemDescProto> token2mem_desc = 1;
}

55 changes: 28 additions & 27 deletions oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/platform/include/ibv.h"
#include "oneflow/core/actor/actor_message_bus.h"

#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)

Expand Down Expand Up @@ -55,37 +56,37 @@ IBVerbsCommNet::~IBVerbsCommNet() {
CHECK_EQ(ibv::wrapper.ibv_close_device(context_), 0);
}

void IBVerbsCommNet::RegisterMemoryDone() {
int64_t this_machine_id = GlobalProcessCtx::Rank();
IBVerbsTokensMsg this_tokens_msg;
for (IBVerbsMemDesc* mem_desc : mem_descs()) {
this_tokens_msg.mutable_token2mem_desc()->insert(
{reinterpret_cast<uint64_t>(mem_desc), mem_desc->ToProto()});
}
// TODO(chengcheng): Use Global<Transport> to sync session tokens.
Global<CtrlClient>::Get()->PushKV(GenTokensMsgKey(this_machine_id), this_tokens_msg);
for (int64_t peer_id : peer_machine_id()) {
IBVerbsTokensMsg peer_tokens_msg;
Global<CtrlClient>::Get()->PullKV(GenTokensMsgKey(peer_id), &peer_tokens_msg);
for (const auto& pair : peer_tokens_msg.token2mem_desc()) {
CHECK(token2mem_desc_.at(peer_id)
.emplace(reinterpret_cast<void*>(pair.first), pair.second)
.second);
}
void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
ActorMsg new_msg = msg;
if (msg.IsDataRegstMsgToConsumer()) {
CHECK_EQ(msg.user_data_size(), 0);
auto* mem_desc = reinterpret_cast<IBVerbsMemDesc*>(msg.regst()->comm_net_token());
CHECK(mem_desc != nullptr);
IBVerbsCommNetRMADesc rma_desc{};
rma_desc.mem_ptr = reinterpret_cast<uint64_t>(mem_desc->mem_ptr());
rma_desc.mem_size = mem_desc->mem_size();
rma_desc.mr_rkey = mem_desc->mr()->rkey;
static_assert(sizeof(IBVerbsCommNetRMADesc) <= kActorMsgUserDataMaxSize, "");
new_msg.AddUserData(sizeof(IBVerbsCommNetRMADesc), &rma_desc);
}
// TODO(chengcheng): change to OF_ENV_BARRIER
OF_SESSION_BARRIER();
Global<CtrlClient>::Get()->ClearKV(GenTokensMsgKey(this_machine_id));
qp_vec_.at(dst_machine_id)->PostSendRequest(new_msg);
}

void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
qp_vec_.at(dst_machine_id)->PostSendRequest(msg);
void IBVerbsCommNet::RecvActorMsg(const ActorMsg& msg) {
ActorMsg new_msg = msg;
if (msg.IsDataRegstMsgToConsumer()) {
std::lock_guard<std::mutex> lock(remote_regst2rma_desc_mutex_);
auto& desc = remote_regst2rma_desc_[std::make_pair(msg.src_actor_id(),
reinterpret_cast<uint64_t>(msg.regst()))];
if (!desc) { desc.reset(new IBVerbsCommNetRMADesc); }
CHECK_EQ(msg.user_data_size(), sizeof(IBVerbsCommNetRMADesc));
std::memcpy(desc.get(), msg.user_data(), sizeof(IBVerbsCommNetRMADesc));
new_msg.set_comm_net_token(desc.get());
}
Global<ActorMsgBus>::Get()->SendMsgWithoutCommNet(new_msg);
}

IBVerbsCommNet::IBVerbsCommNet()
: CommNetIf(),
token2mem_desc_(Global<ResourceDesc, ForEnv>::Get()->process_ranks().size()),
poll_exit_flag_(ATOMIC_FLAG_INIT) {
IBVerbsCommNet::IBVerbsCommNet() : CommNetIf(), poll_exit_flag_(ATOMIC_FLAG_INIT) {
ibv_device** device_list = ibv::wrapper.ibv_get_device_list(nullptr);
PCHECK(device_list);
ibv_device* device = device_list[0];
Expand Down Expand Up @@ -135,7 +136,7 @@ IBVerbsCommNet::IBVerbsCommNet()
void IBVerbsCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token,
void* dst_token) {
qp_vec_.at(src_machine_id)
->PostReadRequest(token2mem_desc_.at(src_machine_id).at(src_token),
->PostReadRequest(*reinterpret_cast<IBVerbsCommNetRMADesc*>(src_token),
*static_cast<const IBVerbsMemDesc*>(dst_token), read_id);
}

Expand Down
13 changes: 10 additions & 3 deletions oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,19 @@ limitations under the License.

namespace oneflow {

struct IBVerbsCommNetRMADesc {
uint64_t mem_ptr;
uint64_t mem_size;
uint32_t mr_rkey;
};

class IBVerbsCommNet final : public CommNetIf<IBVerbsMemDesc> {
public:
OF_DISALLOW_COPY_AND_MOVE(IBVerbsCommNet);
~IBVerbsCommNet();

void RegisterMemoryDone() override;

void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void RecvActorMsg(const ActorMsg& msg);

private:
friend class Global<IBVerbsCommNet>;
Expand All @@ -50,13 +55,15 @@ class IBVerbsCommNet final : public CommNetIf<IBVerbsMemDesc> {

static const int32_t max_poll_wc_num_;

std::vector<HashMap<void*, IBVerbsMemDescProto>> token2mem_desc_;
ibv_context* context_;
ibv_pd* pd_;
ibv_cq* cq_;
std::vector<IBVerbsQP*> qp_vec_;
std::atomic_flag poll_exit_flag_;
std::thread poll_thread_;
HashMap<std::pair<int64_t, uint64_t>, std::shared_ptr<IBVerbsCommNetRMADesc>>
remote_regst2rma_desc_;
std::mutex remote_regst2rma_desc_mutex_;
};

} // namespace oneflow
Expand Down
44 changes: 7 additions & 37 deletions oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,52 +14,22 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/platform/include/ibv.h"

#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)

namespace oneflow {

IBVerbsMemDesc::IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size) {
CHECK_GE(byte_size, 1);
size_t block_num =
(byte_size - 1) / Global<ResourceDesc, ForSession>::Get()->rdma_mem_block_byte() + 1;
sge_vec_.reserve(block_num);
mr_vec_.reserve(block_num);
char* ch_mem_ptr = reinterpret_cast<char*>(mem_ptr);
while (byte_size > 0) {
size_t cur_size =
std::min<size_t>(byte_size, Global<ResourceDesc, ForSession>::Get()->rdma_mem_block_byte());
ibv_mr* cur_mr = ibv::wrapper.ibv_reg_mr_wrap(
pd, ch_mem_ptr, cur_size,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ);
CHECK(cur_mr);
mr_vec_.push_back(cur_mr);
ibv_sge cur_sge{};
cur_sge.addr = reinterpret_cast<uint64_t>(ch_mem_ptr);
cur_sge.length = cur_size;
cur_sge.lkey = cur_mr->lkey;
sge_vec_.push_back(cur_sge);
ch_mem_ptr += cur_size;
byte_size -= cur_size;
}
CHECK_EQ(byte_size, 0);
CHECK_EQ(block_num, sge_vec_.size());
CHECK_EQ(block_num, mr_vec_.size());
IBVerbsMemDesc::IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size)
: mem_ptr_(mem_ptr), mem_size_(byte_size) {
mr_ = ibv::wrapper.ibv_reg_mr_wrap(
pd, mem_ptr, byte_size,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ);
CHECK(mr_);
}

IBVerbsMemDesc::~IBVerbsMemDesc() {
for (ibv_mr* mr : mr_vec_) { CHECK_EQ(ibv::wrapper.ibv_dereg_mr(mr), 0); }
}

IBVerbsMemDescProto IBVerbsMemDesc::ToProto() {
IBVerbsMemDescProto proto;
for (const ibv_sge& sge : sge_vec_) { proto.add_mem_ptr(sge.addr); }
for (ibv_mr* mr : mr_vec_) { proto.add_mr_rkey(mr->rkey); }
return proto;
}
IBVerbsMemDesc::~IBVerbsMemDesc() { CHECK_EQ(ibv::wrapper.ibv_dereg_mr(mr_), 0); }

} // namespace oneflow

Expand Down
11 changes: 7 additions & 4 deletions oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ class IBVerbsMemDesc final {
IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size);
~IBVerbsMemDesc();

const std::vector<ibv_sge>& sge_vec() const { return sge_vec_; }
void* mem_ptr() const { return mem_ptr_; }

IBVerbsMemDescProto ToProto();
size_t mem_size() const { return mem_size_; }

const ibv_mr* mr() const { return mr_; }

private:
std::vector<ibv_sge> sge_vec_;
std::vector<ibv_mr*> mr_vec_;
ibv_mr* mr_;
void* mem_ptr_;
uint64_t mem_size_;
};

} // namespace oneflow
Expand Down