Skip to content

Commit

Permalink
WIP multipart send
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisklein committed Aug 22, 2022
1 parent a69e9a4 commit 6908616
Showing 1 changed file with 60 additions and 54 deletions.
114 changes: 60 additions & 54 deletions fairmq/shmem/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,9 @@ class Socket final : public fair::mq::Socket
}

template <typename SuccessHandler>
auto SendMeta(ZMsg& msg, int flags, int timeout, SuccessHandler&& successHandler) -> int64_t {
auto SendMeta(ZMsg& msg, int timeout, SuccessHandler&& successHandler) -> int64_t {
auto elapsed = 0;
auto const flags = (timeout == 0) ? ZMQ_DONTWAIT : 0;

while (true) {
auto const nbytes = zmq_msg_send(msg.Msg(), fSocket, flags);
Expand All @@ -261,18 +262,19 @@ class Socket final : public fair::mq::Socket
return static_cast<int>(TransferCode::error);
}

auto SendShm(shmem::Message& msg, int timeout) -> int64_t {
auto SendShm(shmem::Message& msg, int timeout) -> int64_t
{
auto metaMsg = MakeMetaMsg(msg);

auto const flags = (timeout == 0) ? ZMQ_DONTWAIT : 0;

return static_cast<int64_t>(SendMeta(metaMsg, flags, timeout, [this, &msg](std::size_t) {
msg.fQueued = true;
++fMessagesTx;
auto const size = msg.GetSize();
fBytesTx += size;
return static_cast<int64_t>(size);
}));
return static_cast<int64_t>(SendMeta(
metaMsg,
timeout,
[this, &msg](std::size_t) {
msg.fQueued = true;
++fMessagesTx;
auto const size = msg.GetSize();
fBytesTx += size;
return static_cast<int64_t>(size);
}));
}

public:
Expand Down Expand Up @@ -322,59 +324,63 @@ class Socket final : public fair::mq::Socket
}
}

int64_t Send(std::vector<MessagePtr>& msgVec, int timeout = -1) override
private:
// TODO In C++20 we should use a std::ranges::sized_range<shmem::Message const&>
auto MakeMetaMsg(std::vector<mq::MessagePtr>& mqMsgs) noexcept
{
auto const flags = (timeout == 0) ? ZMQ_DONTWAIT : 0;
auto elapsed = 0;
auto const n = mqMsgs.size();

// put it into zmq message
const unsigned int vecSize = msgVec.size();
ZMsg zmqMsg(vecSize * sizeof(MetaHeader));
#if FAIRMQ_HAS_STD_PMR
// TODO padded case
#endif // FAIRMQ_HAS_STD_PMR

// prepare the message with shm metas
MetaHeader* metas = static_cast<MetaHeader*>(zmqMsg.Data());
std::unique_ptr<MetaHeader[]> buffer(new MetaHeader[n]);
auto metas = buffer.get();
for (auto& mqMsg : mqMsgs) {
auto const& shmMsg = *static_cast<shmem::Message*>(mqMsg.get()); // NOLINT
// copy the shmMsg::fMeta field into the zmq data buffer
std::memcpy(metas++, &(shmMsg.fMeta), sizeof(MetaHeader)); // NOLINT
}
return ZMsg(
buffer.release(),
sizeof(MetaHeader) * n,
[](void* data, void*) {
delete[] static_cast<MetaHeader**>(data);
});
}

for (auto& msg : msgVec) {
public:
int64_t Send(std::vector<mq::MessagePtr>& msgVec, int timeout = -1) override
{
for (auto const& msg : msgVec) {
auto msgPtr = msg.get();
if (!msgPtr) {
return static_cast<int>(TransferCode::error);
}
assertm(dynamic_cast<shmem::Message*>(msgPtr), "given mq::Message is a shmem::Message"); // NOLINT
auto shmMsg = static_cast<shmem::Message*>(msgPtr); // NOLINT(cppcoreguidelines-pro-type-static-cast-downcast)
std::memcpy(metas++, &(shmMsg->fMeta), sizeof(MetaHeader));
}

while (true) {
int64_t totalSize = 0;
int nbytes = zmq_msg_send(zmqMsg.Msg(), fSocket, flags);
if (nbytes > 0) {
assert(static_cast<unsigned int>(nbytes) == (vecSize * sizeof(MetaHeader))); // all or nothing

for (auto& msg : msgVec) {
Message* shmMsg = static_cast<Message*>(msg.get());
shmMsg->fQueued = true;
totalSize += shmMsg->fMeta.fSize;
}

// store statistics on how many messages have been sent
fMessagesTx++;
fBytesTx += totalSize;

return totalSize;
} else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) {
if (fManager.Interrupted()) {
return static_cast<int>(TransferCode::interrupted);
} else if (zmq::ShouldRetry(flags, fTimeout, timeout, elapsed)) {
continue;
} else {
return static_cast<int>(TransferCode::timeout);
}
} else {
return zmq::HandleErrors(fId);
}
}

return static_cast<int>(TransferCode::error);
auto metaMsg = MakeMetaMsg(msgVec);
return static_cast<int64_t>(SendMeta(
metaMsg,
timeout,
[this, &msgVec](std::size_t nbytes) {
auto const size = msgVec.size();
int64_t totalSize = 0;
assertm(static_cast<unsigned int>(nbytes) == (size * sizeof(MetaHeader)), "all or nothing"); // NOLINT

for (auto& msg : msgVec) {
auto& shmMsg = *static_cast<shmem::Message*>(msg.get()); // NOLINT
shmMsg.fQueued = true;
totalSize += static_cast<int64_t>(shmMsg.fMeta.fSize);
}

// store statistics on how many messages have been sent
fMessagesTx++;
fBytesTx += totalSize;

return totalSize;
}));
}

int64_t Receive(std::vector<MessagePtr>& msgVec, int timeout = -1) override
Expand Down

0 comments on commit 6908616

Please sign in to comment.