diff --git a/CMakeLists.txt b/CMakeLists.txt index db8f6bc2..d49fe66f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,17 +90,25 @@ ELSE () target_link_libraries(${PROJECT} pthread) ENDIF (WIN32) +# __CPP_REDIS_READ_SIZE IF (READ_SIZE) -set_target_properties(${PROJECT} - PROPERTIES - COMPILE_DEFINITIONS "__CPP_REDIS_READ_SIZE=${READ_SIZE}") +set_target_properties(${PROJECT} PROPERTIES COMPILE_DEFINITIONS "__CPP_REDIS_READ_SIZE=${READ_SIZE}") ENDIF (READ_SIZE) -IF (NO_LOGGING) -set_target_properties(${PROJECT} - PROPERTIES - COMPILE_DEFINITIONS "__CPP_REDIS_NO_LOGGING=${NO_LOGGING}") -ENDIF (NO_LOGGING) +# __CPP_REDIS_LOGGING_ENABLED +IF (LOGGING_ENABLED) +set_target_properties(${PROJECT} PROPERTIES COMPILE_DEFINITIONS "__CPP_REDIS_LOGGING_ENABLED=${LOGGING_ENABLED}") +ENDIF (LOGGING_ENABLED) + +# _CPP_REDIS_MAX_NB_FDS +IF (MAX_NB_FDS) +set_target_properties(${PROJECT} PROPERTIES COMPILE_DEFINITIONS "_CPP_REDIS_MAX_NB_FDS=${MAX_NB_FDS}") +ENDIF (MAX_NB_FDS) + +# __CPP_REDIS_DEFAULT_NB_IO_SERVICE_WORKERS +IF (DEFAULT_NB_IO_SERVICE_WORKERS) +set_target_properties(${PROJECT} PROPERTIES COMPILE_DEFINITIONS "__CPP_REDIS_DEFAULT_NB_IO_SERVICE_WORKERS=${DEFAULT_NB_IO_SERVICE_WORKERS}") +ENDIF (DEFAULT_NB_IO_SERVICE_WORKERS) ### diff --git a/includes/cpp_redis/logger.hpp b/includes/cpp_redis/logger.hpp index 281a87d2..22bbb285 100644 --- a/includes/cpp_redis/logger.hpp +++ b/includes/cpp_redis/logger.hpp @@ -67,11 +67,10 @@ void warn(const std::string& msg, const std::string& file, unsigned int line); void error(const std::string& msg, const std::string& file, unsigned int line); //! convenience macro to log with file and line information -//! if __CPP_REDIS_NO_LOGGING, all logging related lines are removed from source code -#ifdef __CPP_REDIS_NO_LOGGING -#define __CPP_REDIS_LOG(level, msg) -#else +#ifdef __CPP_REDIS_LOGGING_ENABLED #define __CPP_REDIS_LOG(level, msg) cpp_redis::level(msg, __FILE__, __LINE__); -#endif /* __CPP_REDIS_NO_LOGGING */ +#else +#define __CPP_REDIS_LOG(level, msg) +#endif /* __CPP_REDIS_LOGGING_ENABLED */ } //! cpp_redis diff --git a/includes/cpp_redis/network/io_service.hpp b/includes/cpp_redis/network/io_service.hpp new file mode 100644 index 00000000..e1779484 --- /dev/null +++ b/includes/cpp_redis/network/io_service.hpp @@ -0,0 +1,71 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#ifndef __CPP_REDIS_DEFAULT_NB_IO_SERVICE_WORKERS +#define __CPP_REDIS_DEFAULT_NB_IO_SERVICE_WORKERS 16 +#endif /* __CPP_REDIS_DEFAULT_NB_IO_SERVICE_WORKERS */ + +namespace cpp_redis { + +namespace network { + +class io_service { +public: + //! get default global instance + static const std::shared_ptr& get_global_instance(void); + //! set default global instance + static void set_global_instance(const std::shared_ptr& instance); + +public: + //! ctor & dtor + io_service(size_t nb_workers); + virtual ~io_service(void) = default; + + //! copy ctor & assignment operator + io_service(const io_service&) = default; + io_service& operator=(const io_service&) = default; + +public: + //! disconnection handler declaration + typedef std::function disconnection_handler_t; + + //! add or remove a given fd from the io service + //! untrack should never be called from inside a callback + virtual void track(_sock_t sock, const disconnection_handler_t& handler) = 0; + virtual void untrack(_sock_t sock) = 0; + + //! asynchronously read read_size bytes and append them to the given buffer + //! on completion, call the read_callback to notify of the success or failure of the operation + //! return false if another async_read operation is in progress or fd is not registered + typedef std::function read_callback_t; + virtual bool async_read(_sock_t sock, std::vector& buffer, std::size_t read_size, const read_callback_t& callback) = 0; + + //! asynchronously write write_size bytes from buffer to the specified fd + //! on completion, call the write_callback to notify of the success or failure of the operation + //! return false if another async_write operation is in progress or fd is not registered + typedef std::function write_callback_t; + virtual bool async_write(_sock_t sock, const std::vector& buffer, std::size_t write_size, const write_callback_t& callback) = 0; + +public: + size_t get_nb_workers(void) const; + +private: + //! listen for incoming events and notify + virtual void process_io(void) = 0; + +private: + size_t m_nb_workers; +}; + +//! multi-platform instance builder +std::shared_ptr create_io_service(size_t nb_workers = __CPP_REDIS_DEFAULT_NB_IO_SERVICE_WORKERS); + +} //! network + +} //! cpp_redis diff --git a/includes/cpp_redis/network/redis_connection.hpp b/includes/cpp_redis/network/redis_connection.hpp index fe6d5ec6..fa357165 100644 --- a/includes/cpp_redis/network/redis_connection.hpp +++ b/includes/cpp_redis/network/redis_connection.hpp @@ -5,13 +5,8 @@ #include #include -#ifdef _MSC_VER -#include -#else -#include -#endif /* _MSC_VER */ - #include +#include namespace cpp_redis { diff --git a/includes/cpp_redis/network/socket.hpp b/includes/cpp_redis/network/socket.hpp new file mode 100644 index 00000000..91809b57 --- /dev/null +++ b/includes/cpp_redis/network/socket.hpp @@ -0,0 +1,20 @@ +#pragma once + +namespace cpp_redis { + +namespace network { + +#ifdef _WIN32 +#include +typedef SOCKET _sock_t; +#else +typedef int _sock_t; +#endif /* _WIN32 */ + +#ifndef INVALID_SOCKET +#define INVALID_SOCKET -1 +#endif /* INVALID_SOCKET */ + +} //! network + +} //! cpp_redis diff --git a/includes/cpp_redis/network/unix/tcp_client.hpp b/includes/cpp_redis/network/tcp_client.hpp similarity index 90% rename from includes/cpp_redis/network/unix/tcp_client.hpp rename to includes/cpp_redis/network/tcp_client.hpp index 2a564fe0..631b7b63 100644 --- a/includes/cpp_redis/network/unix/tcp_client.hpp +++ b/includes/cpp_redis/network/tcp_client.hpp @@ -1,13 +1,15 @@ #pragma once #include +#include #include #include #include #include #include -#include +#include +#include #include #ifndef __CPP_REDIS_READ_SIZE @@ -56,20 +58,21 @@ class tcp_client { void reset_state(void); void clear_buffer(void); + void setup_socket(void); + private: //! io service instance std::shared_ptr m_io_service; //! socket fd - int m_fd; + _sock_t m_sock; //! is connected std::atomic_bool m_is_connected; //! buffers - static const unsigned int READ_SIZE = __CPP_REDIS_READ_SIZE; std::vector m_read_buffer; - std::vector m_write_buffer; + std::list> m_write_buffer; //! handlers receive_handler_t m_receive_handler; diff --git a/includes/cpp_redis/network/unix/io_service.hpp b/includes/cpp_redis/network/unix/io_service.hpp index 69b100e0..5e21a841 100644 --- a/includes/cpp_redis/network/unix/io_service.hpp +++ b/includes/cpp_redis/network/unix/io_service.hpp @@ -11,46 +11,34 @@ #include #include +#include + +#ifndef _CPP_REDIS_MAX_NB_FDS +#define _CPP_REDIS_MAX_NB_FDS 1024 +#endif /* _CPP_REDIS_MAX_NB_FDS */ + namespace cpp_redis { namespace network { -class io_service { -public: - //! instance getter (singleton pattern) - static const std::shared_ptr& get_instance(void); +namespace unix { - //! dtor +class io_service : public network::io_service { +public: + //! ctor & dtor + io_service(size_t nb_workers); ~io_service(void); -private: - //! ctor - io_service(void); - //! copy ctor & assignment operator io_service(const io_service&) = delete; io_service& operator=(const io_service&) = delete; public: - //! disconnection handler declaration - typedef std::function disconnection_handler_t; - - //! add or remove a given fd from the io service - //! untrack should never be called from inside a callback - void track(int fd, const disconnection_handler_t& handler); - void untrack(int fd); - - //! asynchronously read read_size bytes and append them to the given buffer - //! on completion, call the read_callback to notify of the success or failure of the operation - //! return false if another async_read operation is in progress or fd is not registered - typedef std::function read_callback_t; - bool async_read(int fd, std::vector& buffer, std::size_t read_size, const read_callback_t& callback); - - //! asynchronously write write_size bytes from buffer to the specified fd - //!on completion, call the write_callback to notify of the success or failure of the operation - //! return false if another async_write operation is in progress or fd is not registered - typedef std::function write_callback_t; - bool async_write(int fd, const std::vector& buffer, std::size_t write_size, const write_callback_t& callback); + void track(_sock_t fd, const disconnection_handler_t& handler) override; + void untrack(_sock_t fd) override; + + bool async_read(_sock_t fd, std::vector& buffer, std::size_t read_size, const read_callback_t& callback) override; + bool async_write(_sock_t fd, const std::vector& buffer, std::size_t write_size, const write_callback_t& callback) override; private: //! simple struct to keep track of ongoing operations on a given fd @@ -73,7 +61,7 @@ class io_service { private: //! listen for incoming events and notify - void listen(void); + void process_io(void) override; //! notify the poll call so that it can wake up to process new events void notify_poll(void); @@ -108,6 +96,8 @@ class io_service { std::recursive_mutex m_fds_mutex; }; +} //! unix + } //! network } //! cpp_redis diff --git a/includes/cpp_redis/network/windows/io_service.hpp b/includes/cpp_redis/network/windows/io_service.hpp index fcdf6d4d..441ede41 100644 --- a/includes/cpp_redis/network/windows/io_service.hpp +++ b/includes/cpp_redis/network/windows/io_service.hpp @@ -8,54 +8,44 @@ #include -#define MAX_BUFF_SIZE __CPP_REDIS_READ_SIZE -#define MAX_WORKER_THREADS 16 +#include namespace cpp_redis { namespace network { +namespace windows { + typedef enum _enIoOperation { //IO_OP_ACCEPT, IO_OP_READ, IO_OP_WRITE } enIoOperation; - -class io_service { +class io_service : public network::io_service { public: - //! instance getter (singleton pattern) - static const std::shared_ptr& get_instance(void); - io_service(size_t max_worker_threads = MAX_WORKER_THREADS); + //! ctor & dtor + io_service(size_t nb_workers); ~io_service(void); - void shutdown(); - private: //! copy ctor & assignment operator io_service(const io_service&) = delete; io_service& operator=(const io_service&) = delete; public: - //! disconnection handler declaration - typedef std::function disconnection_handler_t; - - //! add or remove a given socket from the io service - //! untrack should never be called from inside a callback - void track(SOCKET sock, const disconnection_handler_t& handler); - void untrack(SOCKET sock); - - //! asynchronously read read_size bytes and append them to the given buffer - //! on completion, call the read_callback to notify of the success or failure of the operation - //! return false if another async_read operation is in progress or socket is not registered - typedef std::function read_callback_t; - bool async_read(SOCKET socket, std::vector& buffer, std::size_t read_size, const read_callback_t& callback); - - //! asynchronously write write_size bytes from buffer to the specified fd - //!on completion, call the write_callback to notify of the success or failure of the operation - //! return false if another async_write operation is in progress or socket is not registered - typedef std::function write_callback_t; - bool async_write(SOCKET socket, const std::vector& buffer, std::size_t write_size, const write_callback_t& callback); + void track(_sock_t sock, const disconnection_handler_t& handler) override; + void untrack(_sock_t sock) override; + + bool async_read(_sock_t socket, std::vector& buffer, std::size_t read_size, const read_callback_t& callback) override; + bool async_write(_sock_t socket, const std::vector& buffer, std::size_t write_size, const write_callback_t& callback) override; + +private: + //! wait for incoming events and notify + void process_io(void) override; + + //! shutdown + void shutdown(void); private: struct io_context_info : OVERLAPPED { @@ -63,6 +53,7 @@ class io_service { enIoOperation eOperation; }; +private: //! simple struct to keep track of ongoing operations on a given sockeet class sock_info { public: @@ -78,12 +69,12 @@ class io_service { SOCKET hsock; std::size_t sent_bytes; - //Must protect the members of our structure from access by multiple threads during IO Completion + //! Must protect the members of our structure from access by multiple threads during IO Completion std::recursive_mutex sock_info_mutex; - //We keep a simple vector of io_context_info structs to reuse for overlapped WSARecv and WSASend operations - //Since each must have its OWN struct if we issue them at the same time. - //othewise things get tangled up and borked. + //! We keep a simple vector of io_context_info structs to reuse for overlapped WSARecv and WSASend operations + //! Since each must have its OWN struct if we issue them at the same time. + //! othewise things get tangled up and borked. std::vector io_contexts_pool; disconnection_handler_t disconnection_handler; @@ -117,14 +108,12 @@ class io_service { } }; - typedef std::function callback_t; - - //! wait for incoming events and notify - int process_io(void); - +private: + //! completion port HANDLE m_completion_port; - unsigned int m_worker_thread_pool_size; - std::vector m_worker_threads; //vector containing all the threads we start to service our i/o requests + + //! vector containing all the threads we start to service our i/o requests + std::vector m_worker_threads; private: //! whether the worker should terminate or not @@ -142,6 +131,8 @@ class io_service { std::recursive_mutex m_socket_mutex; }; +} //! windows + } //! network } //! cpp_redis diff --git a/includes/cpp_redis/network/windows/tcp_client.hpp b/includes/cpp_redis/network/windows/tcp_client.hpp deleted file mode 100644 index 8fc9f804..00000000 --- a/includes/cpp_redis/network/windows/tcp_client.hpp +++ /dev/null @@ -1,85 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#ifndef __CPP_REDIS_READ_SIZE -#define __CPP_REDIS_READ_SIZE 4096 -#endif /* __CPP_REDIS_READ_SIZE */ - -namespace cpp_redis { - -namespace network { - -//! tcp_client -//! async tcp client based on boost asio -class tcp_client { -public: - //! ctor & dtor - tcp_client(const std::shared_ptr& IO = nullptr); - ~tcp_client(void); - - //! assignment operator & copy ctor - tcp_client(const tcp_client&) = delete; - tcp_client& operator=(const tcp_client&) = delete; - - //! returns whether the client is connected or not - bool is_connected(void); - - //! handle connection & disconnection - typedef std::function disconnection_handler_t; - typedef std::function& buffer)> receive_handler_t; - void connect(const std::string& host, unsigned int port, - const disconnection_handler_t& disconnection_handler = nullptr, - const receive_handler_t& receive_handler = nullptr); - void disconnect(void); - - //! send data - void send(const std::string& buffer); - void send(const std::vector& buffer); - -private: - //! make async read and write operations - void async_read(void); - void async_write(void); - - //! io service callback - void io_service_disconnection_handler(io_service&); - - void reset_state(void); - void clear_buffer(void); - -private: - //! io service instance - const std::shared_ptr m_io_service; - - //! socket - SOCKET m_sock; - - //! is connected - std::atomic_bool m_is_connected; - - //! buffers - static const unsigned int READ_SIZE = __CPP_REDIS_READ_SIZE; - std::vector m_read_buffer; - std::list> m_write_buffer; - - //! handlers - receive_handler_t m_receive_handler; - disconnection_handler_t m_disconnection_handler; - - //! thread safety - std::mutex m_write_buffer_mutex; -}; - -} //! network - -} //! cpp_redis diff --git a/includes/cpp_redis/redis_client.hpp b/includes/cpp_redis/redis_client.hpp index a1f6cbaa..8bedca94 100644 --- a/includes/cpp_redis/redis_client.hpp +++ b/includes/cpp_redis/redis_client.hpp @@ -45,7 +45,7 @@ class redis_client { std::unique_lock lock_callback(m_callbacks_mutex); __CPP_REDIS_LOG(debug, "cpp_redis::redis_client waits for callbacks to complete"); - m_sync_condvar.wait_for(lock_callback, timeout, [=] { return m_callbacks.empty(); }); + m_sync_condvar.wait_for(lock_callback, timeout, [=] { return m_callbacks_running == 0 && m_callbacks.empty(); }); __CPP_REDIS_LOG(debug, "cpp_redis::redis_client finished to wait for callbacks completion (or timeout reached)"); return *this; @@ -282,6 +282,7 @@ class redis_client { std::mutex m_callbacks_mutex; std::mutex m_send_mutex; std::condition_variable m_sync_condvar; + std::atomic_uint m_callbacks_running; }; } //! cpp_redis diff --git a/sources/network/io_service.cpp b/sources/network/io_service.cpp new file mode 100644 index 00000000..0b14b037 --- /dev/null +++ b/sources/network/io_service.cpp @@ -0,0 +1,47 @@ +#include + +#ifdef _WIN32 +#include +#else +#include +#endif /* _WIN32 */ + +namespace cpp_redis { + +namespace network { + +static std::shared_ptr global_instance = nullptr; + +const std::shared_ptr& +io_service::get_global_instance(void) { + if (!global_instance) + global_instance = create_io_service(); + + return global_instance; +} + +void +io_service::set_global_instance(const std::shared_ptr& io_service) { + global_instance = io_service; +} + +io_service::io_service(size_t nb_workers) +: m_nb_workers(nb_workers) {} + +size_t +io_service::get_nb_workers(void) const { + return m_nb_workers; +} + +std::shared_ptr +create_io_service(size_t nb_workers) { +#ifdef _WIN32 + return std::make_shared(nb_workers); +#else + return std::make_shared(nb_workers); +#endif /* _WIN32 */ +} + +} //! network + +} //! cpp_redis diff --git a/sources/network/redis_connection.cpp b/sources/network/redis_connection.cpp index 7283b768..26b6d5dd 100644 --- a/sources/network/redis_connection.cpp +++ b/sources/network/redis_connection.cpp @@ -5,8 +5,8 @@ namespace cpp_redis { namespace network { -redis_connection::redis_connection(const std::shared_ptr& IO) -: m_client(IO) +redis_connection::redis_connection(const std::shared_ptr& io_service) +: m_client(io_service) , m_reply_callback(nullptr) , m_disconnection_handler(nullptr) { __CPP_REDIS_LOG(debug, "cpp_redis::network::redis_connection created"); diff --git a/sources/network/windows/tcp_client.cpp b/sources/network/tcp_client.cpp similarity index 67% rename from sources/network/windows/tcp_client.cpp rename to sources/network/tcp_client.cpp index 2230fda6..a9a0510b 100644 --- a/sources/network/windows/tcp_client.cpp +++ b/sources/network/tcp_client.cpp @@ -1,24 +1,27 @@ #include -#pragma warning(disable : 4996) //Disable "The POSIX name for this item is deprecated" warnings for gethostbyname() +//! Disable "The POSIX name for this item is deprecated" warnings for gethostbyname() +#ifdef _WIN32 +#pragma warning(disable : 4996) +#endif /* _WIN32 */ #include +#ifdef _WIN32 +#include +#else +#include +#include +#endif /* _WIN32 */ + #include -#include +#include namespace cpp_redis { namespace network { -//! note that we call io_service::get_instance in the init list -//! -//! this will force force io_service instance creation -//! this is a workaround to handle static object destructions order -//! -//! that way, any object containing a tcp_client has an attribute (or through its attributes) -//! is guaranteed to be destructed before the io_service is destructed, even if it is global -tcp_client::tcp_client(const std::shared_ptr& IO) -: m_io_service(IO ? IO : io_service::get_instance()) +tcp_client::tcp_client(const std::shared_ptr& io_service) +: m_io_service(io_service ? io_service : io_service::get_global_instance()) , m_sock(-1) , m_is_connected(false) , m_receive_handler(nullptr) @@ -31,6 +34,35 @@ tcp_client::~tcp_client(void) { __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client destroyed"); } +void +tcp_client::setup_socket(void) { +#ifdef _WIN32 + //! create the socket + //! Enable socket for overlapped i/o + m_sock = WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); + if (m_sock < 0) { + __CPP_REDIS_LOG(error, "cpp_redis::network::tcp_client could not create socket"); + throw redis_error("Can't open a socket"); + } + + //! Instruct the TCP stack to directly perform I/O using the buffer provided in our I/O call. + //! The advantage is performance because we save a buffer copy between the TCP stack buffer + //! and our user buffer for each I/O call. + //! BUT we have to make sure we don't access the buffer once it's submitted for overlapped operation and before the overlapped operation completes! + int nZero = 0; + if (setsockopt(m_sock, SOL_SOCKET, SO_SNDBUF, (char*) &nZero, sizeof(nZero))) { + __CPP_REDIS_LOG(warn, "cpp_redis::network::tcp_client could not disable buffering"); + } +#else + //! create the socket + m_sock = socket(AF_INET, SOCK_STREAM, 0); + if (m_sock < 0) { + __CPP_REDIS_LOG(error, "cpp_redis::network::tcp_client could not create socket"); + throw redis_error("Can't open a socket"); + } +#endif /* _WIN32 */ +} + void tcp_client::connect(const std::string& host, unsigned int port, const disconnection_handler_t& disconnection_handler, @@ -42,20 +74,7 @@ tcp_client::connect(const std::string& host, unsigned int port, return throw cpp_redis::redis_error("Client already connected"); } - //! create the socket - int nZero = 0; - //Enable socket for overlapped i/o - m_sock = WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); - if (m_sock < 0) - throw redis_error("Can't open a socket"); - - //Instruct the TCP stack to directly perform I/O using the buffer provided in our I/O call. - //The advantage is performance because we save a buffer copy between the TCP stack buffer - //and our user buffer for each I/O call. - //BUT we have to make sure we don't access the buffer once it's submitted for overlapped operation and before the overlapped operation completes! - if (0 != setsockopt(m_sock, SOL_SOCKET, SO_SNDBUF, (char*) &nZero, sizeof(nZero))) { - return throw cpp_redis::redis_error("tcp_client::connect() setsockopt failed to disable buffering"); - } + setup_socket(); //! get the server's DNS entry struct hostent* server = gethostbyname(host.c_str()); @@ -73,17 +92,22 @@ tcp_client::connect(const std::string& host, unsigned int port, //! create a connection with the server if (::connect(m_sock, reinterpret_cast(&server_addr), sizeof(server_addr)) < 0) { + __CPP_REDIS_LOG(error, "cpp_redis::network::tcp_client could not connect"); throw redis_error("Fail to connect to " + host + ":" + std::to_string(port)); } - //! add socket to the io_service and set the disconnection & recv handlers - m_disconnection_handler = disconnection_handler; - m_receive_handler = receive_handler; - +#ifdef _WIN32 + //! Set socket to non blocking. + //! Must only be done once connected u_long ulValue = 1; - if (0 != ioctlsocket(m_sock, FIONBIO, &ulValue)) //Set socket to non blocking. - throw cpp_redis::redis_error("tcp_client::connect() setsockopt failed to set socket to non-blocking"); + if (ioctlsocket(m_sock, FIONBIO, &ulValue)) { + __CPP_REDIS_LOG(warn, "cpp_redis::network::tcp_client could not enable non-blocking mode on socket"); + } +#endif /* _WIN32 */ + //! add fd to the io_service and set the disconnection & recv handlers + m_disconnection_handler = disconnection_handler; + m_receive_handler = receive_handler; m_io_service->track(m_sock, std::bind(&tcp_client::io_service_disconnection_handler, this, std::placeholders::_1)); m_is_connected = true; @@ -95,18 +119,14 @@ tcp_client::connect(const std::string& host, unsigned int port, void tcp_client::disconnect(void) { + __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client attemps to disconnect"); + if (!m_is_connected) { __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client already disconnected"); return; } - m_is_connected = false; m_io_service->untrack(m_sock); - - closesocket(m_sock); - m_sock = INVALID_SOCKET; - - clear_buffer(); reset_state(); __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client disconnected"); @@ -119,6 +139,8 @@ tcp_client::send(const std::string& buffer) { void tcp_client::send(const std::vector& buffer) { + __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client attemps to send data"); + if (!m_is_connected) { __CPP_REDIS_LOG(error, "cpp_redis::network::tcp_client is not connected"); throw redis_error("Not connected"); @@ -148,22 +170,21 @@ void tcp_client::async_read(void) { __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client starts async_read"); - m_io_service->async_read(m_sock, m_read_buffer, READ_SIZE, + m_io_service->async_read(m_sock, m_read_buffer, __CPP_REDIS_READ_SIZE, [&](std::size_t length) { __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client received data"); if (m_receive_handler) { __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client calls receive_handler"); - if (!m_receive_handler(*this, - {m_read_buffer.begin(), m_read_buffer.begin() + length})) { + if (!m_receive_handler(*this, {m_read_buffer.begin(), m_read_buffer.begin() + length})) { __CPP_REDIS_LOG(warn, "cpp_redis::network::tcp_client has been asked for disconnection by receive_handler"); disconnect(); return; } } - //! clear read buffer and re-issue async read to receive more incoming bytes + //! clear read buffer keep waiting for incoming bytes m_read_buffer.clear(); if (m_is_connected) @@ -180,13 +201,13 @@ tcp_client::async_write(void) { __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client wrote data and cleans write_buffer"); std::lock_guard lock(m_write_buffer_mutex); - //Remove what has already been sent and see if we have any more to send + //! Remove what has already been sent and see if we have any more to send if (length >= m_write_buffer.front().size()) m_write_buffer.pop_front(); else m_write_buffer.front().erase(m_write_buffer.front().begin(), m_write_buffer.front().begin() + length); - //If we still have data to write the call ourselves recursivly until the buffer is completely sent + //! If we still have data to write the call ourselves recursivly until the buffer is completely sent if (m_is_connected && m_write_buffer.size()) async_write(); }); @@ -211,10 +232,16 @@ tcp_client::io_service_disconnection_handler(network::io_service&) { void tcp_client::reset_state(void) { - if (m_sock != INVALID_SOCKET) - closesocket(m_sock); m_is_connected = false; - m_sock = INVALID_SOCKET; + + if (m_sock != INVALID_SOCKET) { +#ifdef _WIN32 + closesocket(m_sock); +#else + close(m_sock); +#endif /* _WIN32 */ + m_sock = INVALID_SOCKET; + } clear_buffer(); } diff --git a/sources/network/unix/io_service.cpp b/sources/network/unix/io_service.cpp index 484b8f3f..922b4741 100644 --- a/sources/network/unix/io_service.cpp +++ b/sources/network/unix/io_service.cpp @@ -8,14 +8,11 @@ namespace cpp_redis { namespace network { -const std::shared_ptr& -io_service::get_instance(void) { - static std::shared_ptr instance = std::shared_ptr{new io_service}; - return instance; -} +namespace unix { -io_service::io_service(void) -: m_should_stop(false) +io_service::io_service(size_t nb_workers) +: network::io_service(nb_workers) +, m_should_stop(false) , m_notif_pipe_fds{1, 1} { if (pipe(m_notif_pipe_fds) == -1) { __CPP_REDIS_LOG(error, "cpp_redis::network::io_service could not create pipe"); @@ -28,7 +25,7 @@ io_service::io_service(void) throw cpp_redis::redis_error("Could not init cpp_redis::io_service, fcntl() failure"); } - m_worker = std::thread(&io_service::listen, this); + m_worker = std::thread(&io_service::process_io, this); __CPP_REDIS_LOG(debug, "cpp_redis::network::io_service created"); } @@ -172,8 +169,8 @@ io_service::process_sets(struct pollfd* fds, unsigned int nfds) { } void -io_service::listen(void) { - struct pollfd fds[1024]; +io_service::process_io(void) { + struct pollfd fds[_CPP_REDIS_MAX_NB_FDS]; __CPP_REDIS_LOG(debug, "cpp_redis::network::io_service starts poll loop in worker thread"); @@ -184,15 +181,16 @@ io_service::listen(void) { __CPP_REDIS_LOG(debug, "cpp_redis::network::io_service woke up by poll"); process_sets(fds, nfds); } - else + else { __CPP_REDIS_LOG(debug, "cpp_redis::network::io_service woke up by poll, but nothing to process"); + } } __CPP_REDIS_LOG(debug, "cpp_redis::network::io_service ends poll loop in worker thread"); } void -io_service::track(int fd, const disconnection_handler_t& handler) { +io_service::track(_sock_t fd, const disconnection_handler_t& handler) { std::lock_guard lock(m_fds_mutex); auto& info = m_fds[fd]; @@ -206,7 +204,7 @@ io_service::track(int fd, const disconnection_handler_t& handler) { } void -io_service::untrack(int fd) { +io_service::untrack(_sock_t fd) { std::unique_lock lock(m_fds_mutex); __CPP_REDIS_LOG(debug, "cpp_redis::network::io_service requests to untrack fd #" + std::to_string(fd)); @@ -228,7 +226,7 @@ io_service::untrack(int fd) { } bool -io_service::async_read(int fd, std::vector& buffer, std::size_t read_size, const read_callback_t& callback) { +io_service::async_read(_sock_t fd, std::vector& buffer, std::size_t read_size, const read_callback_t& callback) { std::lock_guard lock(m_fds_mutex); __CPP_REDIS_LOG(debug, "cpp_redis::network::io_service is requested async_read for fd #" + std::to_string(fd)); @@ -256,7 +254,7 @@ io_service::async_read(int fd, std::vector& buffer, std::size_t read_size, } bool -io_service::async_write(int fd, const std::vector& buffer, std::size_t write_size, const write_callback_t& callback) { +io_service::async_write(_sock_t fd, const std::vector& buffer, std::size_t write_size, const write_callback_t& callback) { std::lock_guard lock(m_fds_mutex); __CPP_REDIS_LOG(debug, "cpp_redis::network::io_service is requested async_write for fd #" + std::to_string(fd)); @@ -289,6 +287,8 @@ io_service::notify_poll(void) { (void) write(m_notif_pipe_fds[1], "a", 1); } +} //! unix + } //! network } //! cpp_redis diff --git a/sources/network/unix/tcp_client.cpp b/sources/network/unix/tcp_client.cpp deleted file mode 100644 index b66791a6..00000000 --- a/sources/network/unix/tcp_client.cpp +++ /dev/null @@ -1,212 +0,0 @@ -#include -#include -#include - -#include -#include - -namespace cpp_redis { - -namespace network { - -//! note that we call io_service::get_instance in the init list -//! -//! this will force force io_service instance creation -//! this is a workaround to handle static object destructions order -//! -//! that way, any object containing a tcp_client has an attribute (or through its attributes) -//! is guaranteed to be destructed before the io_service is destructed, even if it is global -tcp_client::tcp_client(const std::shared_ptr& IO) -: m_io_service(IO ? IO : io_service::get_instance()) -, m_fd(-1) -, m_is_connected(false) -, m_receive_handler(nullptr) -, m_disconnection_handler(nullptr) { - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client created"); -} - -tcp_client::~tcp_client(void) { - disconnect(); - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client destroyed"); -} - -void -tcp_client::connect(const std::string& host, unsigned int port, - const disconnection_handler_t& disconnection_handler, - const receive_handler_t& receive_handler) { - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client attempts to connect"); - - if (m_is_connected) { - __CPP_REDIS_LOG(warn, "cpp_redis::network::tcp_client is already connected"); - return throw cpp_redis::redis_error("Client already connected"); - } - - //! create the socket - m_fd = socket(AF_INET, SOCK_STREAM, 0); - if (m_fd < 0) { - __CPP_REDIS_LOG(error, "cpp_redis::network::tcp_client could not create socket"); - throw redis_error("Can't open a socket"); - } - - //! get the server's DNS entry - struct hostent* server = gethostbyname(host.c_str()); - if (!server) { - __CPP_REDIS_LOG(error, "cpp_redis::network::tcp_client could not resolve DNS"); - throw redis_error("No such host: " + host); - } - - //! build the server's Internet address - struct sockaddr_in server_addr; - std::memset(&server_addr, 0, sizeof(server_addr)); - std::memcpy(&server_addr.sin_addr.s_addr, server->h_addr, server->h_length); - server_addr.sin_port = htons(port); - server_addr.sin_family = AF_INET; - - //! create a connection with the server - if (::connect(m_fd, reinterpret_cast(&server_addr), sizeof(server_addr)) < 0) { - __CPP_REDIS_LOG(error, "cpp_redis::network::tcp_client could not connect"); - throw redis_error("Fail to connect to " + host + ":" + std::to_string(port)); - } - - //! add fd to the io_service and set the disconnection & recv handlers - m_disconnection_handler = disconnection_handler; - m_receive_handler = receive_handler; - m_io_service->track(m_fd, std::bind(&tcp_client::io_service_disconnection_handler, this, std::placeholders::_1)); - m_is_connected = true; - - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client connected"); - - //! start async read - async_read(); -} - -void -tcp_client::disconnect(void) { - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client attemps to disconnect"); - - if (!m_is_connected) { - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client already disconnected"); - return; - } - - m_io_service->untrack(m_fd); - reset_state(); - - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client disconnected"); -} - -void -tcp_client::send(const std::string& buffer) { - send(std::vector{buffer.begin(), buffer.end()}); -} - -void -tcp_client::send(const std::vector& buffer) { - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client attemps to send data"); - - if (!m_is_connected) { - __CPP_REDIS_LOG(error, "cpp_redis::network::tcp_client is not connected"); - throw redis_error("Not connected"); - } - - if (!buffer.size()) { - __CPP_REDIS_LOG(warn, "cpp_redis::network::tcp_client has nothing to send"); - return; - } - - std::lock_guard lock(m_write_buffer_mutex); - - bool bytes_in_buffer = m_write_buffer.size() > 0; - - //! concat buffer - m_write_buffer.insert(m_write_buffer.end(), buffer.begin(), buffer.end()); - - //! if there were already bytes in buffer, simply return - //! async_write callback will process the new buffer - if (bytes_in_buffer) { - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client is already processing an async_write"); - return; - } - - async_write(); -} - -void -tcp_client::async_read(void) { - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client starts async_read"); - - m_io_service->async_read(m_fd, m_read_buffer, READ_SIZE, - [&](std::size_t length) { - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client received data"); - - if (m_receive_handler) - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client calls receive_handler"); - if (!m_receive_handler(*this, {m_read_buffer.begin(), m_read_buffer.begin() + length})) { - __CPP_REDIS_LOG(warn, "cpp_redis::network::tcp_client has been asked for disconnection by receive_handler"); - disconnect(); - return; - } - - //! clear read buffer keep waiting for incoming bytes - m_read_buffer.clear(); - - if (m_is_connected) - async_read(); - }); -} - -void -tcp_client::async_write(void) { - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client starts async_write"); - - m_io_service->async_write(m_fd, m_write_buffer, m_write_buffer.size(), - [&](std::size_t length) { - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client wrote data and cleans write_buffer"); - std::lock_guard lock(m_write_buffer_mutex); - - m_write_buffer.erase(m_write_buffer.begin(), m_write_buffer.begin() + length); - - if (m_is_connected && m_write_buffer.size()) - async_write(); - }); -} - -bool -tcp_client::is_connected(void) { - return m_is_connected; -} - -void -tcp_client::io_service_disconnection_handler(network::io_service&) { - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client has been disconnected"); - - reset_state(); - - if (m_disconnection_handler) { - __CPP_REDIS_LOG(debug, "cpp_redis::network::tcp_client calls disconnection handler"); - m_disconnection_handler(*this); - } -} - -void -tcp_client::reset_state(void) { - m_is_connected = false; - - if (m_fd != -1) { - close(m_fd); - m_fd = -1; - } - - clear_buffer(); -} - -void -tcp_client::clear_buffer(void) { - std::lock_guard lock(m_write_buffer_mutex); - m_write_buffer.clear(); - m_read_buffer.clear(); -} - -} //! network - -} //! cpp_redis diff --git a/sources/network/windows/io_service.cpp b/sources/network/windows/io_service.cpp index db14d4b9..31a48dcc 100644 --- a/sources/network/windows/io_service.cpp +++ b/sources/network/windows/io_service.cpp @@ -1,42 +1,29 @@ -#include "cpp_redis/network/windows/io_service.hpp" -#include "cpp_redis/redis_error.hpp" +#include +#include namespace cpp_redis { namespace network { -const std::shared_ptr& -io_service::get_instance(void) { - static std::shared_ptr instance = std::shared_ptr{new io_service}; - return instance; -} - -io_service::io_service(size_t max_worker_threads) -: m_should_stop(false) { - //Determine the size of the thread pool dynamically. - //2 * number of processors in the system is our rule here. - SYSTEM_INFO info; - ::GetSystemInfo(&info); - m_worker_thread_pool_size = (info.dwNumberOfProcessors * 2); - - if (m_worker_thread_pool_size > max_worker_threads) - m_worker_thread_pool_size = max_worker_threads; +namespace windows { +io_service::io_service(size_t nb_workers) +: network::io_service(nb_workers) +, m_should_stop(false) { + //! Start winsock before any other socket calls. WSADATA wsaData; - int nRet = 0; - if ((nRet = WSAStartup(0x202, &wsaData)) != 0) //Start winsock before any other socket calls. + int nRet = WSAStartup(0x202, &wsaData); + if (nRet) throw cpp_redis::redis_error("Could not init cpp_redis::io_service, WSAStartup() failure"); //Create completion port. Pass 0 for parameter 4 to allow as many threads as there are processors in the system m_completion_port = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); - ; - if (INVALID_HANDLE_VALUE == m_completion_port) + if (m_completion_port == INVALID_HANDLE_VALUE) throw cpp_redis::redis_error("Could not init cpp_redis::io_service, CreateIoCompletionPort() failure"); - //Now startup worker thread pool which will service our async io requests - for (unsigned int i = 0; i < m_worker_thread_pool_size; i++) { + //! Now startup worker thread pool which will service our async io requests + for (unsigned int i = 0; i < get_nb_workers(); ++i) m_worker_threads.push_back(std::thread(&io_service::process_io, this)); - } } io_service::~io_service(void) { @@ -47,28 +34,24 @@ void io_service::shutdown() { m_should_stop = true; - //Iterate all of our sockets and shutdown any IO worker threads by posting a issuing a special - //message to the thread to tell them to wake up and shut down. - io_context_info* pInfo = NULL; - - auto sock_it = m_sockets.begin(); - while (sock_it != m_sockets.end()) { - auto& info = sock_it->second; - //Post for each of our worker threads. - int workers = m_worker_threads.size(); - for (int i = 0; i < workers; i++) - PostQueuedCompletionStatus(m_completion_port, 0, NULL, NULL); //Use NULL for the completion key to wake them up. - sock_it++; + //! Iterate all of our sockets and shutdown any IO worker threads by posting a issuing a special + //! message to the thread to tell them to wake up and shut down. + for (const auto& sock : m_sockets) { + //! Post for each of our worker threads. + for (size_t i = 0; i < get_nb_workers(); i++) { + //! Use nullptr for the completion key to wake them up. + PostQueuedCompletionStatus(m_completion_port, 0, NULL, NULL); + } } - // Wait for the threads to finish - for (auto& t : m_worker_threads) - t.join(); + //! Wait for the threads to finish + for (auto& worker : m_worker_threads) + worker.join(); - //close the completion port otherwise the worker threads will all be waiting on GetQueuedCompletionStatus() + //! close the completion port otherwise the worker threads will all be waiting on GetQueuedCompletionStatus() if (m_completion_port) { CloseHandle(m_completion_port); - m_completion_port = NULL; + m_completion_port = nullptr; } } @@ -172,7 +155,7 @@ io_service::async_write(SOCKET sock, const std::vector& buffer, std::size_ } //function used by worker thread(s) used to process io requests -int +void io_service::process_io(void) { BOOL bSuccess = FALSE; int nRet = 0; @@ -203,7 +186,7 @@ io_service::process_io(void) { continue; } if (m_should_stop) - return 0; + return; } //get the base address of the struct holding lpOverlapped (the io_context_info) pointer. @@ -215,7 +198,7 @@ io_service::process_io(void) { // Somebody used PostQueuedCompletionStatus to post an I/O packet with // a NULL CompletionKey (or if we get one for any reason). It is time to exit. if (!psock_info || !pOverlapped) - return 0; + return; e_op = pio_info->eOperation; @@ -259,9 +242,10 @@ io_service::process_io(void) { break; } //switch } //while - - return 0; } +} //! windows + } //! network + } //! cpp_redis diff --git a/sources/redis_client.cpp b/sources/redis_client.cpp index 96f888be..ad5d7146 100644 --- a/sources/redis_client.cpp +++ b/sources/redis_client.cpp @@ -3,8 +3,9 @@ namespace cpp_redis { -redis_client::redis_client(const std::shared_ptr& IO) -: m_client(IO) { +redis_client::redis_client(const std::shared_ptr& io_service) +: m_client(io_service) +, m_callbacks_running(0) { __CPP_REDIS_LOG(debug, "cpp_redis::redis_client created"); } @@ -65,7 +66,7 @@ redis_client::sync_commit(void) { std::unique_lock lock_callback(m_callbacks_mutex); __CPP_REDIS_LOG(debug, "cpp_redis::redis_client waits for callbacks to complete"); - m_sync_condvar.wait(lock_callback, [=] { return m_callbacks.empty(); }); + m_sync_condvar.wait(lock_callback, [=] { return m_callbacks_running == 0 && m_callbacks.empty(); }); __CPP_REDIS_LOG(debug, "cpp_redis::redis_client finished to wait for callbacks completion"); return *this; @@ -95,6 +96,7 @@ redis_client::connection_receive_handler(network::redis_connection&, reply& repl if (m_callbacks.size()) { callback = m_callbacks.front(); + m_callbacks_running += 1; m_callbacks.pop(); } } @@ -104,6 +106,7 @@ redis_client::connection_receive_handler(network::redis_connection&, reply& repl callback(reply); } + m_callbacks_running -= 1; m_sync_condvar.notify_all(); } diff --git a/sources/redis_subscriber.cpp b/sources/redis_subscriber.cpp index eefc954f..ababc985 100644 --- a/sources/redis_subscriber.cpp +++ b/sources/redis_subscriber.cpp @@ -4,8 +4,8 @@ namespace cpp_redis { -redis_subscriber::redis_subscriber(const std::shared_ptr& IO) -: m_client(IO) { +redis_subscriber::redis_subscriber(const std::shared_ptr& io_service) +: m_client(io_service) { __CPP_REDIS_LOG(debug, "cpp_redis::redis_subscriber created"); } diff --git a/sources/reply.cpp b/sources/reply.cpp index fbdd5f75..eaa92f23 100644 --- a/sources/reply.cpp +++ b/sources/reply.cpp @@ -1,5 +1,5 @@ -#include "cpp_redis/reply.hpp" -#include "cpp_redis/redis_error.hpp" +#include +#include namespace cpp_redis { diff --git a/tests/sources/spec/redis_subscriber_spec.cpp b/tests/sources/spec/redis_subscriber_spec.cpp index a0804a45..0cf6ac9f 100644 --- a/tests/sources/spec/redis_subscriber_spec.cpp +++ b/tests/sources/spec/redis_subscriber_spec.cpp @@ -298,6 +298,14 @@ TEST(RedisSubscriber, MultipleSubscribeSomethingPublished) { sub.connect(); client.connect(); + auto ack_callback = [&](int nb_chans) { + if (nb_chans == 2) { + client.publish("/chan_1", "hello"); + client.publish("/chan_2", "world"); + client.commit(); + } + }; + std::atomic_bool callback_1_run(false); std::atomic_bool callback_2_run(false); sub.subscribe("/chan_1", @@ -309,13 +317,7 @@ TEST(RedisSubscriber, MultipleSubscribeSomethingPublished) { if (callback_2_run) cv.notify_all(); }, - [&](int nb_chans) { - if (nb_chans == 2) { - client.publish("/chan_1", "hello"); - client.publish("/chan_2", "world"); - client.commit(); - } - }); + ack_callback); sub.subscribe("/chan_2", [&](const std::string& channel, const std::string& message) { EXPECT_TRUE(channel == "/chan_2"); @@ -325,13 +327,7 @@ TEST(RedisSubscriber, MultipleSubscribeSomethingPublished) { if (callback_1_run) cv.notify_all(); }, - [&](int nb_chans) { - if (nb_chans == 2) { - client.publish("/chan_1", "hello"); - client.publish("/chan_2", "world"); - client.commit(); - } - }); + ack_callback); sub.commit(); @@ -446,6 +442,14 @@ TEST(RedisSubscriber, MultiplePSubscribeSomethingPublished) { sub.connect(); client.connect(); + auto ack_callback = [&](int nb_chans) { + if (nb_chans == 2) { + client.publish("/chan/1", "hello"); + client.publish("/other_chan/2", "world"); + client.commit(); + } + }; + std::atomic_bool callback_1_run(false); std::atomic_bool callback_2_run(false); sub.psubscribe("/chan/*", @@ -457,13 +461,7 @@ TEST(RedisSubscriber, MultiplePSubscribeSomethingPublished) { if (callback_2_run) cv.notify_all(); }, - [&](int nb_chans) { - if (nb_chans == 2) { - client.publish("/chan/1", "hello"); - client.publish("/other_chan/2", "world"); - client.commit(); - } - }); + ack_callback); sub.psubscribe("/other_chan/*", [&](const std::string& channel, const std::string& message) { EXPECT_TRUE(channel == "/other_chan/2"); @@ -473,13 +471,7 @@ TEST(RedisSubscriber, MultiplePSubscribeSomethingPublished) { if (callback_1_run) cv.notify_all(); }, - [&](int nb_chans) { - if (nb_chans == 2) { - client.publish("/chan/1", "hello"); - client.publish("/other_chan/2", "world"); - client.commit(); - } - }); + ack_callback); sub.commit(); @@ -499,31 +491,27 @@ TEST(RedisSubscriber, Unsubscribe) { sub.connect(); client.connect(); + auto ack_callback = [&](int nb_chans) { + if (nb_chans == 2) { + client.publish("/chan_1", "hello"); + client.publish("/chan_2", "hello"); + client.commit(); + } + }; + std::atomic_bool callback_1_run(false); std::atomic_bool callback_2_run(false); sub.subscribe("/chan_1", [&](const std::string&, const std::string&) { callback_1_run = true; }, - [&](int nb_chans) { - if (nb_chans == 2) { - client.publish("/chan_1", "hello"); - client.publish("/chan_2", "hello"); - client.commit(); - } - }); + ack_callback); sub.subscribe("/chan_2", [&](const std::string&, const std::string&) { callback_2_run = true; cv.notify_all(); }, - [&](int nb_chans) { - if (nb_chans == 2) { - client.publish("/chan_1", "hello"); - client.publish("/chan_2", "hello"); - client.commit(); - } - }); + ack_callback); sub.unsubscribe("/chan_1"); sub.commit(); @@ -544,31 +532,27 @@ TEST(RedisSubscriber, PUnsubscribe) { sub.connect(); client.connect(); + auto ack_callback = [&](int nb_chans) { + if (nb_chans == 2) { + client.publish("/chan_1/hello", "hello"); + client.publish("/chan_2/hello", "hello"); + client.commit(); + } + }; + std::atomic_bool callback_1_run(false); std::atomic_bool callback_2_run(false); sub.psubscribe("/chan_1/*", [&](const std::string&, const std::string&) { callback_1_run = true; }, - [&](int nb_chans) { - if (nb_chans == 2) { - client.publish("/chan_1/hello", "hello"); - client.publish("/chan_2/hello", "hello"); - client.commit(); - } - }); + ack_callback); sub.psubscribe("/chan_2/*", [&](const std::string&, const std::string&) { callback_2_run = true; cv.notify_all(); }, - [&](int nb_chans) { - if (nb_chans == 2) { - client.publish("/chan_1/hello", "hello"); - client.publish("/chan_2/hello", "hello"); - client.commit(); - } - }); + ack_callback); sub.punsubscribe("/chan_1/*"); sub.commit();