diff --git a/fairmq/FairMQChannel.cxx b/fairmq/FairMQChannel.cxx index 42f87d899..a2666609e 100644 --- a/fairmq/FairMQChannel.cxx +++ b/fairmq/FairMQChannel.cxx @@ -18,90 +18,65 @@ #include // join/split #include -#include // std::move +#include using namespace std; mutex FairMQChannel::fChannelMutex; FairMQChannel::FairMQChannel() - : fSocket(nullptr) - , fType("unspecified") - , fMethod("unspecified") - , fAddress("unspecified") - , fTransportType(fair::mq::Transport::DEFAULT) - , fSndBufSize(1000) - , fRcvBufSize(1000) - , fSndKernelSize(0) - , fRcvKernelSize(0) - , fLinger(500) - , fRateLogging(1) - , fName("") - , fIsValid(false) - , fTransportFactory(nullptr) - , fMultipart(false) - , fModified(true) - , fReset(false) -{ -} + : FairMQChannel("", "unspecified", "unspecified", "unspecified", nullptr) +{} FairMQChannel::FairMQChannel(const string& type, const string& method, const string& address) - : fSocket(nullptr) - , fType(type) - , fMethod(method) - , fAddress(address) - , fTransportType(fair::mq::Transport::DEFAULT) - , fSndBufSize(1000) - , fRcvBufSize(1000) - , fSndKernelSize(0) - , fRcvKernelSize(0) - , fLinger(500) - , fRateLogging(1) - , fName("") - , fIsValid(false) - , fTransportFactory(nullptr) - , fMultipart(false) - , fModified(true) - , fReset(false) -{ -} + : FairMQChannel("", type, method, address, nullptr) +{} FairMQChannel::FairMQChannel(const string& name, const string& type, shared_ptr factory) - : fSocket(factory->CreateSocket(type, name)) + : FairMQChannel(name, type, "unspecified", "unspecified", factory) +{} + +FairMQChannel::FairMQChannel(const string& name, const string& type, const string& method, const string& address, shared_ptr factory) + : fTransportFactory(factory) + , fTransportType(factory ? factory->GetType() : fair::mq::Transport::DEFAULT) + , fSocket(factory ? factory->CreateSocket(type, name) : nullptr) , fType(type) - , fMethod("unspecified") - , fAddress("unspecified") - , fTransportType(factory->GetType()) + , fMethod(method) + , fAddress(address) , fSndBufSize(1000) , fRcvBufSize(1000) , fSndKernelSize(0) , fRcvKernelSize(0) , fLinger(500) , fRateLogging(1) + , fPortRangeMin(22000) + , fPortRangeMax(23000) + , fAutoBind(true) , fName(name) , fIsValid(false) - , fTransportFactory(factory) , fMultipart(false) , fModified(true) , fReset(false) -{ -} +{} FairMQChannel::FairMQChannel(const FairMQChannel& chan) - : fSocket(nullptr) + : fTransportFactory(nullptr) + , fTransportType(chan.fTransportType) + , fSocket(nullptr) , fType(chan.fType) , fMethod(chan.fMethod) , fAddress(chan.fAddress) - , fTransportType(chan.fTransportType) , fSndBufSize(chan.fSndBufSize) , fRcvBufSize(chan.fRcvBufSize) , fSndKernelSize(chan.fSndKernelSize) , fRcvKernelSize(chan.fRcvKernelSize) , fLinger(chan.fLinger) , fRateLogging(chan.fRateLogging) + , fPortRangeMin(chan.fPortRangeMin) + , fPortRangeMax(chan.fPortRangeMax) + , fAutoBind(chan.fAutoBind) , fName(chan.fName) , fIsValid(false) - , fTransportFactory(nullptr) , fMultipart(chan.fMultipart) , fModified(chan.fModified) , fReset(false) @@ -109,20 +84,23 @@ FairMQChannel::FairMQChannel(const FairMQChannel& chan) FairMQChannel& FairMQChannel::operator=(const FairMQChannel& chan) { + fTransportFactory = nullptr; + fTransportType = chan.fTransportType; fSocket = nullptr; fType = chan.fType; fMethod = chan.fMethod; fAddress = chan.fAddress; - fTransportType = chan.fTransportType; fSndBufSize = chan.fSndBufSize; fRcvBufSize = chan.fRcvBufSize; fSndKernelSize = chan.fSndKernelSize; fRcvKernelSize = chan.fRcvKernelSize; fLinger = chan.fLinger; fRateLogging = chan.fRateLogging; + fPortRangeMin = chan.fPortRangeMin; + fPortRangeMax = chan.fPortRangeMax; + fAutoBind = chan.fAutoBind; fName = chan.fName; fIsValid = false; - fTransportFactory = nullptr; fMultipart = chan.fMultipart; fModified = chan.fModified; fReset = false; @@ -136,20 +114,23 @@ FairMQSocket & FairMQChannel::GetSocket() const return *fSocket; } -string FairMQChannel::GetChannelName() const +string FairMQChannel::GetName() const { + lock_guard lock(fChannelMutex); return fName; } -string FairMQChannel::GetChannelPrefix() const +string FairMQChannel::GetPrefix() const { + lock_guard lock(fChannelMutex); string prefix = fName; prefix = prefix.erase(fName.rfind('[')); return prefix; } -string FairMQChannel::GetChannelIndex() const +string FairMQChannel::GetIndex() const { + lock_guard lock(fChannelMutex); string indexStr = fName; indexStr.erase(indexStr.rfind(']')); indexStr.erase(0, indexStr.rfind('[') + 1); @@ -246,6 +227,33 @@ try { throw ChannelConfigurationError(fair::mq::tools::ToString("failed to acquire lock: ", e.what())); } +int FairMQChannel::GetPortRangeMin() const +try { + lock_guard lock(fChannelMutex); + return fPortRangeMin; +} catch (exception& e) { + LOG(error) << "Exception caught in FairMQChannel::GetPortRangeMin: " << e.what(); + throw ChannelConfigurationError(fair::mq::tools::ToString("failed to acquire lock: ", e.what())); +} + +int FairMQChannel::GetPortRangeMax() const +try { + lock_guard lock(fChannelMutex); + return fPortRangeMax; +} catch (exception& e) { + LOG(error) << "Exception caught in FairMQChannel::GetPortRangeMax: " << e.what(); + throw ChannelConfigurationError(fair::mq::tools::ToString("failed to acquire lock: ", e.what())); +} + +bool FairMQChannel::GetAutoBind() const +try { + lock_guard lock(fChannelMutex); + return fAutoBind; +} catch (exception& e) { + LOG(error) << "Exception caught in FairMQChannel::GetAutoBind: " << e.what(); + throw ChannelConfigurationError(fair::mq::tools::ToString("failed to acquire lock: ", e.what())); +} + void FairMQChannel::UpdateType(const string& type) try { lock_guard lock(fChannelMutex); @@ -356,6 +364,39 @@ try { throw ChannelConfigurationError(fair::mq::tools::ToString("failed to acquire lock: ", e.what())); } +void FairMQChannel::UpdatePortRangeMin(const int minPort) +try { + lock_guard lock(fChannelMutex); + fIsValid = false; + fPortRangeMin = minPort; + fModified = true; +} catch (exception& e) { + LOG(error) << "Exception caught in FairMQChannel::UpdatePortRangeMin: " << e.what(); + throw ChannelConfigurationError(fair::mq::tools::ToString("failed to acquire lock: ", e.what())); +} + +void FairMQChannel::UpdatePortRangeMax(const int maxPort) +try { + lock_guard lock(fChannelMutex); + fIsValid = false; + fPortRangeMax = maxPort; + fModified = true; +} catch (exception& e) { + LOG(error) << "Exception caught in FairMQChannel::UpdatePortRangeMax: " << e.what(); + throw ChannelConfigurationError(fair::mq::tools::ToString("failed to acquire lock: ", e.what())); +} + +void FairMQChannel::UpdateAutoBind(const bool autobind) +try { + lock_guard lock(fChannelMutex); + fIsValid = false; + fAutoBind = autobind; + fModified = true; +} catch (exception& e) { + LOG(error) << "Exception caught in FairMQChannel::UpdateAutoBind: " << e.what(); + throw ChannelConfigurationError(fair::mq::tools::ToString("failed to acquire lock: ", e.what())); +} + auto FairMQChannel::SetModified(const bool modified) -> void try { lock_guard lock(fChannelMutex); @@ -365,14 +406,14 @@ try { throw ChannelConfigurationError(fair::mq::tools::ToString("failed to acquire lock: ", e.what())); } -void FairMQChannel::UpdateChannelName(const string& name) +void FairMQChannel::UpdateName(const string& name) try { lock_guard lock(fChannelMutex); fIsValid = false; fName = name; fModified = true; } catch (exception& e) { - LOG(error) << "Exception caught in FairMQChannel::UpdateChannelName: " << e.what(); + LOG(error) << "Exception caught in FairMQChannel::UpdateName: " << e.what(); throw ChannelConfigurationError(fair::mq::tools::ToString("failed to acquire lock: ", e.what())); } @@ -385,14 +426,13 @@ try { throw ChannelConfigurationError(fair::mq::tools::ToString("failed to acquire lock: ", e.what())); } -bool FairMQChannel::ValidateChannel() +bool FairMQChannel::Validate() try { lock_guard lock(fChannelMutex); stringstream ss; ss << "Validating channel '" << fName << "'... "; - if (fIsValid) - { + if (fIsValid) { ss << "ALREADY VALID"; LOG(debug) << ss.str(); return true; @@ -400,8 +440,7 @@ try { // validate socket type const set socketTypes{ "sub", "pub", "pull", "push", "req", "rep", "xsub", "xpub", "dealer", "router", "pair" }; - if (socketTypes.find(fType) == socketTypes.end()) - { + if (socketTypes.find(fType) == socketTypes.end()) { ss << "INVALID"; LOG(debug) << ss.str(); LOG(error) << "Invalid channel type: '" << fType << "'"; @@ -409,30 +448,22 @@ try { } // validate socket address - if (fAddress == "unspecified" || fAddress == "") - { + if (fAddress == "unspecified" || fAddress == "") { ss << "INVALID"; LOG(debug) << ss.str(); LOG(debug) << "invalid channel address: '" << fAddress << "'"; return false; - } - else - { + } else { vector endpoints; boost::algorithm::split(endpoints, fAddress, boost::algorithm::is_any_of(";")); - for (const auto endpoint : endpoints) - { + for (const auto endpoint : endpoints) { string address; - if (endpoint[0] == '@' || endpoint[0] == '+' || endpoint[0] == '>') - { + if (endpoint[0] == '@' || endpoint[0] == '+' || endpoint[0] == '>') { address = endpoint.substr(1); - } - else - { + } else { // we don't have a method modifier, check if the default method is set const set socketMethods{ "bind", "connect" }; - if (socketMethods.find(fMethod) == socketMethods.end()) - { + if (socketMethods.find(fMethod) == socketMethods.end()) { ss << "INVALID"; LOG(debug) << ss.str(); LOG(error) << "Invalid endpoint connection method: '" << fMethod << "' for " << endpoint; @@ -441,56 +472,43 @@ try { address = endpoint; } // check if address is a tcp or ipc address - if (address.compare(0, 6, "tcp://") == 0) - { + if (address.compare(0, 6, "tcp://") == 0) { // check if TCP address contains port delimiter string addressString = address.substr(6); - if (addressString.find(':') == string::npos) - { + if (addressString.find(':') == string::npos) { ss << "INVALID"; LOG(debug) << ss.str(); LOG(error) << "invalid channel address: '" << address << "' (missing port?)"; return false; } - } - else if (address.compare(0, 6, "ipc://") == 0) - { + } else if (address.compare(0, 6, "ipc://") == 0) { // check if IPC address is not empty string addressString = address.substr(6); - if (addressString == "") - { + if (addressString == "") { ss << "INVALID"; LOG(debug) << ss.str(); LOG(error) << "invalid channel address: '" << address << "' (empty IPC address?)"; return false; } - } - else if (address.compare(0, 9, "inproc://") == 0) - { + } else if (address.compare(0, 9, "inproc://") == 0) { // check if IPC address is not empty string addressString = address.substr(9); - if (addressString == "") - { + if (addressString == "") { ss << "INVALID"; LOG(debug) << ss.str(); LOG(error) << "invalid channel address: '" << address << "' (empty inproc address?)"; return false; } - } - else if (address.compare(0, 8, "verbs://") == 0) - { + } else if (address.compare(0, 8, "verbs://") == 0) { // check if IPC address is not empty string addressString = address.substr(8); - if (addressString == "") - { + if (addressString == "") { ss << "INVALID"; LOG(debug) << ss.str(); LOG(error) << "invalid channel address: '" << address << "' (empty verbs address?)"; return false; } - } - else - { + } else { // if neither TCP or IPC is specified, return invalid ss << "INVALID"; LOG(debug) << ss.str(); @@ -501,8 +519,7 @@ try { } // validate socket buffer size for sending - if (fSndBufSize < 0) - { + if (fSndBufSize < 0) { ss << "INVALID"; LOG(debug) << ss.str(); LOG(error) << "invalid channel send buffer size (cannot be negative): '" << fSndBufSize << "'"; @@ -510,8 +527,7 @@ try { } // validate socket buffer size for receiving - if (fRcvBufSize < 0) - { + if (fRcvBufSize < 0) { ss << "INVALID"; LOG(debug) << ss.str(); LOG(error) << "invalid channel receive buffer size (cannot be negative): '" << fRcvBufSize << "'"; @@ -519,8 +535,7 @@ try { } // validate socket kernel transmit size for sending - if (fSndKernelSize < 0) - { + if (fSndKernelSize < 0) { ss << "INVALID"; LOG(debug) << ss.str(); LOG(error) << "invalid channel send kernel transmit size (cannot be negative): '" << fSndKernelSize << "'"; @@ -528,8 +543,7 @@ try { } // validate socket kernel transmit size for receiving - if (fRcvKernelSize < 0) - { + if (fRcvKernelSize < 0) { ss << "INVALID"; LOG(debug) << ss.str(); LOG(error) << "invalid channel receive kernel transmit size (cannot be negative): '" << fRcvKernelSize << "'"; @@ -537,8 +551,7 @@ try { } // validate socket rate logging interval - if (fRateLogging < 0) - { + if (fRateLogging < 0) { ss << "INVALID"; LOG(debug) << ss.str(); LOG(error) << "invalid socket rate logging interval (cannot be negative): '" << fRateLogging << "'"; @@ -554,143 +567,76 @@ try { throw ChannelConfigurationError(fair::mq::tools::ToString(e.what())); } -void FairMQChannel::InitTransport(shared_ptr factory) -{ - fTransportFactory = factory; - fTransportType = factory->GetType(); -} - -void FairMQChannel::ResetChannel() +void FairMQChannel::Init() { lock_guard lock(fChannelMutex); - fIsValid = false; - // TODO: implement channel resetting -} - -int FairMQChannel::Send(unique_ptr& msg, int sndTimeoutInMs) -{ - CheckSendCompatibility(msg); - return fSocket->Send(msg, sndTimeoutInMs); -} -int FairMQChannel::Receive(unique_ptr& msg, int rcvTimeoutInMs) -{ - CheckReceiveCompatibility(msg); - return fSocket->Receive(msg, rcvTimeoutInMs); -} + fSocket = fTransportFactory->CreateSocket(fType, fName); -int FairMQChannel::SendAsync(unique_ptr& msg) -{ - CheckSendCompatibility(msg); - return fSocket->Send(msg, 0); -} + // set linger duration (how long socket should wait for outstanding transfers before shutdown) + fSocket->SetLinger(fLinger); -int FairMQChannel::ReceiveAsync(unique_ptr& msg) -{ - CheckReceiveCompatibility(msg); - return fSocket->Receive(msg, 0); -} + // set high water marks + fSocket->SetSndBufSize(fSndBufSize); + fSocket->SetRcvBufSize(fRcvBufSize); -int64_t FairMQChannel::Send(vector>& msgVec, int sndTimeoutInMs) -{ - CheckSendCompatibility(msgVec); - return fSocket->Send(msgVec, sndTimeoutInMs); -} - -int64_t FairMQChannel::Receive(vector>& msgVec, int rcvTimeoutInMs) -{ - CheckReceiveCompatibility(msgVec); - return fSocket->Receive(msgVec, rcvTimeoutInMs); -} - -int64_t FairMQChannel::SendAsync(vector>& msgVec) -{ - CheckSendCompatibility(msgVec); - return fSocket->Send(msgVec, 0); -} - -int64_t FairMQChannel::ReceiveAsync(vector>& msgVec) -{ - CheckReceiveCompatibility(msgVec); - return fSocket->Receive(msgVec, 0); -} - -FairMQChannel::~FairMQChannel() -{ + // set kernel transmit size (set it only if value is not the default value) + if (fSndKernelSize != 0) { + fSocket->SetSndKernelSize(fSndKernelSize); + } + if (fRcvKernelSize != 0) { + fSocket->SetRcvKernelSize(fRcvKernelSize); + } } -unsigned long FairMQChannel::GetBytesTx() const +bool FairMQChannel::ConnectEndpoint(const string& endpoint) { - return fSocket->GetBytesTx(); -} + lock_guard lock(fChannelMutex); -unsigned long FairMQChannel::GetBytesRx() const -{ - return fSocket->GetBytesRx(); + return fSocket->Connect(endpoint); } -unsigned long FairMQChannel::GetMessagesTx() const +bool FairMQChannel::BindEndpoint(string& endpoint) { - return fSocket->GetMessagesTx(); -} + lock_guard lock(fChannelMutex); -unsigned long FairMQChannel::GetMessagesRx() const -{ - return fSocket->GetMessagesRx(); -} + // try to bind to the configured port. If it fails, try random one (if AutoBind is on). + if (fSocket->Bind(endpoint)) { + return true; + } else { + if (fAutoBind) { + // number of attempts when choosing a random port + int numAttempts = 0; + int maxAttempts = 1000; + + // initialize random generator + default_random_engine generator(chrono::system_clock::now().time_since_epoch().count()); + uniform_int_distribution randomPort(fPortRangeMin, fPortRangeMax); + + do { + LOG(debug) << "Could not bind to configured (TCP) port, trying random port in range " << fPortRangeMin << "-" << fPortRangeMax; + ++numAttempts; + + if (numAttempts > maxAttempts) { + LOG(error) << "could not bind to any (TCP) port in the given range after " << maxAttempts << " attempts"; + return false; + } -void FairMQChannel::CheckSendCompatibility(FairMQMessagePtr& msg) -{ - if (fTransportType != msg->GetType()) - { - // LOG(debug) << "Channel type does not match message type. Creating wrapper"; - FairMQMessagePtr msgWrapper(NewMessage(msg->GetData(), - msg->GetSize(), - [](void* /*data*/, void* _msg) { delete static_cast(_msg); }, - msg.get() - )); - msg.release(); - msg = move(msgWrapper); - } -} + size_t pos = endpoint.rfind(':'); + endpoint = endpoint.substr(0, pos + 1) + fair::mq::tools::ToString(static_cast(randomPort(generator))); + } while (fSocket->Bind(endpoint)); -void FairMQChannel::CheckSendCompatibility(vector& msgVec) -{ - for (auto& msg : msgVec) - { - if (fTransportType != msg->GetType()) - { - // LOG(debug) << "Channel type does not match message type. Creating wrapper"; - FairMQMessagePtr msgWrapper(NewMessage(msg->GetData(), - msg->GetSize(), - [](void* /*data*/, void* _msg) { delete static_cast(_msg); }, - msg.get() - )); - msg.release(); - msg = move(msgWrapper); + return true; + } else { + return false; } } -} -void FairMQChannel::CheckReceiveCompatibility(FairMQMessagePtr& msg) -{ - if (fTransportType != msg->GetType()) - { - // LOG(debug) << "Channel type does not match message type. Creating wrapper"; - FairMQMessagePtr newMsg(NewMessage()); - msg = move(newMsg); - } } -void FairMQChannel::CheckReceiveCompatibility(vector& msgVec) +void FairMQChannel::ResetChannel() { - for (auto& msg : msgVec) - { - if (fTransportType != msg->GetType()) - { - // LOG(debug) << "Channel type does not match message type. Creating wrapper"; - FairMQMessagePtr newMsg(NewMessage()); - msg = move(newMsg); - } - } + lock_guard lock(fChannelMutex); + fIsValid = false; + // TODO: implement channel resetting } diff --git a/fairmq/FairMQChannel.h b/fairmq/FairMQChannel.h index a6bf3a318..804dc5faf 100644 --- a/fairmq/FairMQChannel.h +++ b/fairmq/FairMQChannel.h @@ -10,11 +10,12 @@ #define FAIRMQCHANNEL_H_ #include -#include // unique_ptr +#include // unique_ptr, shared_ptr #include #include #include #include +#include // std::move #include #include @@ -43,6 +44,14 @@ class FairMQChannel /// @param factory TransportFactory FairMQChannel(const std::string& name, const std::string& type, std::shared_ptr factory); + /// Constructor + /// @param name Channel name + /// @param type Socket type (push/pull/pub/sub/spub/xsub/pair/req/rep/dealer/router/) + /// @param method Socket method (bind/connect) + /// @param address Network address to bind/connect to (e.g. "tcp://127.0.0.1:5555" or "ipc://abc") + /// @param factory TransportFactory + FairMQChannel(const std::string& name, const std::string& type, const std::string& method, const std::string& address, std::shared_ptr factory); + /// Copy Constructor FairMQChannel(const FairMQChannel&); @@ -50,20 +59,20 @@ class FairMQChannel FairMQChannel& operator=(const FairMQChannel&); /// Default destructor - virtual ~FairMQChannel(); + virtual ~FairMQChannel() {} struct ChannelConfigurationError : std::runtime_error { using std::runtime_error::runtime_error; }; FairMQSocket& GetSocket() const; - auto Bind(const std::string& address) -> bool + bool Bind(const std::string& address) { fMethod = "bind"; fAddress = address; return fSocket->Bind(address); } - auto Connect(const std::string& address) -> void + bool Connect(const std::string& address) { fMethod = "connect"; fAddress = address; @@ -72,15 +81,18 @@ class FairMQChannel /// Get channel name /// @return Returns full channel name (e.g. "data[0]") - std::string GetChannelName() const; + std::string GetChannelName() const { return GetName(); } // TODO: deprecate this in favor of following + std::string GetName() const; /// Get channel prefix /// @return Returns channel prefix (e.g. "data" in "data[0]") - std::string GetChannelPrefix() const; + std::string GetChannelPrefix() const { return GetPrefix(); } // TODO: deprecate this in favor of following + std::string GetPrefix() const; /// Get channel index /// @return Returns channel index (e.g. 0 in "data[0]") - std::string GetChannelIndex() const; + std::string GetChannelIndex() const { return GetPrefix(); } // TODO: deprecate this in favor of following + std::string GetIndex() const; /// Get socket type /// @return Returns socket type (push/pull/pub/sub/spub/xsub/pair/req/rep/dealer/router/) @@ -122,6 +134,18 @@ class FairMQChannel /// @return Returns socket rate logging interval (in seconds) int GetRateLogging() const; + /// Get start of the port range for automatic binding + /// @return start of the port range + int GetPortRangeMin() const; + + /// Get end of the port range for automatic binding + /// @return end of the port range + int GetPortRangeMax() const; + + /// Set automatic binding (pick random port if bind fails) + /// @return true/false, true if automatic binding is enabled + bool GetAutoBind() const; + /// Set socket type /// @param type Socket type (push/pull/pub/sub/spub/xsub/pair/req/rep/dealer/router/) void UpdateType(const std::string& type); @@ -162,9 +186,22 @@ class FairMQChannel /// @param rateLogging Socket rate logging interval (in seconds) void UpdateRateLogging(const int rateLogging); + /// Set start of the port range for automatic binding + /// @param minPort start of the port range + void UpdatePortRangeMin(const int minPort); + + /// Set end of the port range for automatic binding + /// @param maxPort end of the port range + void UpdatePortRangeMax(const int maxPort); + + /// Set automatic binding (pick random port if bind fails) + /// @param autobind true/false, true to enable automatic binding + void UpdateAutoBind(const bool autobind); + /// Set channel name /// @param name Arbitrary channel name - void UpdateChannelName(const std::string& name); + void UpdateChannelName(const std::string& name) { UpdateName(name); } // TODO: deprecate this in favor of following + void UpdateName(const std::string& name); /// Checks if the configured channel settings are valid (checks the validity parameter, without running full validation (as oposed to ValidateChannel())) /// @return true if channel settings are valid, false otherwise. @@ -172,7 +209,20 @@ class FairMQChannel /// Validates channel configuration /// @return true if channel settings are valid, false otherwise. - bool ValidateChannel(); + bool ValidateChannel() // TODO: deprecate this + { + return Validate(); + } + + /// Validates channel configuration + /// @return true if channel settings are valid, false otherwise. + bool Validate(); + + void Init(); + + bool ConnectEndpoint(const std::string& endpoint); + + bool BindEndpoint(std::string& endpoint); /// Resets the channel (requires validation to be used again). void ResetChannel(); @@ -181,31 +231,63 @@ class FairMQChannel /// @param msg Constant reference of unique_ptr to a FairMQMessage /// @param sndTimeoutInMs send timeout in ms. -1 will wait forever (or until interrupt (e.g. via state change)), 0 will not wait (return immediately if cannot send) /// @return Number of bytes that have been queued. -2 If queueing was not possible or timed out. -1 if there was an error. - int Send(FairMQMessagePtr& msg, int sndTimeoutInMs = -1); + int Send(FairMQMessagePtr& msg, int sndTimeoutInMs = -1) + { + CheckSendCompatibility(msg); + return fSocket->Send(msg, sndTimeoutInMs); + } /// Receives a message from the socket queue. /// @param msg Constant reference of unique_ptr to a FairMQMessage /// @param rcvTimeoutInMs receive timeout in ms. -1 will wait forever (or until interrupt (e.g. via state change)), 0 will not wait (return immediately if cannot receive) /// @return Number of bytes that have been received. -2 if reading from the queue was not possible or timed out. -1 if there was an error. - int Receive(FairMQMessagePtr& msg, int rcvTimeoutInMs = -1); + int Receive(FairMQMessagePtr& msg, int rcvTimeoutInMs = -1) + { + CheckReceiveCompatibility(msg); + return fSocket->Receive(msg, rcvTimeoutInMs); + } - int SendAsync(FairMQMessagePtr& msg) __attribute__((deprecated("For non-blocking Send, use timeout version with timeout of 0: Send(msg, timeout);"))); - int ReceiveAsync(FairMQMessagePtr& msg) __attribute__((deprecated("For non-blocking Receive, use timeout version with timeout of 0: Receive(msg, timeout);"))); + int SendAsync(FairMQMessagePtr& msg) __attribute__((deprecated("For non-blocking Send, use timeout version with timeout of 0: Send(msg, timeout);"))) + { + CheckSendCompatibility(msg); + return fSocket->Send(msg, 0); + } + int ReceiveAsync(FairMQMessagePtr& msg) __attribute__((deprecated("For non-blocking Receive, use timeout version with timeout of 0: Receive(msg, timeout);"))) + { + CheckReceiveCompatibility(msg); + return fSocket->Receive(msg, 0); + } /// Send a vector of messages /// @param msgVec message vector reference /// @param sndTimeoutInMs send timeout in ms. -1 will wait forever (or until interrupt (e.g. via state change)), 0 will not wait (return immediately if cannot send) /// @return Number of bytes that have been queued. -2 If queueing was not possible or timed out. -1 if there was an error. - int64_t Send(std::vector& msgVec, int sndTimeoutInMs = -1); + int64_t Send(std::vector& msgVec, int sndTimeoutInMs = -1) + { + CheckSendCompatibility(msgVec); + return fSocket->Send(msgVec, sndTimeoutInMs); + } /// Receive a vector of messages /// @param msgVec message vector reference /// @param rcvTimeoutInMs receive timeout in ms. -1 will wait forever (or until interrupt (e.g. via state change)), 0 will not wait (return immediately if cannot receive) /// @return Number of bytes that have been received. -2 if reading from the queue was not possible or timed out. -1 if there was an error. - int64_t Receive(std::vector& msgVec, int rcvTimeoutInMs = -1); + int64_t Receive(std::vector& msgVec, int rcvTimeoutInMs = -1) + { + CheckReceiveCompatibility(msgVec); + return fSocket->Receive(msgVec, rcvTimeoutInMs); + } - int64_t SendAsync(std::vector& msgVec) __attribute__((deprecated("For non-blocking Send, use timeout version with timeout of 0: Send(msgVec, timeout);"))); - int64_t ReceiveAsync(std::vector& msgVec) __attribute__((deprecated("For non-blocking Receive, use timeout version with timeout of 0: Receive(msgVec, timeout);"))); + int64_t SendAsync(std::vector& msgVec) __attribute__((deprecated("For non-blocking Send, use timeout version with timeout of 0: Send(msgVec, timeout);"))) + { + CheckSendCompatibility(msgVec); + return fSocket->Send(msgVec, 0); + } + int64_t ReceiveAsync(std::vector& msgVec) __attribute__((deprecated("For non-blocking Receive, use timeout version with timeout of 0: Receive(msgVec, timeout);"))) + { + CheckReceiveCompatibility(msgVec); + return fSocket->Receive(msgVec, 0); + } /// Send FairMQParts /// @param parts FairMQParts reference @@ -235,10 +317,10 @@ class FairMQChannel return Receive(parts.fParts, 0); } - unsigned long GetBytesTx() const; - unsigned long GetBytesRx() const; - unsigned long GetMessagesTx() const; - unsigned long GetMessagesRx() const; + unsigned long GetBytesTx() const { return fSocket->GetBytesTx(); } + unsigned long GetBytesRx() const { return fSocket->GetBytesRx(); } + unsigned long GetMessagesTx() const { return fSocket->GetMessagesTx(); } + unsigned long GetMessagesRx() const { return fSocket->GetMessagesRx(); } auto Transport() -> FairMQTransportFactory* { @@ -264,31 +346,26 @@ class FairMQChannel } private: + std::shared_ptr fTransportFactory; + fair::mq::Transport fTransportType; std::unique_ptr fSocket; std::string fType; std::string fMethod; std::string fAddress; - fair::mq::Transport fTransportType; int fSndBufSize; int fRcvBufSize; int fSndKernelSize; int fRcvKernelSize; int fLinger; int fRateLogging; + int fPortRangeMin; + int fPortRangeMax; + bool fAutoBind; std::string fName; std::atomic fIsValid; - std::shared_ptr fTransportFactory; - - void CheckSendCompatibility(FairMQMessagePtr& msg); - void CheckSendCompatibility(std::vector& msgVec); - void CheckReceiveCompatibility(FairMQMessagePtr& msg); - void CheckReceiveCompatibility(std::vector& msgVec); - - void InitTransport(std::shared_ptr factory); - // use static mutex to make the class easily copyable // implication: same mutex is used for all instances of the class // this does not hurt much, because mutex is used only during initialization with very low contention @@ -297,8 +374,66 @@ class FairMQChannel bool fMultipart; bool fModified; - auto SetModified(const bool modified) -> void; bool fReset; + + void CheckSendCompatibility(FairMQMessagePtr& msg) + { + if (fTransportType != msg->GetType()) { + // LOG(debug) << "Channel type does not match message type. Creating wrapper"; + FairMQMessagePtr msgWrapper(NewMessage( + msg->GetData(), + msg->GetSize(), + [](void* /*data*/, void* _msg) { delete static_cast(_msg); }, + msg.get() + )); + msg.release(); + msg = move(msgWrapper); + } + } + + void CheckSendCompatibility(std::vector& msgVec) + { + for (auto& msg : msgVec) { + if (fTransportType != msg->GetType()) { + // LOG(debug) << "Channel type does not match message type. Creating wrapper"; + FairMQMessagePtr msgWrapper(NewMessage( + msg->GetData(), + msg->GetSize(), + [](void* /*data*/, void* _msg) { delete static_cast(_msg); }, + msg.get() + )); + msg.release(); + msg = move(msgWrapper); + } + } + } + + void CheckReceiveCompatibility(FairMQMessagePtr& msg) + { + if (fTransportType != msg->GetType()) { + // LOG(debug) << "Channel type does not match message type. Creating wrapper"; + FairMQMessagePtr newMsg(NewMessage()); + msg = move(newMsg); + } + } + + void CheckReceiveCompatibility(std::vector& msgVec) + { + for (auto& msg : msgVec) { + if (fTransportType != msg->GetType()) { + // LOG(debug) << "Channel type does not match message type. Creating wrapper"; + FairMQMessagePtr newMsg(NewMessage()); + msg = move(newMsg); + } + } + } + + void InitTransport(std::shared_ptr factory) + { + fTransportFactory = factory; + fTransportType = factory->GetType(); + } + auto SetModified(const bool modified) -> void; }; #endif /* FAIRMQCHANNEL_H_ */ diff --git a/fairmq/FairMQDevice.cxx b/fairmq/FairMQDevice.cxx index 3e08f29bf..fbf8a6e20 100644 --- a/fairmq/FairMQDevice.cxx +++ b/fairmq/FairMQDevice.cxx @@ -16,7 +16,6 @@ #include #include -#include #include #include #include @@ -55,8 +54,6 @@ FairMQDevice::FairMQDevice(FairMQProgOptions* config, const fair::mq::tools::Ver , fInternalConfig(config ? nullptr : fair::mq::tools::make_unique()) , fConfig(config ? config : fInternalConfig.get()) , fId() - , fPortRangeMin(22000) - , fPortRangeMax(32000) , fDefaultTransportType(fair::mq::Transport::DEFAULT) , fDataCallbacks(false) , fMsgInputs() @@ -80,8 +77,6 @@ void FairMQDevice::InitWrapper() { fId = fConfig->GetValue("id"); fRate = fConfig->GetValue("rate"); - fPortRangeMin = fConfig->GetValue("port-range-min"); - fPortRangeMax = fConfig->GetValue("port-range-max"); try { fDefaultTransportType = fair::mq::TransportTypes.at(fConfig->GetValue("transport")); @@ -91,15 +86,11 @@ void FairMQDevice::InitWrapper() throw; } - for (auto& c : fConfig->GetFairMQMap()) - { - if (fChannels.find(c.first) == fChannels.end()) - { + for (auto& c : fConfig->GetFairMQMap()) { + if (fChannels.find(c.first) == fChannels.end()) { LOG(debug) << "Inserting new device channel from config: " << c.first; fChannels.insert(c); - } - else - { + } else { LOG(debug) << "Updating existing device channel from config: " << c.first; fChannels[c.first] = c.second; } @@ -115,50 +106,44 @@ void FairMQDevice::InitWrapper() string networkInterface = fConfig->GetValue("network-interface"); // Fill the uninitialized channel containers - for (auto& mi : fChannels) - { - for (auto vi = mi.second.begin(); vi != mi.second.end(); ++vi) - { - // if (vi->fModified) - // { - // if (vi->fReset) - // { - // vi->fSocket.reset(); - // } - // set channel name: name + vector index - vi->fName = fair::mq::tools::ToString(mi.first, "[", vi - (mi.second).begin(), "]"); + for (auto& mi : fChannels) { + int subChannelIndex = 0; + for (auto& vi : mi.second) { + // set channel name: name + vector index + vi.fName = fair::mq::tools::ToString(mi.first, "[", subChannelIndex, "]"); + + // set channel transport + if (vi.fTransportType == fair::mq::Transport::DEFAULT || vi.fTransportType == fTransportFactory->GetType()) { + LOG(debug) << vi.fName << ": using default transport"; + vi.InitTransport(fTransportFactory); + } else { + LOG(debug) << vi.fName << ": channel transport (" << fair::mq::TransportNames.at(fDefaultTransportType) << ") overriden to " << fair::mq::TransportNames.at(vi.fTransportType); + vi.InitTransport(AddTransport(vi.fTransportType)); + } - if (vi->fMethod == "bind") - { - // if binding address is not specified, try getting it from the configured network interface - if (vi->fAddress == "unspecified" || vi->fAddress == "") - { - // if the configured network interface is default, get its name from the default route - if (networkInterface == "default") - { - networkInterface = fair::mq::tools::getDefaultRouteNetworkInterface(); - } - vi->fAddress = "tcp://" + fair::mq::tools::getInterfaceIP(networkInterface) + ":1"; + if (vi.fMethod == "bind") { + // if binding address is not specified, try getting it from the configured network interface + if (vi.fAddress == "unspecified" || vi.fAddress == "") { + // if the configured network interface is default, get its name from the default route + if (networkInterface == "default") { + networkInterface = fair::mq::tools::getDefaultRouteNetworkInterface(); } - // fill the uninitialized list - uninitializedBindingChannels.push_back(&(*vi)); - } - else if (vi->fMethod == "connect") - { - // fill the uninitialized list - uninitializedConnectingChannels.push_back(&(*vi)); - } - else if (vi->fAddress.find_first_of("@+>") != string::npos) - { - // fill the uninitialized list - uninitializedConnectingChannels.push_back(&(*vi)); - } - else - { - LOG(error) << "Cannot update configuration. Socket method (bind/connect) for channel '" << vi->fName << "' not specified."; - throw runtime_error(fair::mq::tools::ToString("Cannot update configuration. Socket method (bind/connect) for channel ", vi->fName, " not specified.")); + vi.fAddress = "tcp://" + fair::mq::tools::getInterfaceIP(networkInterface) + ":1"; } - // } + // fill the uninitialized list + uninitializedBindingChannels.push_back(&vi); + } else if (vi.fMethod == "connect") { + // fill the uninitialized list + uninitializedConnectingChannels.push_back(&vi); + } else if (vi.fAddress.find_first_of("@+>") != string::npos) { + // fill the uninitialized list + uninitializedConnectingChannels.push_back(&vi); + } else { + LOG(error) << "Cannot update configuration. Socket method (bind/connect) for channel '" << vi.fName << "' not specified."; + throw runtime_error(fair::mq::tools::ToString("Cannot update configuration. Socket method (bind/connect) for channel ", vi.fName, " not specified.")); + } + + subChannelIndex++; } } @@ -166,8 +151,7 @@ void FairMQDevice::InitWrapper() // If necessary this could be handled in the same way as the connecting channels AttachChannels(uninitializedBindingChannels); - if (!uninitializedBindingChannels.empty()) - { + if (!uninitializedBindingChannels.empty()) { LOG(error) << uninitializedBindingChannels.size() << " of the binding channels could not initialize. Initial configuration incomplete."; throw runtime_error(fair::mq::tools::ToString(uninitializedBindingChannels.size(), " of the binding channels could not initialize. Initial configuration incomplete.")); } @@ -183,22 +167,18 @@ void FairMQDevice::InitWrapper() // first attempt AttachChannels(uninitializedConnectingChannels); // if not all channels could be connected, update their address values from config and retry - while (!uninitializedConnectingChannels.empty()) - { + while (!uninitializedConnectingChannels.empty()) { this_thread::sleep_for(chrono::milliseconds(sleepTimeInMS)); - for (auto& chan : uninitializedConnectingChannels) - { + for (auto& chan : uninitializedConnectingChannels) { string key{"chans." + chan->GetChannelPrefix() + "." + chan->GetChannelIndex() + ".address"}; string newAddress = fConfig->GetValue(key); - if (newAddress != chan->GetAddress()) - { + if (newAddress != chan->GetAddress()) { chan->UpdateAddress(newAddress); } } - if (numAttempts++ > maxAttempts) - { + if (numAttempts++ > maxAttempts) { LOG(error) << "could not connect all channels after " << initializationTimeoutInS << " attempts"; throw runtime_error(fair::mq::tools::ToString("could not connect all channels after ", initializationTimeoutInS, " attempts")); } @@ -223,106 +203,54 @@ void FairMQDevice::AttachChannels(vector& chans) { auto itr = chans.begin(); - while (itr != chans.end()) - { - if ((*itr)->ValidateChannel()) - { - if (AttachChannel(**itr)) - { + while (itr != chans.end()) { + if ((*itr)->ValidateChannel()) { + (*itr)->Init(); + if (AttachChannel(**itr)) { (*itr)->SetModified(false); + // remove the channel from the uninitialized container itr = chans.erase(itr); - } - else - { + } else { LOG(error) << "failed to attach channel " << (*itr)->fName << " (" << (*itr)->fMethod << ")"; ++itr; } - } - else - { + } else { ++itr; } } } -bool FairMQDevice::AttachChannel(FairMQChannel& ch) +bool FairMQDevice::AttachChannel(FairMQChannel& chan) { - if (ch.fTransportType == fair::mq::Transport::DEFAULT || ch.fTransportType == fTransportFactory->GetType()) - { - LOG(debug) << ch.fName << ": using default transport"; - ch.InitTransport(fTransportFactory); - } - else - { - LOG(debug) << ch.fName << ": channel transport (" << fair::mq::TransportNames.at(fDefaultTransportType) << ") overriden to " << fair::mq::TransportNames.at(ch.fTransportType); - ch.InitTransport(AddTransport(ch.fTransportType)); - } - vector endpoints; - boost::algorithm::split(endpoints, ch.fAddress, boost::algorithm::is_any_of(",")); - for (auto& endpoint : endpoints) - { - //(re-)init socket - if (!ch.fSocket) - { - try - { - ch.fSocket = ch.fTransportFactory->CreateSocket(ch.fType, ch.fName); - } - catch (fair::mq::SocketError& se) - { - LOG(error) << se.what(); - return false; - } - } - - // set linger duration (how long socket should wait for outstanding transfers before shutdown) - ch.fSocket->SetLinger(ch.fLinger); - - // set high water marks - ch.fSocket->SetSndBufSize(ch.fSndBufSize); - ch.fSocket->SetRcvBufSize(ch.fRcvBufSize); - - // set kernel transmit size (set it only if value is not the default value) - if (ch.fSndKernelSize != 0) - { - ch.fSocket->SetSndKernelSize(ch.fSndKernelSize); - } - if (ch.fRcvKernelSize != 0) - { - ch.fSocket->SetRcvKernelSize(ch.fRcvKernelSize); - } + string chanAddress = chan.GetAddress(); + boost::algorithm::split(endpoints, chanAddress, boost::algorithm::is_any_of(",")); + for (auto& endpoint : endpoints) { // attach - bool bind = (ch.fMethod == "bind"); + bool bind = (chan.GetMethod() == "bind"); bool connectionModifier = false; string address = endpoint; // check if the default fMethod is overridden by a modifier - if (endpoint[0] == '+' || endpoint[0] == '>') - { + if (endpoint[0] == '+' || endpoint[0] == '>') { connectionModifier = true; bind = false; address = endpoint.substr(1); - } - else if (endpoint[0] == '@') - { + } else if (endpoint[0] == '@') { connectionModifier = true; bind = true; address = endpoint.substr(1); } - if (address.compare(0, 6, "tcp://") == 0) - { + if (address.compare(0, 6, "tcp://") == 0) { string addressString = address.substr(6); auto pos = addressString.find(':'); string hostPart = addressString.substr(0, pos); - if (!(bind && hostPart == "*")) - { + if (!(bind && hostPart == "*")) { string portPart = addressString.substr(pos + 1); string resolvedHost = fair::mq::tools::getIpFromHostname(hostPart); - if (resolvedHost == "") - { + if (resolvedHost == "") { return false; } address.assign("tcp://" + resolvedHost + ":" + portPart); @@ -331,76 +259,35 @@ bool FairMQDevice::AttachChannel(FairMQChannel& ch) bool success = true; // make the connection - if (bind) - { - success = BindEndpoint(*ch.fSocket, address); - } - else - { - success = ConnectEndpoint(*ch.fSocket, address); + if (bind) { + success = chan.BindEndpoint(address); + } else { + success = chan.ConnectEndpoint(address); } // bind might bind to an address different than requested, // put the actual address back in the config endpoint.clear(); - if (connectionModifier) - { + if (connectionModifier) { endpoint.push_back(bind?'@':'+'); } endpoint += address; - LOG(debug) << "Attached channel " << ch.fName << " to " << endpoint << (bind ? " (bind) " : " (connect) ") << "(" << ch.fType << ")"; - // after the book keeping is done, exit in case of errors - if (!success) - { + if (!success) { return success; + } else { + LOG(debug) << "Attached channel " << chan.GetName() << " to " << endpoint << (bind ? " (bind) " : " (connect) ") << "(" << chan.GetType() << ")"; } } // put the (possibly) modified address back in the channel object and config - string newAddress{boost::algorithm::join(endpoints, ",")}; - if (newAddress != ch.fAddress) - { - ch.UpdateAddress(newAddress); - string key{"chans." + ch.GetChannelPrefix() + "." + ch.GetChannelIndex() + ".address"}; - fConfig->SetValue(key, newAddress); - } - - return true; -} + string newAddress(boost::algorithm::join(endpoints, ",")); + if (newAddress != chanAddress) { + chan.UpdateAddress(newAddress); -bool FairMQDevice::ConnectEndpoint(FairMQSocket& socket, string& endpoint) -{ - socket.Connect(endpoint); - - return true; -} - -bool FairMQDevice::BindEndpoint(FairMQSocket& socket, string& endpoint) -{ - // number of attempts when choosing a random port - int maxAttempts = 1000; - int numAttempts = 0; - - // initialize random generator - default_random_engine generator(chrono::system_clock::now().time_since_epoch().count()); - uniform_int_distribution randomPort(fPortRangeMin, fPortRangeMax); - - // try to bind to the saved port. In case of failure, try random one. - while (!socket.Bind(endpoint)) - { - LOG(debug) << "Could not bind to configured (TCP) port, trying random port in range " << fPortRangeMin << "-" << fPortRangeMax; - ++numAttempts; - - if (numAttempts > maxAttempts) - { - LOG(error) << "could not bind to any (TCP) port in the given range after " << maxAttempts << " attempts"; - return false; - } - - size_t pos = endpoint.rfind(':'); - endpoint = endpoint.substr(0, pos + 1) + fair::mq::tools::ToString(static_cast(randomPort(generator))); + // update address in the config, it could have been modified during binding + fConfig->SetValue({"chans." + chan.GetPrefix() + "." + chan.GetIndex() + ".address"}, newAddress); } return true; @@ -445,27 +332,6 @@ void FairMQDevice::SortChannel(const string& name, const bool reindex) } } -void FairMQDevice::PrintChannel(const string& name) -{ - if (fChannels.find(name) != fChannels.end()) - { - for (const auto& vi : fChannels[name]) - { - LOG(info) << vi.fName << ": " - << vi.fType << " | " - << vi.fMethod << " | " - << vi.fAddress << " | " - << vi.fSndBufSize << " | " - << vi.fRcvBufSize << " | " - << vi.fRateLogging; - } - } - else - { - LOG(error) << "Printing failed: no channel with the name \"" << name << "\"."; - } -} - void FairMQDevice::RunWrapper() { CallStateChangeCallbacks(RUNNING); diff --git a/fairmq/FairMQDevice.h b/fairmq/FairMQDevice.h index 771d9c79d..2e87b7bc0 100644 --- a/fairmq/FairMQDevice.h +++ b/fairmq/FairMQDevice.h @@ -81,10 +81,6 @@ class FairMQDevice : public FairMQStateMachine /// @param reindex Should reindexing be done void SortChannel(const std::string& name, const bool reindex = true); - /// Prints channel configuration - /// @param name Name of the channel - void PrintChannel(const std::string& name); - template void Serialize(FairMQMessage& msg, DataType&& data, Args&&... args) const { @@ -373,12 +369,6 @@ class FairMQDevice : public FairMQStateMachine void SetNumIoThreads(int numIoThreads) { fConfig->SetValue("io-threads", numIoThreads);} int GetNumIoThreads() const { return fConfig->GetValue("io-threads"); } - void SetPortRangeMin(int portRangeMin) { fConfig->SetValue("port-range-min", portRangeMin); } - int GetPortRangeMin() const { return fConfig->GetValue("port-range-min"); } - - void SetPortRangeMax(int portRangeMax) { fConfig->SetValue("port-range-max", portRangeMax); } - int GetPortRangeMax() const { return fConfig->GetValue("port-range-max"); } - void SetNetworkInterface(const std::string& networkInterface) { fConfig->SetValue("network-interface", networkInterface); } std::string GetNetworkInterface() const { return fConfig->GetValue("network-interface"); } @@ -458,9 +448,6 @@ class FairMQDevice : public FairMQStateMachine virtual void Reset(); private: - int fPortRangeMin; ///< Minimum value for the port range (if dynamic) - int fPortRangeMax; ///< Maximum value for the port range (if dynamic) - fair::mq::Transport fDefaultTransportType; ///< Default transport for the device /// Handles the initialization and the Init() method @@ -484,15 +471,6 @@ class FairMQDevice : public FairMQStateMachine /// Attach (bind/connect) channels in the list void AttachChannels(std::vector& chans); - - /// Sets up and connects/binds a socket to an endpoint - /// return a string with the actual endpoint if it happens - /// to stray from default. - bool ConnectEndpoint(FairMQSocket& socket, std::string& endpoint); - bool BindEndpoint(FairMQSocket& socket, std::string& endpoint); - /// Attaches the channel to all listed endpoints - /// the list is comma separated; the default method (bind/connect) is used. - /// to override default: prepend "@" to bind, "+" or ">" to connect endpoint. bool AttachChannel(FairMQChannel& ch); void HandleSingleChannelInput(); diff --git a/fairmq/FairMQSocket.h b/fairmq/FairMQSocket.h index 40f70843c..2174c97f2 100644 --- a/fairmq/FairMQSocket.h +++ b/fairmq/FairMQSocket.h @@ -23,7 +23,7 @@ class FairMQSocket virtual std::string GetId() = 0; virtual bool Bind(const std::string& address) = 0; - virtual void Connect(const std::string& address) = 0; + virtual bool Connect(const std::string& address) = 0; virtual int Send(FairMQMessagePtr& msg, int timeout = -1) = 0; virtual int Receive(FairMQMessagePtr& msg, int timeout = -1) = 0; diff --git a/fairmq/nanomsg/FairMQSocketNN.cxx b/fairmq/nanomsg/FairMQSocketNN.cxx index dbbf28a7d..d759cb0be 100644 --- a/fairmq/nanomsg/FairMQSocketNN.cxx +++ b/fairmq/nanomsg/FairMQSocketNN.cxx @@ -99,24 +99,26 @@ bool FairMQSocketNN::Bind(const string& address) { // LOG(info) << "bind socket " << fId << " on " << address; - int eid = nn_bind(fSocket, address.c_str()); - if (eid < 0) + if (nn_bind(fSocket, address.c_str()) < 0) { LOG(error) << "failed binding socket " << fId << ", reason: " << nn_strerror(errno); return false; } + return true; } -void FairMQSocketNN::Connect(const string& address) +bool FairMQSocketNN::Connect(const string& address) { // LOG(info) << "connect socket " << fId << " to " << address; - int eid = nn_connect(fSocket, address.c_str()); - if (eid < 0) + if (nn_connect(fSocket, address.c_str()) < 0) { LOG(error) << "failed connecting socket " << fId << ", reason: " << nn_strerror(errno); + return false; } + + return true; } int FairMQSocketNN::Send(FairMQMessagePtr& msg, const int timeout) diff --git a/fairmq/nanomsg/FairMQSocketNN.h b/fairmq/nanomsg/FairMQSocketNN.h index e24df2139..525266b06 100644 --- a/fairmq/nanomsg/FairMQSocketNN.h +++ b/fairmq/nanomsg/FairMQSocketNN.h @@ -25,7 +25,7 @@ class FairMQSocketNN final : public FairMQSocket std::string GetId() override; bool Bind(const std::string& address) override; - void Connect(const std::string& address) override; + bool Connect(const std::string& address) override; int Send(FairMQMessagePtr& msg, const int timeout = -1) override; int Receive(FairMQMessagePtr& msg, const int timeout = -1) override; diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index d07ea5ffe..61a6537f5 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -107,13 +107,14 @@ catch (const SocketError& e) return false; } -auto Socket::Connect(const string& address) -> void +auto Socket::Connect(const string& address) -> bool { auto addr = Context::VerifyAddress(address); ConnectControlSocket(addr); fContext.InitOfi(ConnectionType::Connect, addr); InitDataEndpoint(); fWaitingForControlPeer = true; + return true; } auto Socket::BindControlSocket(Context::Address address) -> void diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index f380cbdc6..9cc93adfc 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -42,7 +42,7 @@ class Socket final : public fair::mq::Socket auto GetId() -> std::string { return fId; } auto Bind(const std::string& address) -> bool override; - auto Connect(const std::string& address) -> void override; + auto Connect(const std::string& address) -> bool override; auto Send(MessagePtr& msg, int timeout = 0) -> int override; auto Receive(MessagePtr& msg, int timeout = 0) -> int override; diff --git a/fairmq/options/FairMQParser.cxx b/fairmq/options/FairMQParser.cxx index 0d44aaaed..70c91b5e6 100644 --- a/fairmq/options/FairMQParser.cxx +++ b/fairmq/options/FairMQParser.cxx @@ -32,8 +32,7 @@ namespace parser // function that convert property tree (given the json structure) to FairMQChannelMap FairMQChannelMap ptreeToMQMap(const boost::property_tree::ptree& pt, const string& id, const string& rootNode) { - if (id == "") - { + if (id == "") { throw ParserError("no device ID provided. Provide with `--id` cmd option"); } @@ -44,8 +43,7 @@ FairMQChannelMap ptreeToMQMap(const boost::property_tree::ptree& pt, const strin // Extract value from boost::property_tree Helper::DeviceParser(pt.get_child(rootNode), channelMap, id); - if (channelMap.empty()) - { + if (channelMap.empty()) { LOG(warn) << "---- No channel keys found for " << id; LOG(warn) << "---- Check the JSON inputs and/or command line inputs"; } @@ -68,20 +66,14 @@ void PrintDeviceList(const boost::property_tree::ptree& tree) string deviceIdKey; // do a first loop just to print the device-id in json input - for (const auto& p : tree) - { - if (p.first == "devices") - { - for (const auto& q : p.second.get_child("")) - { + for (const auto& p : tree) { + if (p.first == "devices") { + for (const auto& q : p.second.get_child("")) { string key = q.second.get("key", ""); - if (key != "") - { + if (key != "") { deviceIdKey = key; LOG(debug) << "Found config for device key '" << deviceIdKey << "' in JSON input"; - } - else - { + } else { deviceIdKey = q.second.get("id"); LOG(debug) << "Found config for device id '" << deviceIdKey << "' in JSON input"; } @@ -95,33 +87,26 @@ void DeviceParser(const boost::property_tree::ptree& tree, FairMQChannelMap& cha string deviceIdKey; // For each node in fairMQOptions - for (const auto& p : tree) - { - if (p.first == "devices") - { - for (const auto& q : p.second) - { + for (const auto& p : tree) { + if (p.first == "devices") { + for (const auto& q : p.second) { // check if key is provided, otherwise use id string key = q.second.get("key", ""); - if (key != "") - { + if (key != "") { deviceIdKey = key; - // LOG(debug) << "Found config for device key '" << deviceIdKey << "' in JSON input"; - } - else - { + // LOG(trace) << "Found config for device key '" << deviceIdKey << "' in JSON input"; + } else { deviceIdKey = q.second.get("id"); - // LOG(debug) << "Found config for device id '" << deviceIdKey << "' in JSON input"; + // LOG(trace) << "Found config for device id '" << deviceIdKey << "' in JSON input"; } // if not correct device id, do not fill MQMap - if (deviceId != deviceIdKey) - { + if (deviceId != deviceIdKey) { continue; } - LOG(debug) << "Found following channels for device ID '" << deviceId << "' :"; + LOG(trace) << "Found following channels for device ID '" << deviceId << "' :"; ChannelParser(q.second, channelMap); } @@ -133,12 +118,9 @@ void ChannelParser(const boost::property_tree::ptree& tree, FairMQChannelMap& ch { string channelKey; - for (const auto& p : tree) - { - if (p.first == "channels") - { - for (const auto& q : p.second) - { + for (const auto& p : tree) { + if (p.first == "channels") { + for (const auto& q : p.second) { channelKey = q.second.get("name"); int numSockets = q.second.get("numSockets", 0); @@ -154,36 +136,37 @@ void ChannelParser(const boost::property_tree::ptree& tree, FairMQChannelMap& ch commonChannel.UpdateSndKernelSize(q.second.get("sndKernelSize", commonChannel.GetSndKernelSize())); commonChannel.UpdateRcvKernelSize(q.second.get("rcvKernelSize", commonChannel.GetRcvKernelSize())); commonChannel.UpdateLinger(q.second.get("linger", commonChannel.GetLinger())); - commonChannel.UpdateRateLogging(q.second.get("rateLogging", commonChannel.GetRateLogging())); + commonChannel.UpdatePortRangeMin(q.second.get("portRangeMin", commonChannel.GetPortRangeMin())); + commonChannel.UpdatePortRangeMax(q.second.get("portRangeMax", commonChannel.GetPortRangeMax())); + commonChannel.UpdateAutoBind(q.second.get("autoBind", commonChannel.GetAutoBind())); // temporary FairMQChannel container vector channelList; - if (numSockets > 0) - { - LOG(debug) << "" << channelKey << ":"; - LOG(debug) << "\tnumSockets of " << numSockets << " specified,"; - LOG(debug) << "\tapplying common settings to each:"; - - LOG(debug) << "\ttype = " << commonChannel.GetType(); - LOG(debug) << "\tmethod = " << commonChannel.GetMethod(); - LOG(debug) << "\taddress = " << commonChannel.GetAddress(); - LOG(debug) << "\ttransport = " << commonChannel.GetTransportName(); - LOG(debug) << "\tsndBufSize = " << commonChannel.GetSndBufSize(); - LOG(debug) << "\trcvBufSize = " << commonChannel.GetRcvBufSize(); - LOG(debug) << "\tsndKernelSize = " << commonChannel.GetSndKernelSize(); - LOG(debug) << "\trcvKernelSize = " << commonChannel.GetRcvKernelSize(); - LOG(debug) << "\tlinger = " << commonChannel.GetLinger(); - LOG(debug) << "\trateLogging = " << commonChannel.GetRateLogging(); - - for (int i = 0; i < numSockets; ++i) - { + if (numSockets > 0) { + LOG(trace) << "" << channelKey << ":"; + LOG(trace) << "\tnumSockets of " << numSockets << " specified,"; + LOG(trace) << "\tapplying common settings to each:"; + + LOG(trace) << "\ttype = " << commonChannel.GetType(); + LOG(trace) << "\tmethod = " << commonChannel.GetMethod(); + LOG(trace) << "\taddress = " << commonChannel.GetAddress(); + LOG(trace) << "\ttransport = " << commonChannel.GetTransportName(); + LOG(trace) << "\tsndBufSize = " << commonChannel.GetSndBufSize(); + LOG(trace) << "\trcvBufSize = " << commonChannel.GetRcvBufSize(); + LOG(trace) << "\tsndKernelSize = " << commonChannel.GetSndKernelSize(); + LOG(trace) << "\trcvKernelSize = " << commonChannel.GetRcvKernelSize(); + LOG(trace) << "\tlinger = " << commonChannel.GetLinger(); + LOG(trace) << "\trateLogging = " << commonChannel.GetRateLogging(); + LOG(trace) << "\tportRangeMin = " << commonChannel.GetPortRangeMin(); + LOG(trace) << "\tportRangeMax = " << commonChannel.GetPortRangeMax(); + LOG(trace) << "\tautoBind = " << commonChannel.GetAutoBind(); + + for (int i = 0; i < numSockets; ++i) { FairMQChannel channel(commonChannel); channelList.push_back(channel); } - } - else - { + } else { SocketParser(q.second.get_child(""), channelList, channelKey, commonChannel); } @@ -198,12 +181,9 @@ void SocketParser(const boost::property_tree::ptree& tree, vector // for each socket in channel int socketCounter = 0; - for (const auto& p : tree) - { - if (p.first == "sockets") - { - for (const auto& q : p.second) - { + for (const auto& p : tree) { + if (p.first == "sockets") { + for (const auto& q : p.second) { // create new channel and apply setting from the common channel FairMQChannel channel(commonChannel); @@ -218,18 +198,24 @@ void SocketParser(const boost::property_tree::ptree& tree, vector channel.UpdateRcvKernelSize(q.second.get("rcvKernelSize", channel.GetRcvKernelSize())); channel.UpdateLinger(q.second.get("linger", channel.GetLinger())); channel.UpdateRateLogging(q.second.get("rateLogging", channel.GetRateLogging())); - - LOG(debug) << "" << channelName << "[" << socketCounter << "]:"; - LOG(debug) << "\ttype = " << channel.GetType(); - LOG(debug) << "\tmethod = " << channel.GetMethod(); - LOG(debug) << "\taddress = " << channel.GetAddress(); - LOG(debug) << "\ttransport = " << channel.GetTransportName(); - LOG(debug) << "\tsndBufSize = " << channel.GetSndBufSize(); - LOG(debug) << "\trcvBufSize = " << channel.GetRcvBufSize(); - LOG(debug) << "\tsndKernelSize = " << channel.GetSndKernelSize(); - LOG(debug) << "\trcvKernelSize = " << channel.GetRcvKernelSize(); - LOG(debug) << "\tlinger = " << channel.GetLinger(); - LOG(debug) << "\trateLogging = " << channel.GetRateLogging(); + channel.UpdatePortRangeMin(q.second.get("portRangeMin", channel.GetPortRangeMin())); + channel.UpdatePortRangeMax(q.second.get("portRangeMax", channel.GetPortRangeMax())); + channel.UpdateAutoBind(q.second.get("autoBind", channel.GetAutoBind())); + + LOG(trace) << "" << channelName << "[" << socketCounter << "]:"; + LOG(trace) << "\ttype = " << channel.GetType(); + LOG(trace) << "\tmethod = " << channel.GetMethod(); + LOG(trace) << "\taddress = " << channel.GetAddress(); + LOG(trace) << "\ttransport = " << channel.GetTransportName(); + LOG(trace) << "\tsndBufSize = " << channel.GetSndBufSize(); + LOG(trace) << "\trcvBufSize = " << channel.GetRcvBufSize(); + LOG(trace) << "\tsndKernelSize = " << channel.GetSndKernelSize(); + LOG(trace) << "\trcvKernelSize = " << channel.GetRcvKernelSize(); + LOG(trace) << "\tlinger = " << channel.GetLinger(); + LOG(trace) << "\trateLogging = " << channel.GetRateLogging(); + LOG(trace) << "\tportRangeMin = " << channel.GetPortRangeMin(); + LOG(trace) << "\tportRangeMax = " << channel.GetPortRangeMax(); + LOG(trace) << "\tautoBind = " << channel.GetAutoBind(); channelList.push_back(channel); ++socketCounter; @@ -237,28 +223,28 @@ void SocketParser(const boost::property_tree::ptree& tree, vector } } // end socket loop - if (socketCounter) - { - LOG(debug) << "Found " << socketCounter << " socket(s) in channel."; - } - else - { - LOG(debug) << "" << channelName << ":"; - LOG(debug) << "\tNo sockets specified,"; - LOG(debug) << "\tapplying common settings to the channel:"; + if (socketCounter) { + LOG(trace) << "Found " << socketCounter << " socket(s) in channel."; + } else { + LOG(trace) << "" << channelName << ":"; + LOG(trace) << "\tNo sockets specified,"; + LOG(trace) << "\tapplying common settings to the channel:"; FairMQChannel channel(commonChannel); - LOG(debug) << "\ttype = " << channel.GetType(); - LOG(debug) << "\tmethod = " << channel.GetMethod(); - LOG(debug) << "\taddress = " << channel.GetAddress(); - LOG(debug) << "\ttransport = " << channel.GetTransportName(); - LOG(debug) << "\tsndBufSize = " << channel.GetSndBufSize(); - LOG(debug) << "\trcvBufSize = " << channel.GetRcvBufSize(); - LOG(debug) << "\tsndKernelSize = " << channel.GetSndKernelSize(); - LOG(debug) << "\trcvKernelSize = " << channel.GetRcvKernelSize(); - LOG(debug) << "\tlinger = " << channel.GetLinger(); - LOG(debug) << "\trateLogging = " << channel.GetRateLogging(); + LOG(trace) << "\ttype = " << channel.GetType(); + LOG(trace) << "\tmethod = " << channel.GetMethod(); + LOG(trace) << "\taddress = " << channel.GetAddress(); + LOG(trace) << "\ttransport = " << channel.GetTransportName(); + LOG(trace) << "\tsndBufSize = " << channel.GetSndBufSize(); + LOG(trace) << "\trcvBufSize = " << channel.GetRcvBufSize(); + LOG(trace) << "\tsndKernelSize = " << channel.GetSndKernelSize(); + LOG(trace) << "\trcvKernelSize = " << channel.GetRcvKernelSize(); + LOG(trace) << "\tlinger = " << channel.GetLinger(); + LOG(trace) << "\trateLogging = " << channel.GetRateLogging(); + LOG(trace) << "\tportRangeMin = " << channel.GetPortRangeMin(); + LOG(trace) << "\tportRangeMax = " << channel.GetPortRangeMax(); + LOG(trace) << "\tautoBind = " << channel.GetAutoBind(); channelList.push_back(channel); } diff --git a/fairmq/options/FairMQProgOptions.cxx b/fairmq/options/FairMQProgOptions.cxx index 0a85aaa45..55cfd3e49 100644 --- a/fairmq/options/FairMQProgOptions.cxx +++ b/fairmq/options/FairMQProgOptions.cxx @@ -62,8 +62,6 @@ FairMQProgOptions::FairMQProgOptions() ("network-interface", po::value()->default_value("default"), "Network interface to bind on (e.g. eth0, ib0..., default will try to detect the interface of the default route).") ("config-key", po::value(), "Use provided value instead of device id for fetching the configuration from the config file.") ("initialization-timeout", po::value()->default_value(120), "Timeout for the initialization in seconds (when expecting dynamic initialization).") - ("port-range-min", po::value()->default_value(22000), "Start of the port range for dynamic initialization.") - ("port-range-max", po::value()->default_value(32000), "End of the port range for dynamic initialization.") ("print-channels", po::value()->implicit_value(true), "Print registered channel endpoints in a machine-readable format (::)") ("shm-segment-size", po::value()->default_value(2000000000), "Shared memory: size of the shared memory segment (in bytes).") ("shm-monitor", po::value()->default_value(true), "Shared memory: run monitor daemon.") @@ -170,7 +168,6 @@ int FairMQProgOptions::ParseAll(const int argc, char const* const* argv, bool al { LOG(warn) << "--" << p->canonical_display_name(); } - LOG(warn) << "No channels will be created (You can create them manually)."; } } catch (exception& e) @@ -270,12 +267,10 @@ void FairMQProgOptions::UpdateChannelInfo() // create key for variable map as follow : channelName.index.memberName void FairMQProgOptions::UpdateMQValues() { - for (const auto& p : fFairMQChannelMap) - { + for (const auto& p : fFairMQChannelMap) { int index = 0; - for (const auto& channel : p.second) - { + for (const auto& channel : p.second) { string typeKey = "chans." + p.first + "." + to_string(index) + ".type"; string methodKey = "chans." + p.first + "." + to_string(index) + ".method"; string addressKey = "chans." + p.first + "." + to_string(index) + ".address"; @@ -286,6 +281,9 @@ void FairMQProgOptions::UpdateMQValues() string rcvKernelSizeKey = "chans." + p.first + "." + to_string(index) + ".rcvKernelSize"; string lingerKey = "chans." + p.first + "." + to_string(index) + ".linger"; string rateLoggingKey = "chans." + p.first + "." + to_string(index) + ".rateLogging"; + string portRangeMinKey = "chans." + p.first + "." + to_string(index) + ".portRangeMin"; + string portRangeMaxKey = "chans." + p.first + "." + to_string(index) + ".portRangeMax"; + string autoBindKey = "chans." + p.first + "." + to_string(index) + ".autoBind"; fChannelKeyMap[typeKey] = ChannelKey{p.first, index, "type"}; fChannelKeyMap[methodKey] = ChannelKey{p.first, index, "method"}; @@ -297,6 +295,9 @@ void FairMQProgOptions::UpdateMQValues() fChannelKeyMap[rcvKernelSizeKey] = ChannelKey{p.first, index, "rcvkernelSize"}; fChannelKeyMap[lingerKey] = ChannelKey{p.first, index, "linger"}; fChannelKeyMap[rateLoggingKey] = ChannelKey{p.first, index, "rateLogging"}; + fChannelKeyMap[portRangeMinKey] = ChannelKey{p.first, index, "portRangeMin"}; + fChannelKeyMap[portRangeMaxKey] = ChannelKey{p.first, index, "portRangeMax"}; + fChannelKeyMap[autoBindKey] = ChannelKey{p.first, index, "autoBind"}; UpdateVarMap(typeKey, channel.GetType()); UpdateVarMap(methodKey, channel.GetMethod()); @@ -308,86 +309,67 @@ void FairMQProgOptions::UpdateMQValues() UpdateVarMap(rcvKernelSizeKey, channel.GetRcvKernelSize()); UpdateVarMap(lingerKey, channel.GetLinger()); UpdateVarMap(rateLoggingKey, channel.GetRateLogging()); + UpdateVarMap(portRangeMinKey, channel.GetPortRangeMin()); + UpdateVarMap(portRangeMaxKey, channel.GetPortRangeMax()); + UpdateVarMap(autoBindKey, channel.GetAutoBind()); index++; } + UpdateVarMap("chans." + p.first + ".numSockets", index); } } int FairMQProgOptions::UpdateChannelValue(const string& channelName, int index, const string& member, const string& val) { - if (member == "type") - { + if (member == "type") { fFairMQChannelMap.at(channelName).at(index).UpdateType(val); - return 0; - } - - if (member == "method") - { + } else if (member == "method") { fFairMQChannelMap.at(channelName).at(index).UpdateMethod(val); - return 0; - } - - if (member == "address") - { + } else if (member == "address") { fFairMQChannelMap.at(channelName).at(index).UpdateAddress(val); - return 0; - } - - if (member == "transport") - { + } else if (member == "transport") { fFairMQChannelMap.at(channelName).at(index).UpdateTransport(val); - return 0; - } - else - { - //if we get there it means something is wrong + } else { LOG(error) << "update of FairMQChannel map failed for the following key: " << channelName << "." << index << "." << member; return 1; } + + return 0; } int FairMQProgOptions::UpdateChannelValue(const string& channelName, int index, const string& member, int val) { - if (member == "sndBufSize") - { + if (member == "sndBufSize") { fFairMQChannelMap.at(channelName).at(index).UpdateSndBufSize(val); - return 0; - } - - if (member == "rcvBufSize") - { + } else if (member == "rcvBufSize") { fFairMQChannelMap.at(channelName).at(index).UpdateRcvBufSize(val); - return 0; - } - - if (member == "sndKernelSize") - { + } else if (member == "sndKernelSize") { fFairMQChannelMap.at(channelName).at(index).UpdateSndKernelSize(val); - return 0; - } - - if (member == "rcvKernelSize") - { + } else if (member == "rcvKernelSize") { fFairMQChannelMap.at(channelName).at(index).UpdateRcvKernelSize(val); - return 0; - } - - if (member == "linger") - { + } else if (member == "linger") { fFairMQChannelMap.at(channelName).at(index).UpdateLinger(val); - return 0; + } else if (member == "rateLogging") { + fFairMQChannelMap.at(channelName).at(index).UpdateRateLogging(val); + } else if (member == "portRangeMin") { + fFairMQChannelMap.at(channelName).at(index).UpdatePortRangeMin(val); + } else if (member == "portRangeMax") { + fFairMQChannelMap.at(channelName).at(index).UpdatePortRangeMax(val); + } else { + LOG(error) << "update of FairMQChannel map failed for the following key: " << channelName << "." << index << "." << member; + return 1; } - if (member == "rateLogging") - { - fFairMQChannelMap.at(channelName).at(index).UpdateRateLogging(val); + return 0; +} + +int FairMQProgOptions::UpdateChannelValue(const string& channelName, int index, const string& member, bool val) +{ + if (member == "autoBind") { + fFairMQChannelMap.at(channelName).at(index).UpdateAutoBind(val); return 0; - } - else - { - // if we get there it means something is wrong + } else { LOG(error) << "update of FairMQChannel map failed for the following key: " << channelName << "." << index << "." << member; return 1; } diff --git a/fairmq/options/FairMQProgOptions.h b/fairmq/options/FairMQProgOptions.h index e01f15435..1f73e72bd 100644 --- a/fairmq/options/FairMQProgOptions.h +++ b/fairmq/options/FairMQProgOptions.h @@ -218,6 +218,7 @@ class FairMQProgOptions } int UpdateChannelValue(const std::string& channelName, int index, const std::string& member, const std::string& val); int UpdateChannelValue(const std::string& channelName, int index, const std::string& member, int val); + int UpdateChannelValue(const std::string& channelName, int index, const std::string& member, bool val); void UpdateChannelInfo(); diff --git a/fairmq/options/FairMQSuboptParser.h b/fairmq/options/FairMQSuboptParser.h index 7d6377edb..3a27a6cff 100644 --- a/fairmq/options/FairMQSuboptParser.h +++ b/fairmq/options/FairMQSuboptParser.h @@ -58,6 +58,9 @@ struct SUBOPT RCVKERNELSIZE, LINGER, RATELOGGING, // logging rate + PORTRANGEMIN, + PORTRANGEMAX, + AUTOBIND, NUMSOCKETS, lastsocketkey }; @@ -74,6 +77,9 @@ struct SUBOPT /*[RCVKERNELSIZE] = */ "rcvKernelSize", /*[LINGER] = */ "linger", /*[RATELOGGING] = */ "rateLogging", + /*[PORTRANGEMIN] = */ "portRangeMin", + /*[PORTRANGEMAX] = */ "portRangeMax", + /*[AUTOBIND] = */ "autoBind", /*[NUMSOCKETS] = */ "numSockets", nullptr }; diff --git a/fairmq/shmem/FairMQSocketSHM.cxx b/fairmq/shmem/FairMQSocketSHM.cxx index 97c26303f..46b0d2d6b 100644 --- a/fairmq/shmem/FairMQSocketSHM.cxx +++ b/fairmq/shmem/FairMQSocketSHM.cxx @@ -99,16 +99,17 @@ bool FairMQSocketSHM::Bind(const string& address) return true; } -void FairMQSocketSHM::Connect(const string& address) +bool FairMQSocketSHM::Connect(const string& address) { // LOG(info) << "connect socket " << fId << " on " << address; if (zmq_connect(fSocket, address.c_str()) != 0) { LOG(error) << "Failed connecting socket " << fId << ", reason: " << zmq_strerror(errno); - // error here means incorrect configuration. exit if it happens. - exit(EXIT_FAILURE); + return false; } + + return true; } int FairMQSocketSHM::Send(FairMQMessagePtr& msg, const int timeout) diff --git a/fairmq/shmem/FairMQSocketSHM.h b/fairmq/shmem/FairMQSocketSHM.h index 0f5810283..39d8064aa 100644 --- a/fairmq/shmem/FairMQSocketSHM.h +++ b/fairmq/shmem/FairMQSocketSHM.h @@ -26,7 +26,7 @@ class FairMQSocketSHM final : public FairMQSocket std::string GetId() override { return fId; } bool Bind(const std::string& address) override; - void Connect(const std::string& address) override; + bool Connect(const std::string& address) override; int Send(FairMQMessagePtr& msg, const int timeout = -1) override; int Receive(FairMQMessagePtr& msg, const int timeout = -1) override; diff --git a/fairmq/tools/Process.cxx b/fairmq/tools/Process.cxx index 942691261..da5468e32 100644 --- a/fairmq/tools/Process.cxx +++ b/fairmq/tools/Process.cxx @@ -13,6 +13,7 @@ #include #include #include +#include using namespace std; namespace bp = boost::process; diff --git a/fairmq/zeromq/FairMQSocketZMQ.cxx b/fairmq/zeromq/FairMQSocketZMQ.cxx index de54994a0..59cdfc5f9 100644 --- a/fairmq/zeromq/FairMQSocketZMQ.cxx +++ b/fairmq/zeromq/FairMQSocketZMQ.cxx @@ -91,19 +91,21 @@ bool FairMQSocketZMQ::Bind(const string& address) LOG(error) << "Failed binding socket " << fId << ", reason: " << zmq_strerror(errno); return false; } + return true; } -void FairMQSocketZMQ::Connect(const string& address) +bool FairMQSocketZMQ::Connect(const string& address) { // LOG(info) << "connect socket " << fId << " on " << address; if (zmq_connect(fSocket, address.c_str()) != 0) { LOG(error) << "Failed connecting socket " << fId << ", reason: " << zmq_strerror(errno); - // error here means incorrect configuration. exit if it happens. - exit(EXIT_FAILURE); + return false; } + + return true; } int FairMQSocketZMQ::Send(FairMQMessagePtr& msg, const int timeout) diff --git a/fairmq/zeromq/FairMQSocketZMQ.h b/fairmq/zeromq/FairMQSocketZMQ.h index 6d6a26076..a5d7d21b7 100644 --- a/fairmq/zeromq/FairMQSocketZMQ.h +++ b/fairmq/zeromq/FairMQSocketZMQ.h @@ -26,7 +26,7 @@ class FairMQSocketZMQ final : public FairMQSocket std::string GetId() override; bool Bind(const std::string& address) override; - void Connect(const std::string& address) override; + bool Connect(const std::string& address) override; int Send(FairMQMessagePtr& msg, const int timeout = -1) override; int Receive(FairMQMessagePtr& msg, const int timeout = -1) override; diff --git a/test/channel/_channel.cxx b/test/channel/_channel.cxx index ce13a358e..024b5b643 100644 --- a/test/channel/_channel.cxx +++ b/test/channel/_channel.cxx @@ -21,76 +21,76 @@ using namespace fair::mq; TEST(Channel, Validation) { FairMQChannel channel; - ASSERT_THROW(channel.ValidateChannel(), FairMQChannel::ChannelConfigurationError); + ASSERT_THROW(channel.Validate(), FairMQChannel::ChannelConfigurationError); channel.UpdateType("pair"); - ASSERT_EQ(channel.ValidateChannel(), false); + ASSERT_EQ(channel.Validate(), false); ASSERT_EQ(channel.IsValid(), false); channel.UpdateAddress("bla"); - ASSERT_THROW(channel.ValidateChannel(), FairMQChannel::ChannelConfigurationError); + ASSERT_THROW(channel.Validate(), FairMQChannel::ChannelConfigurationError); channel.UpdateMethod("connect"); - ASSERT_EQ(channel.ValidateChannel(), false); + ASSERT_EQ(channel.Validate(), false); ASSERT_EQ(channel.IsValid(), false); channel.UpdateAddress("ipc://"); - ASSERT_EQ(channel.ValidateChannel(), false); + ASSERT_EQ(channel.Validate(), false); ASSERT_EQ(channel.IsValid(), false); channel.UpdateAddress("verbs://"); - ASSERT_EQ(channel.ValidateChannel(), false); + ASSERT_EQ(channel.Validate(), false); ASSERT_EQ(channel.IsValid(), false); channel.UpdateAddress("inproc://"); - ASSERT_EQ(channel.ValidateChannel(), false); + ASSERT_EQ(channel.Validate(), false); ASSERT_EQ(channel.IsValid(), false); channel.UpdateAddress("tcp://"); - ASSERT_EQ(channel.ValidateChannel(), false); + ASSERT_EQ(channel.Validate(), false); ASSERT_EQ(channel.IsValid(), false); channel.UpdateAddress("tcp://localhost:5555"); - ASSERT_EQ(channel.ValidateChannel(), true); + ASSERT_EQ(channel.Validate(), true); ASSERT_EQ(channel.IsValid(), true); channel.UpdateSndBufSize(-1); - ASSERT_THROW(channel.ValidateChannel(), FairMQChannel::ChannelConfigurationError); + ASSERT_THROW(channel.Validate(), FairMQChannel::ChannelConfigurationError); channel.UpdateSndBufSize(1000); - ASSERT_NO_THROW(channel.ValidateChannel()); + ASSERT_NO_THROW(channel.Validate()); channel.UpdateRcvBufSize(-1); - ASSERT_THROW(channel.ValidateChannel(), FairMQChannel::ChannelConfigurationError); + ASSERT_THROW(channel.Validate(), FairMQChannel::ChannelConfigurationError); channel.UpdateRcvBufSize(1000); - ASSERT_NO_THROW(channel.ValidateChannel()); + ASSERT_NO_THROW(channel.Validate()); channel.UpdateSndKernelSize(-1); - ASSERT_THROW(channel.ValidateChannel(), FairMQChannel::ChannelConfigurationError); + ASSERT_THROW(channel.Validate(), FairMQChannel::ChannelConfigurationError); channel.UpdateSndKernelSize(1000); - ASSERT_NO_THROW(channel.ValidateChannel()); + ASSERT_NO_THROW(channel.Validate()); channel.UpdateRcvKernelSize(-1); - ASSERT_THROW(channel.ValidateChannel(), FairMQChannel::ChannelConfigurationError); + ASSERT_THROW(channel.Validate(), FairMQChannel::ChannelConfigurationError); channel.UpdateRcvKernelSize(1000); - ASSERT_NO_THROW(channel.ValidateChannel()); + ASSERT_NO_THROW(channel.Validate()); channel.UpdateRateLogging(-1); - ASSERT_THROW(channel.ValidateChannel(), FairMQChannel::ChannelConfigurationError); + ASSERT_THROW(channel.Validate(), FairMQChannel::ChannelConfigurationError); channel.UpdateRateLogging(1); - ASSERT_NO_THROW(channel.ValidateChannel()); + ASSERT_NO_THROW(channel.Validate()); FairMQChannel channel2 = channel; - ASSERT_NO_THROW(channel2.ValidateChannel()); - ASSERT_EQ(channel2.ValidateChannel(), true); + ASSERT_NO_THROW(channel2.Validate()); + ASSERT_EQ(channel2.Validate(), true); ASSERT_EQ(channel2.IsValid(), true); - ASSERT_EQ(channel2.ValidateChannel(), true); + ASSERT_EQ(channel2.Validate(), true); channel2.UpdateChannelName("Kanal"); ASSERT_EQ(channel2.GetChannelName(), "Kanal"); channel2.ResetChannel(); ASSERT_EQ(channel2.IsValid(), false); - ASSERT_EQ(channel2.ValidateChannel(), true); + ASSERT_EQ(channel2.Validate(), true); } } /* namespace */ diff --git a/test/protocols/_push_pull_multipart.cxx b/test/protocols/_push_pull_multipart.cxx index 6d5b0380b..0d07d2624 100644 --- a/test/protocols/_push_pull_multipart.cxx +++ b/test/protocols/_push_pull_multipart.cxx @@ -39,8 +39,8 @@ auto RunSingleThreadedMultipart(string transport, string address) -> void { // TODO validate that fTransportFactory is not nullptr // TODO validate that fSocket is not nullptr - ASSERT_TRUE(push.ValidateChannel()); - ASSERT_TRUE(pull.ValidateChannel()); + ASSERT_TRUE(push.Validate()); + ASSERT_TRUE(pull.Validate()); { auto sentMsg = FairMQParts{}; @@ -76,7 +76,7 @@ auto RunMultiThreadedMultipart(string transport, string address) -> void pull.Connect(address); auto pusher = thread{[&push](){ - ASSERT_TRUE(push.ValidateChannel()); + ASSERT_TRUE(push.Validate()); auto sentMsg = FairMQParts{}; sentMsg.AddPart(push.NewSimpleMessage("1")); @@ -87,7 +87,7 @@ auto RunMultiThreadedMultipart(string transport, string address) -> void }}; auto puller = thread{[&pull](){ - ASSERT_TRUE(pull.ValidateChannel()); + ASSERT_TRUE(pull.Validate()); auto receivedMsg = FairMQParts{}; ASSERT_GE(pull.Receive(receivedMsg), 0);