Skip to content

Commit

Permalink
feat: Add new GetNumberOfConnectedPeers() API
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisklein authored and rbx committed Oct 19, 2021
1 parent 8796ce5 commit fda8126
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 16 deletions.
8 changes: 8 additions & 0 deletions fairmq/Channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,14 @@ class Channel
/// @return true/false, true if automatic binding is enabled
bool GetAutoBind() const { return fAutoBind; }

/// @par Thread Safety
/// * @e Distinct @e objects: Safe.@n
/// * @e Shared @e objects: Unsafe.
auto GetNumberOfConnectedPeers() const
{
return fSocket ? fSocket->GetNumberOfConnectedPeers() : 0;
}

/// Set channel name
/// @param name Arbitrary channel name
void UpdateName(const std::string& name) { fName = name; Invalidate(); }
Expand Down
2 changes: 2 additions & 0 deletions fairmq/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ struct Socket
virtual unsigned long GetMessagesTx() const = 0;
virtual unsigned long GetMessagesRx() const = 0;

virtual unsigned long GetNumberOfConnectedPeers() const = 0;

TransportFactory* GetTransport() { return fTransport; }
void SetTransport(TransportFactory* transport) { fTransport = transport; }

Expand Down
5 changes: 5 additions & 0 deletions fairmq/ofi/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ class Socket final : public fair::mq::Socket
auto GetMessagesTx() const -> unsigned long override { return fMessagesTx; }
auto GetMessagesRx() const -> unsigned long override { return fMessagesRx; }

auto GetNumberOfConnectedPeers() const -> unsigned long override
{
throw SocketError("not yet implemented");
}

static auto GetConstant(const std::string& constant) -> int;

~Socket() override;
Expand Down
29 changes: 21 additions & 8 deletions fairmq/shmem/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ class Socket final : public fair::mq::Socket
public:
Socket(Manager& manager, const std::string& type, const std::string& name, const std::string& id, void* context, FairMQTransportFactory* fac = nullptr)
: fair::mq::Socket(fac)
, fSocket(nullptr)
, fManager(manager)
, fId(id + "." + name + "." + type)
, fSocket(nullptr)
, fMonitorSocket(nullptr)
, fBytesTx(0)
, fBytesRx(0)
, fMessagesTx(0)
, fMessagesRx(0)
, fTimeout(100)
, fConnectedPeersCount(0)
{
assert(context);

Expand All @@ -70,6 +72,7 @@ class Socket final : public fair::mq::Socket
}

fSocket = zmq_socket(context, zmq::getConstant(type));
fMonitorSocket = zmq::makeMonitorSocket(context, fSocket, fId);

if (fSocket == nullptr) {
LOG(error) << "Failed creating socket " << fId << ", reason: " << zmq_strerror(errno);
Expand Down Expand Up @@ -349,15 +352,17 @@ class Socket final : public fair::mq::Socket
{
// LOG(debug) << "Closing socket " << fId;

if (fSocket == nullptr) {
return;
if (fSocket && zmq_close(fSocket) != 0) {
LOG(error) << "Failed closing data socket " << fId
<< ", reason: " << zmq_strerror(errno);
}
fSocket = nullptr;

if (zmq_close(fSocket) != 0) {
LOG(error) << "Failed closing socket " << fId << ", reason: " << zmq_strerror(errno);
if (fMonitorSocket && zmq_close(fMonitorSocket) != 0) {
LOG(error) << "Failed closing monitor socket " << fId
<< ", reason: " << zmq_strerror(errno);
}

fSocket = nullptr;
fMonitorSocket = nullptr;
}

void SetOption(const std::string& option, const void* value, size_t valueSize) override
Expand Down Expand Up @@ -465,6 +470,12 @@ class Socket final : public fair::mq::Socket
return value;
}

unsigned long GetNumberOfConnectedPeers() const override
{
fConnectedPeersCount = zmq::updateNumberOfConnectedPeers(fConnectedPeersCount, fMonitorSocket);
return fConnectedPeersCount;
}

unsigned long GetBytesTx() const override { return fBytesTx; }
unsigned long GetBytesRx() const override { return fBytesRx; }
unsigned long GetMessagesTx() const override { return fMessagesTx; }
Expand All @@ -476,15 +487,17 @@ class Socket final : public fair::mq::Socket
~Socket() override { Close(); }

private:
void* fSocket;
Manager& fManager;
std::string fId;
void* fSocket;
void* fMonitorSocket;
std::atomic<unsigned long> fBytesTx;
std::atomic<unsigned long> fBytesRx;
std::atomic<unsigned long> fMessagesTx;
std::atomic<unsigned long> fMessagesRx;

int fTimeout;
mutable unsigned long fConnectedPeersCount;
};

} // namespace fair::mq::shmem
Expand Down
115 changes: 115 additions & 0 deletions fairmq/zeromq/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#ifndef FAIR_MQ_ZMQ_COMMON_H
#define FAIR_MQ_ZMQ_COMMON_H

