From fda8126a43e2294d18d4d6542d7d2217463090dd Mon Sep 17 00:00:00 2001 From: Dennis Klein Date: Tue, 12 Oct 2021 18:14:33 +0200 Subject: [PATCH] feat: Add new GetNumberOfConnectedPeers() API --- fairmq/Channel.h | 8 +++ fairmq/Socket.h | 2 + fairmq/ofi/Socket.h | 5 ++ fairmq/shmem/Socket.h | 29 +++++++--- fairmq/zeromq/Common.h | 115 ++++++++++++++++++++++++++++++++++++++ fairmq/zeromq/Socket.h | 30 +++++++--- test/channel/_channel.cxx | 43 ++++++++++++++ 7 files changed, 216 insertions(+), 16 deletions(-) diff --git a/fairmq/Channel.h b/fairmq/Channel.h index 04100ec1a..272e7b7a1 100644 --- a/fairmq/Channel.h +++ b/fairmq/Channel.h @@ -186,6 +186,14 @@ class Channel /// @return true/false, true if automatic binding is enabled bool GetAutoBind() const { return fAutoBind; } + /// @par Thread Safety + /// * @e Distinct @e objects: Safe.@n + /// * @e Shared @e objects: Unsafe. + auto GetNumberOfConnectedPeers() const + { + return fSocket ? fSocket->GetNumberOfConnectedPeers() : 0; + } + /// Set channel name /// @param name Arbitrary channel name void UpdateName(const std::string& name) { fName = name; Invalidate(); } diff --git a/fairmq/Socket.h b/fairmq/Socket.h index 3f1e38aba..5c529558f 100644 --- a/fairmq/Socket.h +++ b/fairmq/Socket.h @@ -72,6 +72,8 @@ struct Socket virtual unsigned long GetMessagesTx() const = 0; virtual unsigned long GetMessagesRx() const = 0; + virtual unsigned long GetNumberOfConnectedPeers() const = 0; + TransportFactory* GetTransport() { return fTransport; } void SetTransport(TransportFactory* transport) { fTransport = transport; } diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index d10679f43..444d9a84d 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -74,6 +74,11 @@ class Socket final : public fair::mq::Socket auto GetMessagesTx() const -> unsigned long override { return fMessagesTx; } auto GetMessagesRx() const -> unsigned long override { return fMessagesRx; } + auto GetNumberOfConnectedPeers() const -> unsigned long override + { + throw SocketError("not yet implemented"); + } + static auto GetConstant(const std::string& constant) -> int; ~Socket() override; diff --git a/fairmq/shmem/Socket.h b/fairmq/shmem/Socket.h index ced656e1b..04a013d18 100644 --- a/fairmq/shmem/Socket.h +++ b/fairmq/shmem/Socket.h @@ -53,14 +53,16 @@ class Socket final : public fair::mq::Socket public: Socket(Manager& manager, const std::string& type, const std::string& name, const std::string& id, void* context, FairMQTransportFactory* fac = nullptr) : fair::mq::Socket(fac) - , fSocket(nullptr) , fManager(manager) , fId(id + "." + name + "." + type) + , fSocket(nullptr) + , fMonitorSocket(nullptr) , fBytesTx(0) , fBytesRx(0) , fMessagesTx(0) , fMessagesRx(0) , fTimeout(100) + , fConnectedPeersCount(0) { assert(context); @@ -70,6 +72,7 @@ class Socket final : public fair::mq::Socket } fSocket = zmq_socket(context, zmq::getConstant(type)); + fMonitorSocket = zmq::makeMonitorSocket(context, fSocket, fId); if (fSocket == nullptr) { LOG(error) << "Failed creating socket " << fId << ", reason: " << zmq_strerror(errno); @@ -349,15 +352,17 @@ class Socket final : public fair::mq::Socket { // LOG(debug) << "Closing socket " << fId; - if (fSocket == nullptr) { - return; + if (fSocket && zmq_close(fSocket) != 0) { + LOG(error) << "Failed closing data socket " << fId + << ", reason: " << zmq_strerror(errno); } + fSocket = nullptr; - if (zmq_close(fSocket) != 0) { - LOG(error) << "Failed closing socket " << fId << ", reason: " << zmq_strerror(errno); + if (fMonitorSocket && zmq_close(fMonitorSocket) != 0) { + LOG(error) << "Failed closing monitor socket " << fId + << ", reason: " << zmq_strerror(errno); } - - fSocket = nullptr; + fMonitorSocket = nullptr; } void SetOption(const std::string& option, const void* value, size_t valueSize) override @@ -465,6 +470,12 @@ class Socket final : public fair::mq::Socket return value; } + unsigned long GetNumberOfConnectedPeers() const override + { + fConnectedPeersCount = zmq::updateNumberOfConnectedPeers(fConnectedPeersCount, fMonitorSocket); + return fConnectedPeersCount; + } + unsigned long GetBytesTx() const override { return fBytesTx; } unsigned long GetBytesRx() const override { return fBytesRx; } unsigned long GetMessagesTx() const override { return fMessagesTx; } @@ -476,15 +487,17 @@ class Socket final : public fair::mq::Socket ~Socket() override { Close(); } private: - void* fSocket; Manager& fManager; std::string fId; + void* fSocket; + void* fMonitorSocket; std::atomic fBytesTx; std::atomic fBytesRx; std::atomic fMessagesTx; std::atomic fMessagesRx; int fTimeout; + mutable unsigned long fConnectedPeersCount; }; } // namespace fair::mq::shmem diff --git a/fairmq/zeromq/Common.h b/fairmq/zeromq/Common.h index 912a716b4..c433d0fdc 100644 --- a/fairmq/zeromq/Common.h +++ b/fairmq/zeromq/Common.h @@ -8,6 +8,8 @@ #ifndef FAIR_MQ_ZMQ_COMMON_H #define FAIR_MQ_ZMQ_COMMON_H +#include +#include #include #include #include @@ -53,6 +55,119 @@ inline auto getConstant(std::string_view constant) -> int throw Error(tools::ToString("getConstant called with an invalid argument: ", constant)); } +/// Create a zmq event monitor socket pair, and configure/connect the reading socket +/// @return reading monitor socket +inline auto makeMonitorSocket(void* zmqCtx, void* socketToMonitor, std::string_view id) -> void* +{ + assertm(zmqCtx, "Given zmq context exists"); // NOLINT + + if (!socketToMonitor) { // nothing to do in this case + return nullptr; + } + + auto const address(tools::ToString("inproc://", id)); + { // Create writing monitor socket on socket to be monitored and subscribe + // to all relevant events needed to compute connected peers + // from http://api.zeromq.org/master:zmq-socket-monitor: + // + // ZMQ_EVENT_CONNECTED - The socket has successfully connected to a remote peer. The event + // value is the file descriptor (FD) of the underlying network socket. Warning: there is + // no guarantee that the FD is still valid by the time your code receives this event. + // ZMQ_EVENT_ACCEPTED - The socket has accepted a connection from a remote peer. The event + // value is the FD of the underlying network socket. Warning: there is no guarantee that + // the FD is still valid by the time your code receives this event. + // ZMQ_EVENT_DISCONNECTED - The socket was disconnected unexpectedly. The event value is the + // FD of the underlying network socket. Warning: this socket will be closed. + auto const rc = + zmq_socket_monitor(socketToMonitor, + address.c_str(), + ZMQ_EVENT_CONNECTED | ZMQ_EVENT_ACCEPTED | ZMQ_EVENT_DISCONNECTED); + assertm(rc == 0, "Creating writing monitor socket succeeded"); // NOLINT + } + // Create reading monitor socket + auto mon(zmq_socket(zmqCtx, ZMQ_PAIR)); + assertm(mon, "Creating reading monitor socker succeeded"); // NOLINT + { // Set receive queue size to unlimited on reading socket + // Rationale: In the current implementation this is needed for correctness, because + // we do not have any thread that emptys the receive queue regularly. + // Progress only happens, when a user calls GetNumberOfConnectedPeers()`. + // The assumption here is, that not too many events will pile up anyways. + int const unlimited(0); + auto const rc = zmq_setsockopt(mon, ZMQ_RCVHWM, &unlimited, sizeof(unlimited)); + assertm(rc == 0, "Setting rcv queue size to unlimited succeeded"); // NOLINT + } + { // Connect the reading monitor socket + auto const rc = zmq_connect(mon, address.c_str()); + assertm(rc == 0, "Connecting reading monitor socket succeeded"); // NOLINT + } + return mon; +} + +/// Read pending zmq monitor event in a non-blocking fashion. +/// @return event id or -1 for no event pending +inline auto getMonitorEvent(void* monitorSocket) -> int +{ + assertm(monitorSocket, "zmq monitor socket exists"); // NOLINT + + // First frame in message contains event id + zmq_msg_t msg; + zmq_msg_init(&msg); + { + auto const size = zmq_msg_recv(&msg, monitorSocket, ZMQ_DONTWAIT); + if (size == -1) { + return -1; // no event pending + } + assertm(size >= 2, "At least two bytes were received"); // NOLINT + } + + // Unpack event id + auto const event = *static_cast(zmq_msg_data(&msg)); + + // No unpacking of the event value needed for now + + // Second frame in message contains event address + assertm(zmq_msg_more(&msg), "A second frame is pending"); // NOLINT + zmq_msg_init(&msg); + { + auto const rc = zmq_msg_recv(&msg, monitorSocket, 0); + assertm(rc >= 0, "second monitor event frame successfully received"); // NOLINT + } + assertm(!zmq_msg_more(&msg), "No more frames are pending"); // NOLINT + // No unpacking of the event address needed for now + + return event; +} + +/// Compute updated connected peers count by consuming pending events from a zmq monitor socket +/// @return updated connected peers count +inline auto updateNumberOfConnectedPeers(unsigned long count, void* monitorSocket) -> unsigned long +{ + if (monitorSocket == nullptr) { + return count; + } + + int event = getMonitorEvent(monitorSocket); + while (event >= 0) { + switch (event) { + case ZMQ_EVENT_CONNECTED: + case ZMQ_EVENT_ACCEPTED: + ++count; + break; + case ZMQ_EVENT_DISCONNECTED: + if (count > 0) { + --count; + } else { + LOG(warn) << "Computing connected peers would result in negative count! Some event was missed!"; + } + break; + default: + break; + } + event = getMonitorEvent(monitorSocket); + } + return count; +} + } // namespace fair::mq::zmq #endif /* FAIR_MQ_ZMQ_COMMON_H */ diff --git a/fairmq/zeromq/Socket.h b/fairmq/zeromq/Socket.h index c5a139430..ed6465e90 100644 --- a/fairmq/zeromq/Socket.h +++ b/fairmq/zeromq/Socket.h @@ -20,7 +20,9 @@ #include #include +#include #include // unique_ptr, make_unique +#include namespace fair::mq::zmq { @@ -31,13 +33,15 @@ class Socket final : public fair::mq::Socket Socket(Context& ctx, const std::string& type, const std::string& name, const std::string& id, FairMQTransportFactory* factory = nullptr) : fair::mq::Socket(factory) , fCtx(ctx) - , fSocket(zmq_socket(fCtx.GetZmqCtx(), getConstant(type))) , fId(id + "." + name + "." + type) + , fSocket(zmq_socket(fCtx.GetZmqCtx(), getConstant(type))) + , fMonitorSocket(makeMonitorSocket(fCtx.GetZmqCtx(), fSocket, fId)) , fBytesTx(0) , fBytesRx(0) , fMessagesTx(0) , fMessagesRx(0) , fTimeout(100) + , fConnectedPeersCount(0) { if (fSocket == nullptr) { LOG(error) << "Failed creating socket " << fId << ", reason: " << zmq_strerror(errno); @@ -301,15 +305,17 @@ class Socket final : public fair::mq::Socket { // LOG(debug) << "Closing socket " << fId; - if (fSocket == nullptr) { - return; + if (fSocket && zmq_close(fSocket) != 0) { + LOG(error) << "Failed closing data socket " << fId + << ", reason: " << zmq_strerror(errno); } + fSocket = nullptr; - if (zmq_close(fSocket) != 0) { - LOG(error) << "Failed closing socket " << fId << ", reason: " << zmq_strerror(errno); + if (fMonitorSocket && zmq_close(fMonitorSocket) != 0) { + LOG(error) << "Failed closing monitor socket " << fId + << ", reason: " << zmq_strerror(errno); } - - fSocket = nullptr; + fMonitorSocket = nullptr; } void SetOption(const std::string& option, const void* value, size_t valueSize) override @@ -417,6 +423,12 @@ class Socket final : public fair::mq::Socket return value; } + unsigned long GetNumberOfConnectedPeers() const override + { + fConnectedPeersCount = updateNumberOfConnectedPeers(fConnectedPeersCount, fMonitorSocket); + return fConnectedPeersCount; + } + unsigned long GetBytesTx() const override { return fBytesTx; } unsigned long GetBytesRx() const override { return fBytesRx; } unsigned long GetMessagesTx() const override { return fMessagesTx; } @@ -429,14 +441,16 @@ class Socket final : public fair::mq::Socket private: Context& fCtx; - void* fSocket; std::string fId; + void* fSocket; + void* fMonitorSocket; std::atomic fBytesTx; std::atomic fBytesRx; std::atomic fMessagesTx; std::atomic fMessagesRx; int fTimeout; + mutable unsigned long fConnectedPeersCount; }; } // namespace fair::mq::zmq diff --git a/test/channel/_channel.cxx b/test/channel/_channel.cxx index b953142aa..7ab2323c8 100644 --- a/test/channel/_channel.cxx +++ b/test/channel/_channel.cxx @@ -6,9 +6,14 @@ * copied verbatim in the file "LICENSE" * ********************************************************************************/ +#include #include +#include +#include +#include #include #include +#include namespace { @@ -91,4 +96,42 @@ TEST(Channel, Validation) ASSERT_EQ(channel2.Validate(), true); } +auto testConnectedPeers(std::string const& transport) +{ + using namespace std::chrono_literals; + + ProgOptions config; + config.SetProperty("session", tools::Uuid()); + string const address(tools::ToString("ipc://", config.GetProperty("session"))); + unsigned long constexpr zero(0), one(1); + auto factory(TransportFactory::CreateTransportFactory(transport, tools::Uuid(), &config)); + + Channel ch1("ch1", "pair", factory); + ASSERT_EQ(ch1.GetNumberOfConnectedPeers(), zero); + ch1.Bind(address); + ASSERT_EQ(ch1.GetNumberOfConnectedPeers(), zero); + + { + Channel ch2("ch2", "pair", factory); + ASSERT_EQ(ch2.GetNumberOfConnectedPeers(), zero); + ch2.Connect(address); + std::this_thread::sleep_for(10ms); + ASSERT_EQ(ch1.GetNumberOfConnectedPeers(), one); + ASSERT_EQ(ch2.GetNumberOfConnectedPeers(), one); + } + + std::this_thread::sleep_for(10ms); + ASSERT_EQ(ch1.GetNumberOfConnectedPeers(), zero); +} + +TEST(Channel, GetNumberOfConnectedPeers_zeromq) +{ + testConnectedPeers("zeromq"); +} + +TEST(Channel, GetNumberOfConnectedPeers_shmem) +{ + testConnectedPeers("shmem"); +} + } /* namespace */