From 6908616f64334013360de20a36f0fb17dfbd8814 Mon Sep 17 00:00:00 2001 From: Dennis Klein Date: Mon, 22 Aug 2022 12:25:46 +0200 Subject: [PATCH] WIP multipart send --- fairmq/shmem/Socket.h | 114 ++++++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 54 deletions(-) diff --git a/fairmq/shmem/Socket.h b/fairmq/shmem/Socket.h index c458587de..e663b4700 100644 --- a/fairmq/shmem/Socket.h +++ b/fairmq/shmem/Socket.h @@ -238,8 +238,9 @@ class Socket final : public fair::mq::Socket } template - auto SendMeta(ZMsg& msg, int flags, int timeout, SuccessHandler&& successHandler) -> int64_t { + auto SendMeta(ZMsg& msg, int timeout, SuccessHandler&& successHandler) -> int64_t { auto elapsed = 0; + auto const flags = (timeout == 0) ? ZMQ_DONTWAIT : 0; while (true) { auto const nbytes = zmq_msg_send(msg.Msg(), fSocket, flags); @@ -261,18 +262,19 @@ class Socket final : public fair::mq::Socket return static_cast(TransferCode::error); } - auto SendShm(shmem::Message& msg, int timeout) -> int64_t { + auto SendShm(shmem::Message& msg, int timeout) -> int64_t + { auto metaMsg = MakeMetaMsg(msg); - - auto const flags = (timeout == 0) ? ZMQ_DONTWAIT : 0; - - return static_cast(SendMeta(metaMsg, flags, timeout, [this, &msg](std::size_t) { - msg.fQueued = true; - ++fMessagesTx; - auto const size = msg.GetSize(); - fBytesTx += size; - return static_cast(size); - })); + return static_cast(SendMeta( + metaMsg, + timeout, + [this, &msg](std::size_t) { + msg.fQueued = true; + ++fMessagesTx; + auto const size = msg.GetSize(); + fBytesTx += size; + return static_cast(size); + })); } public: @@ -322,59 +324,63 @@ class Socket final : public fair::mq::Socket } } - int64_t Send(std::vector& msgVec, int timeout = -1) override + private: + // TODO In C++20 we should use a std::ranges::sized_range + auto MakeMetaMsg(std::vector& mqMsgs) noexcept { - auto const flags = (timeout == 0) ? ZMQ_DONTWAIT : 0; - auto elapsed = 0; + auto const n = mqMsgs.size(); - // put it into zmq message - const unsigned int vecSize = msgVec.size(); - ZMsg zmqMsg(vecSize * sizeof(MetaHeader)); +#if FAIRMQ_HAS_STD_PMR + // TODO padded case +#endif // FAIRMQ_HAS_STD_PMR - // prepare the message with shm metas - MetaHeader* metas = static_cast(zmqMsg.Data()); + std::unique_ptr buffer(new MetaHeader[n]); + auto metas = buffer.get(); + for (auto& mqMsg : mqMsgs) { + auto const& shmMsg = *static_cast(mqMsg.get()); // NOLINT + // copy the shmMsg::fMeta field into the zmq data buffer + std::memcpy(metas++, &(shmMsg.fMeta), sizeof(MetaHeader)); // NOLINT + } + return ZMsg( + buffer.release(), + sizeof(MetaHeader) * n, + [](void* data, void*) { + delete[] static_cast(data); + }); + } - for (auto& msg : msgVec) { + public: + int64_t Send(std::vector& msgVec, int timeout = -1) override + { + for (auto const& msg : msgVec) { auto msgPtr = msg.get(); if (!msgPtr) { return static_cast(TransferCode::error); } assertm(dynamic_cast(msgPtr), "given mq::Message is a shmem::Message"); // NOLINT - auto shmMsg = static_cast(msgPtr); // NOLINT(cppcoreguidelines-pro-type-static-cast-downcast) - std::memcpy(metas++, &(shmMsg->fMeta), sizeof(MetaHeader)); - } - - while (true) { - int64_t totalSize = 0; - int nbytes = zmq_msg_send(zmqMsg.Msg(), fSocket, flags); - if (nbytes > 0) { - assert(static_cast(nbytes) == (vecSize * sizeof(MetaHeader))); // all or nothing - - for (auto& msg : msgVec) { - Message* shmMsg = static_cast(msg.get()); - shmMsg->fQueued = true; - totalSize += shmMsg->fMeta.fSize; - } - - // store statistics on how many messages have been sent - fMessagesTx++; - fBytesTx += totalSize; - - return totalSize; - } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) { - if (fManager.Interrupted()) { - return static_cast(TransferCode::interrupted); - } else if (zmq::ShouldRetry(flags, fTimeout, timeout, elapsed)) { - continue; - } else { - return static_cast(TransferCode::timeout); - } - } else { - return zmq::HandleErrors(fId); - } } - return static_cast(TransferCode::error); + auto metaMsg = MakeMetaMsg(msgVec); + return static_cast(SendMeta( + metaMsg, + timeout, + [this, &msgVec](std::size_t nbytes) { + auto const size = msgVec.size(); + int64_t totalSize = 0; + assertm(static_cast(nbytes) == (size * sizeof(MetaHeader)), "all or nothing"); // NOLINT + + for (auto& msg : msgVec) { + auto& shmMsg = *static_cast(msg.get()); // NOLINT + shmMsg.fQueued = true; + totalSize += static_cast(shmMsg.fMeta.fSize); + } + + // store statistics on how many messages have been sent + fMessagesTx++; + fBytesTx += totalSize; + + return totalSize; + })); } int64_t Receive(std::vector& msgVec, int timeout = -1) override