#include <fairlogger/Logger.h>
#include <fairmq/Error.h>
#include <fairmq/tools/Strings.h>
#include <stdexcept>
#include <string_view>
Expand Down Expand Up @@ -53,6 +55,119 @@ inline auto getConstant(std::string_view constant) -> int
throw Error(tools::ToString("getConstant called with an invalid argument: ", constant));
}

/// Create a zmq event monitor socket pair, and configure/connect the reading socket
/// @return reading monitor socket
inline auto makeMonitorSocket(void* zmqCtx, void* socketToMonitor, std::string_view id) -> void*
{
assertm(zmqCtx, "Given zmq context exists"); // NOLINT

if (!socketToMonitor) { // nothing to do in this case
return nullptr;
}

auto const address(tools::ToString("inproc://", id));
{ // Create writing monitor socket on socket to be monitored and subscribe
// to all relevant events needed to compute connected peers
// from http://api.zeromq.org/master:zmq-socket-monitor:
//
// ZMQ_EVENT_CONNECTED - The socket has successfully connected to a remote peer. The event
// value is the file descriptor (FD) of the underlying network socket. Warning: there is
// no guarantee that the FD is still valid by the time your code receives this event.
// ZMQ_EVENT_ACCEPTED - The socket has accepted a connection from a remote peer. The event
// value is the FD of the underlying network socket. Warning: there is no guarantee that
// the FD is still valid by the time your code receives this event.
// ZMQ_EVENT_DISCONNECTED - The socket was disconnected unexpectedly. The event value is the
// FD of the underlying network socket. Warning: this socket will be closed.
auto const rc =
zmq_socket_monitor(socketToMonitor,
address.c_str(),
ZMQ_EVENT_CONNECTED | ZMQ_EVENT_ACCEPTED | ZMQ_EVENT_DISCONNECTED);
assertm(rc == 0, "Creating writing monitor socket succeeded"); // NOLINT
}
// Create reading monitor socket
auto mon(zmq_socket(zmqCtx, ZMQ_PAIR));
assertm(mon, "Creating reading monitor socker succeeded"); // NOLINT
{ // Set receive queue size to unlimited on reading socket
// Rationale: In the current implementation this is needed for correctness, because
// we do not have any thread that emptys the receive queue regularly.
// Progress only happens, when a user calls GetNumberOfConnectedPeers()`.
// The assumption here is, that not too many events will pile up anyways.
int const unlimited(0);
auto const rc = zmq_setsockopt(mon, ZMQ_RCVHWM, &unlimited, sizeof(unlimited));
assertm(rc == 0, "Setting rcv queue size to unlimited succeeded"); // NOLINT
}
{ // Connect the reading monitor socket
auto const rc = zmq_connect(mon, address.c_str());
assertm(rc == 0, "Connecting reading monitor socket succeeded"); // NOLINT
}
return mon;
}

/// Read pending zmq monitor event in a non-blocking fashion.
/// @return event id or -1 for no event pending
inline auto getMonitorEvent(void* monitorSocket) -> int
{
assertm(monitorSocket, "zmq monitor socket exists"); // NOLINT

// First frame in message contains event id
zmq_msg_t msg;
zmq_msg_init(&msg);
{
auto const size = zmq_msg_recv(&msg, monitorSocket, ZMQ_DONTWAIT);
if (size == -1) {
return -1; // no event pending
}
assertm(size >= 2, "At least two bytes were received"); // NOLINT
}

// Unpack event id
auto const event = *static_cast<uint16_t*>(zmq_msg_data(&msg));

// No unpacking of the event value needed for now

// Second frame in message contains event address
assertm(zmq_msg_more(&msg), "A second frame is pending"); // NOLINT
zmq_msg_init(&msg);
{
auto const rc = zmq_msg_recv(&msg, monitorSocket, 0);
assertm(rc >= 0, "second monitor event frame successfully received"); // NOLINT
}
assertm(!zmq_msg_more(&msg), "No more frames are pending"); // NOLINT
// No unpacking of the event address needed for now

return event;
}

/// Compute updated connected peers count by consuming pending events from a zmq monitor socket
/// @return updated connected peers count
inline auto updateNumberOfConnectedPeers(unsigned long count, void* monitorSocket) -> unsigned long
{
if (monitorSocket == nullptr) {
return count;
}

int event = getMonitorEvent(monitorSocket);
while (event >= 0) {
switch (event) {
case ZMQ_EVENT_CONNECTED:
case ZMQ_EVENT_ACCEPTED:
++count;
break;
case ZMQ_EVENT_DISCONNECTED:
if (count > 0) {
--count;
} else {
LOG(warn) << "Computing connected peers would result in negative count! Some event was missed!";
}
break;
default:
break;
}
event = getMonitorEvent(monitorSocket);
}
return count;
}

} // namespace fair::mq::zmq

#endif /* FAIR_MQ_ZMQ_COMMON_H */
30 changes: 22 additions & 8 deletions fairmq/zeromq/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
#include <zmq.h>

#include <atomic>
#include <functional>
#include <memory> // unique_ptr, make_unique
#include <string_view>

namespace fair::mq::zmq
{
Expand All @@ -31,13 +33,15 @@ class Socket final : public fair::mq::Socket
Socket(Context& ctx, const std::string& type, const std::string& name, const std::string& id, FairMQTransportFactory* factory = nullptr)
: fair::mq::Socket(factory)
, fCtx(ctx)
, fSocket(zmq_socket(fCtx.GetZmqCtx(), getConstant(type)))
, fId(id + "." + name + "." + type)
, fSocket(zmq_socket(fCtx.GetZmqCtx(), getConstant(type)))
, fMonitorSocket(makeMonitorSocket(fCtx.GetZmqCtx(), fSocket, fId))
, fBytesTx(0)
, fBytesRx(0)
, fMessagesTx(0)
, fMessagesRx(0)
, fTimeout(100)
, fConnectedPeersCount(0)
{
if (fSocket == nullptr) {
LOG(error) << "Failed creating socket " << fId << ", reason: " << zmq_strerror(errno);
Expand Down Expand Up @@ -301,15 +305,17 @@ class Socket final : public fair::mq::Socket
{
// LOG(debug) << "Closing socket " << fId;

if (fSocket == nullptr) {
return;
if (fSocket && zmq_close(fSocket) != 0) {
LOG(error) << "Failed closing data socket " << fId
<< ", reason: " << zmq_strerror(errno);
}
fSocket = nullptr;

if (zmq_close(fSocket) != 0) {
LOG(error) << "Failed closing socket " << fId << ", reason: " << zmq_strerror(errno);
if (fMonitorSocket && zmq_close(fMonitorSocket) != 0) {
LOG(error) << "Failed closing monitor socket " << fId
<< ", reason: " << zmq_strerror(errno);
}

fSocket = nullptr;
fMonitorSocket = nullptr;
}

void SetOption(const std::string& option, const void* value, size_t valueSize) override
Expand Down Expand Up @@ -417,6 +423,12 @@ class Socket final : public fair::mq::Socket
return value;
}

unsigned long GetNumberOfConnectedPeers() const override
{
fConnectedPeersCount = updateNumberOfConnectedPeers(fConnectedPeersCount, fMonitorSocket);
return fConnectedPeersCount;
}

unsigned long GetBytesTx() const override { return fBytesTx; }
unsigned long GetBytesRx() const override { return fBytesRx; }
unsigned long GetMessagesTx() const override { return fMessagesTx; }
Expand All @@ -429,14 +441,16 @@ class Socket final : public fair::mq::Socket

private:
Context& fCtx;
void* fSocket;
std::string fId;
void* fSocket;
void* fMonitorSocket;
std::atomic<unsigned long> fBytesTx;
std::atomic<unsigned long> fBytesRx;
std::atomic<unsigned long> fMessagesTx;
std::atomic<unsigned long> fMessagesRx;

int fTimeout;
mutable unsigned long fConnectedPeersCount;
};

} // namespace fair::mq::zmq
Expand Down
43 changes: 43 additions & 0 deletions test/channel/_channel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
* copied verbatim in the file "LICENSE" *
********************************************************************************/

#include <chrono>
#include <fairmq/Channel.h>
#include <fairmq/ProgOptions.h>
#include <fairmq/Tools.h>
#include <fairmq/TransportFactory.h>
#include <gtest/gtest.h>
#include <string>
#include <thread>

namespace
{
Expand Down Expand Up @@ -91,4 +96,42 @@ TEST(Channel, Validation)
ASSERT_EQ(channel2.Validate(), true);
}

auto testConnectedPeers(std::string const& transport)
{
using namespace std::chrono_literals;

ProgOptions config;
config.SetProperty<string>("session", tools::Uuid());
string const address(tools::ToString("ipc://", config.GetProperty<string>("session")));
unsigned long constexpr zero(0), one(1);
auto factory(TransportFactory::CreateTransportFactory(transport, tools::Uuid(), &config));

Channel ch1("ch1", "pair", factory);
ASSERT_EQ(ch1.GetNumberOfConnectedPeers(), zero);
ch1.Bind(address);
ASSERT_EQ(ch1.GetNumberOfConnectedPeers(), zero);

{
Channel ch2("ch2", "pair", factory);
ASSERT_EQ(ch2.GetNumberOfConnectedPeers(), zero);
ch2.Connect(address);
std::this_thread::sleep_for(10ms);
ASSERT_EQ(ch1.GetNumberOfConnectedPeers(), one);
ASSERT_EQ(ch2.GetNumberOfConnectedPeers(), one);
}

std::this_thread::sleep_for(10ms);
ASSERT_EQ(ch1.GetNumberOfConnectedPeers(), zero);
}

TEST(Channel, GetNumberOfConnectedPeers_zeromq)
{
testConnectedPeers("zeromq");
}

TEST(Channel, GetNumberOfConnectedPeers_shmem)
{
testConnectedPeers("shmem");
}

} /* namespace */

0 comments on commit fda8126

Please sign in to comment.