diff --git a/api/net/openssl/tls_stream.hpp b/api/net/openssl/tls_stream.hpp index 81225eb254..ea60df6d46 100644 --- a/api/net/openssl/tls_stream.hpp +++ b/api/net/openssl/tls_stream.hpp @@ -2,18 +2,18 @@ #include #include #include -#include +#include -//#define VERBOSE_OPENSSL +//#define VERBOSE_OPENSSL 0 #ifdef VERBOSE_OPENSSL -#define TLS_PRINT(fmt, ...) printf(fmt, ##__VA_ARGS__) +#define TLS_PRINT(fmt, ...) printf("TLS_Stream");printf(fmt, ##__VA_ARGS__) #else #define TLS_PRINT(fmt, ...) /* fmt */ #endif namespace openssl { - struct TLS_stream : public net::Stream + struct TLS_stream : public net::StreamBuffer { using Stream_ptr = net::Stream_ptr; @@ -25,7 +25,6 @@ namespace openssl void write(const std::string&) override; void write(const void* buf, size_t n) override; void close() override; - void reset_callbacks() override; net::Socket local() const override { return m_transport->local(); @@ -37,35 +36,11 @@ namespace openssl return m_transport->to_string(); } - void on_connect(ConnectCallback cb) override { - m_on_connect = std::move(cb); - } - void on_read(size_t, ReadCallback cb) override { - m_on_read = std::move(cb); - } - void on_data(DataCallback cb) override { - m_on_data = std::move(cb); - } - size_t next_size() override { - // FIXME: implement buffering for read_next - return 0; - } - buffer_t read_next() override { - // FIXME: implement buffering for read_next - return{}; - } - void on_close(CloseCallback cb) override { - m_on_close = std::move(cb); - } - void on_write(WriteCallback cb) override { - m_on_write = std::move(cb); - } - bool is_connected() const noexcept override { return handshake_completed() && m_transport->is_connected(); } bool is_writable() const noexcept override { - return is_connected() && m_transport->is_writable(); + return (not write_congested()) && is_connected() && m_transport->is_writable(); } bool is_readable() const noexcept override { return m_transport->is_readable(); @@ -87,7 +62,12 @@ namespace openssl size_t serialize_to(void*) const override; + void handle_read_congestion() override; + void handle_write_congestion() override; private: + void handle_data(); + int decrypt(const void *data,int size); + int send_decrypted(); void tls_read(buffer_t); int tls_perform_stream_write(); int tls_perform_handshake(); @@ -100,273 +80,12 @@ namespace openssl STATUS_FAIL }; status_t status(int n) const noexcept; - Stream_ptr m_transport = nullptr; SSL* m_ssl = nullptr; BIO* m_bio_rd = nullptr; BIO* m_bio_wr = nullptr; bool m_busy = false; bool m_deferred_close = false; - ConnectCallback m_on_connect = nullptr; - ReadCallback m_on_read = nullptr; - DataCallback m_on_data = nullptr; - WriteCallback m_on_write = nullptr; - CloseCallback m_on_close = nullptr; }; - inline TLS_stream::TLS_stream(SSL_CTX* ctx, Stream_ptr t, bool outgoing) - : m_transport(std::move(t)) - { - ERR_clear_error(); // prevent old errors from mucking things up - this->m_bio_rd = BIO_new(BIO_s_mem()); - this->m_bio_wr = BIO_new(BIO_s_mem()); - assert(ERR_get_error() == 0 && "Initializing BIOs"); - this->m_ssl = SSL_new(ctx); - assert(this->m_ssl != nullptr); - assert(ERR_get_error() == 0 && "Initializing SSL"); - // TLS server-mode - if (outgoing == false) - SSL_set_accept_state(this->m_ssl); - else - SSL_set_connect_state(this->m_ssl); - - SSL_set_bio(this->m_ssl, this->m_bio_rd, this->m_bio_wr); - // always-on callbacks - // FIXME: consider using on_data as the default always-on event. - m_transport->on_read(8192, {this, &TLS_stream::tls_read}); - m_transport->on_close({this, &TLS_stream::close_callback_once}); - - // start TLS handshake process - if (outgoing == true) - { - if (this->tls_perform_handshake() < 0) return; - } - } - inline TLS_stream::TLS_stream(Stream_ptr t, SSL* ssl, BIO* rd, BIO* wr) - : m_transport(std::move(t)), m_ssl(ssl), m_bio_rd(rd), m_bio_wr(wr) - { - // always-on callbacks - m_transport->on_read(8192, {this, &TLS_stream::tls_read}); - m_transport->on_close({this, &TLS_stream::close_callback_once}); - } - inline TLS_stream::~TLS_stream() - { - assert(m_busy == false && "Cannot delete stream while in its call stack"); - SSL_free(this->m_ssl); - } - - inline void TLS_stream::write(buffer_t buffer) - { - if (UNLIKELY(this->is_connected() == false)) { - TLS_PRINT("TLS_stream::write() called on closed stream\n"); - return; - } - - int n = SSL_write(this->m_ssl, buffer->data(), buffer->size()); - auto status = this->status(n); - if (status == STATUS_FAIL) { - this->close(); - return; - } - - do { - n = tls_perform_stream_write(); - } while (n > 0); - } - inline void TLS_stream::write(const std::string& str) - { - write(net::Stream::construct_buffer(str.data(), str.data() + str.size())); - } - inline void TLS_stream::write(const void* data, const size_t len) - { - auto* buf = static_cast (data); - write(net::Stream::construct_buffer(buf, buf + len)); - } - - inline void TLS_stream::tls_read(buffer_t buffer) - { - ERR_clear_error(); - uint8_t* buf = buffer->data(); - int len = buffer->size(); - - while (len > 0) - { - int n = BIO_write(this->m_bio_rd, buf, len); - if (UNLIKELY(n < 0)) { - this->close(); - return; - } - buf += n; - len -= n; - - // if we aren't finished initializing session - if (UNLIKELY(!handshake_completed())) - { - int num = SSL_do_handshake(this->m_ssl); - auto status = this->status(num); - - // OpenSSL wants to write - if (status == STATUS_WANT_IO) - { - tls_perform_stream_write(); - } - else if (status == STATUS_FAIL) - { - if (num < 0) { - TLS_PRINT("TLS_stream::SSL_do_handshake() returned %d\n", num); - #ifdef VERBOSE_OPENSSL - ERR_print_errors_fp(stdout); - #endif - } - this->close(); - return; - } - // nothing more to do if still not finished - if (handshake_completed() == false) return; - // handshake success - if (m_on_connect) m_on_connect(*this); - } - - // read decrypted data - do { - char temp[8192]; - n = SSL_read(this->m_ssl, temp, sizeof(temp)); - if (n > 0) { - auto buf = net::Stream::construct_buffer(temp, temp + n); - if (m_on_read) { - this->m_busy = true; - m_on_read(std::move(buf)); - this->m_busy = false; - } - } - } while (n > 0); - // this goes here? - if (UNLIKELY(this->is_closing() || this->is_closed())) { - TLS_PRINT("TLS_stream::SSL_read closed during read\n"); - return; - } - if (this->m_deferred_close) { - this->close(); return; - } - - auto status = this->status(n); - // did peer request stream renegotiation? - if (status == STATUS_WANT_IO) - { - do { - n = tls_perform_stream_write(); - } while (n > 0); - } - else if (status == STATUS_FAIL) - { - this->close(); - return; - } - // check deferred closing - if (this->m_deferred_close) { - this->close(); return; - } - - } // while it < end - } // tls_read() - - inline int TLS_stream::tls_perform_stream_write() - { - ERR_clear_error(); - int pending = BIO_ctrl_pending(this->m_bio_wr); - //printf("pending: %d\n", pending); - if (pending > 0) - { - auto buffer = net::Stream::construct_buffer(pending); - int n = BIO_read(this->m_bio_wr, buffer->data(), buffer->size()); - assert(n == pending); - m_transport->write(buffer); - if (m_on_write) { - this->m_busy = true; - m_on_write(n); - this->m_busy = false; - } - return n; - } - else { - BIO_read(this->m_bio_wr, nullptr, 0); - } - if (!BIO_should_retry(this->m_bio_wr)) - { - this->close(); - return -1; - } - return 0; - } - inline int TLS_stream::tls_perform_handshake() - { - ERR_clear_error(); // prevent old errors from mucking things up - // will return -1:SSL_ERROR_WANT_WRITE - int ret = SSL_do_handshake(this->m_ssl); - int n = this->status(ret); - ERR_print_errors_fp(stderr); - if (n == STATUS_WANT_IO) - { - do { - n = tls_perform_stream_write(); - if (n < 0) { - TLS_PRINT("TLS_stream::tls_perform_handshake() stream write failed\n"); - } - } while (n > 0); - return n; - } - else { - TLS_PRINT("TLS_stream::tls_perform_handshake() returned %d\n", ret); - this->close(); - return -1; - } - } - - inline void TLS_stream::close() - { - //ERR_clear_error(); - if (this->m_busy) { - this->m_deferred_close = true; return; - } - CloseCallback func = std::move(this->m_on_close); - this->reset_callbacks(); - if (m_transport->is_connected()) - m_transport->close(); - if (func) func(); - } - inline void TLS_stream::close_callback_once() - { - if (this->m_busy) { - this->m_deferred_close = true; return; - } - CloseCallback func = std::move(this->m_on_close); - this->reset_callbacks(); - if (func) func(); - } - inline void TLS_stream::reset_callbacks() - { - this->m_on_close = nullptr; - this->m_on_connect = nullptr; - this->m_on_read = nullptr; - this->m_on_write = nullptr; - } - - inline bool TLS_stream::handshake_completed() const noexcept - { - return SSL_is_init_finished(this->m_ssl); - } - inline TLS_stream::status_t TLS_stream::status(int n) const noexcept - { - int error = SSL_get_error(this->m_ssl, n); - switch (error) - { - case SSL_ERROR_NONE: - return STATUS_OK; - case SSL_ERROR_WANT_WRITE: - case SSL_ERROR_WANT_READ: - return STATUS_WANT_IO; - default: - return STATUS_FAIL; - } - } } // openssl diff --git a/api/net/stream_buffer.hpp b/api/net/stream_buffer.hpp new file mode 100644 index 0000000000..8aaf2a428f --- /dev/null +++ b/api/net/stream_buffer.hpp @@ -0,0 +1,214 @@ +#ifndef STREAMBUFFERR_HPP +#define STREAMBUFFERR_HPP +#include +#include +#include + +namespace net { + class StreamBuffer : public net::Stream + { + public: + StreamBuffer(Timers::duration_t timeout=std::chrono::microseconds(10)) + : timer({this,&StreamBuffer::congested}),congestion_timeout(timeout) {} + using buffer_t = os::mem::buf_ptr; + using Ready_queue = std::deque; + virtual ~StreamBuffer() { + timer.stop(); + } + + void on_connect(ConnectCallback cb) override { + m_on_connect = std::move(cb); + } + + void on_read(size_t, ReadCallback cb) override { + m_on_read = std::move(cb); + } + void on_data(DataCallback cb) override { + m_on_data = std::move(cb); + } + size_t next_size() override; + + buffer_t read_next() override; + + void on_close(CloseCallback cb) override { + m_on_close = std::move(cb); + } + void on_write(WriteCallback cb) override { + m_on_write = std::move(cb); + } + + void signal_data(); + + bool read_congested() const noexcept + { return m_read_congested; } + + bool write_congested() const noexcept + { return m_write_congested; } + + /** + * @brief Construct a shared read vector used by streams + * If allocation failed congestion flag is set + * + * @param construction parameters + * + * @return nullptr on failure, shared_ptr to buffer on success + */ + template + buffer_t construct_read_buffer(Args&&... args) + { + return construct_buffer_with_flag(m_read_congested,std::forward (args)...); + } + + /** + * @brief Construct a shared write vector used by streams + * If allocation failed congestion flag is set + * + * @param construction parameters + * + * @return nullptr on failure, shared_ptr to buffer on success + */ + template + buffer_t construct_write_buffer(Args&&... args) + { + return construct_buffer_with_flag(m_write_congested,std::forward (args)...); + } + + virtual void handle_read_congestion() = 0; + virtual void handle_write_congestion() = 0; + protected: + void closed() + { if (m_on_close) m_on_close(); } + void connected() + { if (m_on_connect) m_on_connect(*this); } + void stream_on_write(int n) + { if (m_on_write) m_on_write(n); } + void enqueue_data(buffer_t data) + { m_send_buffers.push_back(data); } + + void congested(); + + CloseCallback getCloseCallback() { return std::move(this->m_on_close); } + + void reset_callbacks() override + { + //remove queue and reset congestion flags and busy flag ?? + this->m_on_close = nullptr; + this->m_on_connect = nullptr; + this->m_on_read = nullptr; + this->m_on_write = nullptr; + this->m_on_data = nullptr; + } + Timer timer; + + private: + Timer::duration_t congestion_timeout; + bool m_write_congested= false; + bool m_read_congested = false; + + ConnectCallback m_on_connect = nullptr; + ReadCallback m_on_read = nullptr; + DataCallback m_on_data = nullptr; + WriteCallback m_on_write = nullptr; + CloseCallback m_on_close = nullptr; + Ready_queue m_send_buffers; + + /** + * @brief Construct a shared vector and set congestion flag if allocation fails + * + * @param flag the flag to set true or false on allocation failure + * @param args arguments to constructing the buffer + * @return nullptr on failure , shared pointer to buffer on success + */ + + template + buffer_t construct_buffer_with_flag(bool &flag,Args&&... args) + { + static buffer_t buffer; + try + { + buffer = std::make_shared(std::forward (args)...); + flag = false; + } + catch (std::bad_alloc &e) + { + flag = true; + timer.start(congestion_timeout); + return nullptr; + } + return buffer; + } + + + }; // < class StreamBuffer + + inline size_t StreamBuffer::next_size() + { + if (not m_send_buffers.empty()) { + return m_send_buffers.front()->size(); + } + return 0; + } + + inline StreamBuffer::buffer_t StreamBuffer::read_next() + { + + if (not m_send_buffers.empty()) { + auto buf = m_send_buffers.front(); + m_send_buffers.pop_front(); + return buf; + } + return nullptr; + } + + inline void StreamBuffer::congested() + { + if (m_read_congested) + { + handle_read_congestion(); + } + if (m_write_congested) + { + handle_write_congestion(); + } + //if any of the congestion states are still active make sure the timer is running + if(m_read_congested or m_write_congested) + { + if (!timer.is_running()) + { + timer.start(congestion_timeout); + } + } + else + { + if (timer.is_running()) + { + timer.stop(); + } + } + } + + inline void StreamBuffer::signal_data() + { + if (not m_send_buffers.empty()) + { + if (m_on_data != nullptr){ + //on_data_callback(); + m_on_data(); + if (not m_send_buffers.empty()) { + m_read_congested=true; + timer.start(congestion_timeout); + } + } + else if (m_on_read != nullptr) + { + for (auto buf : m_send_buffers) { + // Pop each time, in case callback leads to another call here. + m_send_buffers.pop_front(); + m_on_read(buf); + if (m_on_read == nullptr) { break; } //if calling m_on_read reset the callbacks exit + } + } + } + } +} // namespace net +#endif // STREAMBUFFERR_HPP diff --git a/api/net/tcp/connection.hpp b/api/net/tcp/connection.hpp index f49dfd4062..a66bc33d90 100644 --- a/api/net/tcp/connection.hpp +++ b/api/net/tcp/connection.hpp @@ -322,7 +322,7 @@ class Connection { * @return True if able to send, False otherwise. */ bool can_send() const noexcept - { return usable_window() and writeq.has_remaining_requests(); } + { return (usable_window() >= SMSS()) and writeq.has_remaining_requests(); } /** * @brief Return the "tuple" (id) of the connection. diff --git a/api/util/detail/alloc_pmr.hpp b/api/util/detail/alloc_pmr.hpp index 68fbb2227d..73f67fb5cf 100644 --- a/api/util/detail/alloc_pmr.hpp +++ b/api/util/detail/alloc_pmr.hpp @@ -33,6 +33,7 @@ namespace os::mem::detail { void* do_allocate(size_t size, size_t align) override { if (UNLIKELY(size + allocated_ > cap_total_)) { + //printf("pmr about to throw bad alloc: sz=%zu alloc=%zu cap=%zu\n", size, allocated_, cap_total_); throw std::bad_alloc(); } @@ -46,6 +47,7 @@ namespace os::mem::detail { void* buf = memalign(align, size); if (buf == nullptr) { + //printf("pmr memalign return nullptr, throw bad alloc\n"); throw std::bad_alloc(); } @@ -152,7 +154,9 @@ namespace os::mem::detail { std::size_t resource_capacity() { if (cap_suballoc_ == 0) + { return cap_total_ / (used_resources_ + os::mem::Pmr_pool::resource_division_offset); + } return cap_suballoc_; } @@ -244,7 +248,9 @@ namespace os::mem { // Pmr_resource implementation // Pmr_resource::Pmr_resource(Pool_ptr p) : pool_{p} {} - std::size_t Pmr_resource::capacity() { return pool_->resource_capacity(); } + std::size_t Pmr_resource::capacity() { + return std::min(pool_->resource_capacity(), pool_->allocatable()); + } std::size_t Pmr_resource::allocatable() { auto cap = capacity(); if (used > cap) @@ -267,10 +273,8 @@ namespace os::mem { } void* buf = pool_->allocate(size, align); - used += size; allocs++; - return buf; } diff --git a/lib/microLB/micro_lb/balancer.cpp b/lib/microLB/micro_lb/balancer.cpp index 317e55809a..9b0f124da6 100644 --- a/lib/microLB/micro_lb/balancer.cpp +++ b/lib/microLB/micro_lb/balancer.cpp @@ -1,8 +1,6 @@ #include "balancer.hpp" #include -#define READQ_PER_CLIENT 4096 -#define READQ_FOR_NODES 8192 #define MAX_OUTGOING_ATTEMPTS 100 // checking if nodes are dead or not #define ACTIVE_INITIAL_PERIOD 8s @@ -13,7 +11,7 @@ #define LB_VERBOSE 0 #if LB_VERBOSE -#define LBOUT(fmt, ...) printf(fmt, ##__VA_ARGS__) +#define LBOUT(fmt, ...) printf("MICROLB: "); printf(fmt, ##__VA_ARGS__) #else #define LBOUT(fmt, ...) /** **/ #endif @@ -75,7 +73,7 @@ namespace microLB if (client.conn->is_connected()) { // NOTE: explicitly want to copy buffers net::Stream_ptr rval = - nodes.assign(std::move(client.conn), client.readq); + nodes.assign(std::move(client.conn)); if (rval == nullptr) { // done with this queue item queue.pop_front(); @@ -94,7 +92,7 @@ namespace microLB } void Balancer::handle_connections() { - LBOUT("Handle_connections. %i waiting \n", queue.size()); + LBOUT("Handle_connections. %lu waiting \n", queue.size()); // stop any rethrow timer since this is a de-facto retry if (this->throw_retry_timer != Timers::UNUSED_ID) { Timers::stop(this->throw_retry_timer); @@ -143,20 +141,11 @@ namespace microLB // Release connection if it closes before it's assigned to a node. this->conn->on_close([this](){ + printf("Waiting issuing close\n"); if (this->conn != nullptr) this->conn->reset_callbacks(); this->conn = nullptr; }); - - // queue incoming data from clients not yet - // assigned to a node - this->conn->on_read(READQ_PER_CLIENT, - [this] (auto buf) { - // prevent buffer bloat attack - this->total += buf->size(); - LBOUT("*** Queued %lu bytes\n", buf->size()); - readq.push_back(buf); - }); } void Nodes::create_connections(int total) @@ -189,7 +178,7 @@ namespace microLB } } } - net::Stream_ptr Nodes::assign(net::Stream_ptr conn, queue_vector_t& readq) + net::Stream_ptr Nodes::assign(net::Stream_ptr conn) { for (size_t i = 0; i < nodes.size(); i++) { @@ -202,11 +191,7 @@ namespace microLB assert(outgoing->is_connected()); LBOUT("Assigning client to node %d (%s)\n", algo_iterator, outgoing->to_string().c_str()); - // flush readq to outgoing before creating session - for (auto buffer : readq) { - LBOUT("*** Flushing %lu bytes\n", buffer->size()); - outgoing->write(buffer); - } + //Should we some way hold track of the session object ? auto& session = this->create_session( std::move(conn), std::move(outgoing)); @@ -266,16 +251,39 @@ namespace microLB assert(session.is_alive()); return session; } + + void Nodes::destroy_sessions() + { + for (auto& idx: closed_sessions) + { + auto &session=get_session(idx); + + // free session destroying potential unique ptr objects + session.incoming = nullptr; + auto out_tcp = dynamic_cast(session.outgoing->bottom_transport())->tcp(); + session.outgoing = nullptr; + // if we don't have anything to write to the backend, abort it. + if(not out_tcp->sendq_size()) + out_tcp->abort(); + free_sessions.push_back(session.self); + LBOUT("Session %d destroyed (total = %d)\n", session.self, session_cnt); + } + closed_sessions.clear(); + } void Nodes::close_session(int idx) { auto& session = get_session(idx); // remove connections session.incoming->reset_callbacks(); - session.incoming = nullptr; session.outgoing->reset_callbacks(); - session.outgoing = nullptr; - // free session - free_sessions.push_back(session.self); + closed_sessions.push_back(session.self); + + if (!cleanup_timer.is_running()) + { + cleanup_timer.start(std::chrono::milliseconds(10),[this](){ + this->destroy_sessions(); + }); + } session_cnt--; LBOUT("Session %d closed (total = %d)\n", session.self, session_cnt); } @@ -355,13 +363,24 @@ namespace microLB } void Node::connect() { - auto outgoing = this->stack.tcp().connect(this->addr); + net::tcp::Connection_ptr outgoing; + try + { + outgoing = this->stack.tcp().connect(this->addr); + } + catch([[maybe_unused]]const net::TCP_error& err) + { + LBOUT("Got exception: %s\n", err.what()); + this->restart_active_check(); + return; + } // connecting to node atm. this->connecting++; // retry timer when connect takes too long int fail_timer = Timers::oneshot(CONNECT_TIMEOUT, [this, outgoing] (int) { + printf("Fail timer\n"); // close connection outgoing->abort(); // no longer connecting @@ -403,8 +422,14 @@ namespace microLB auto conn = std::move(pool.back()); assert(conn != nullptr); pool.pop_back(); - if (conn->is_connected()) return conn; - else conn->close(); + if (conn->is_connected()) { + return conn; + } + else + { + printf("CLOSING SINCE conn->connected is false\n"); + conn->close(); + } } return nullptr; } @@ -415,19 +440,25 @@ namespace microLB : parent(n), self(idx), incoming(std::move(inc)), outgoing(std::move(out)) { - incoming->on_read(READQ_PER_CLIENT, - [this] (auto buf) { - assert(this->is_alive()); - this->outgoing->write(buf); + + incoming->on_data([this]() { + assert(this->is_alive()); + while((this->incoming->next_size() > 0) and this->outgoing->is_writable()) + { + this->outgoing->write(this->incoming->read_next()); + } }); incoming->on_close( [&nodes = n, idx] () { nodes.close_session(idx); }); - outgoing->on_read(READQ_FOR_NODES, - [this] (auto buf) { - assert(this->is_alive()); - this->incoming->write(buf); + + outgoing->on_data([this]() { + assert(this->is_alive()); + while((this->outgoing->next_size() > 0) and this->incoming->is_writable()) + { + this->incoming->write(this->outgoing->read_next()); + } }); outgoing->on_close( [&nodes = n, idx] () { diff --git a/lib/microLB/micro_lb/balancer.hpp b/lib/microLB/micro_lb/balancer.hpp index 8fb095b320..f05ffd8027 100644 --- a/lib/microLB/micro_lb/balancer.hpp +++ b/lib/microLB/micro_lb/balancer.hpp @@ -1,12 +1,12 @@ #pragma once #include #include +#include namespace microLB { typedef net::Inet netstack_t; typedef net::tcp::Connection_ptr tcp_ptr; - typedef std::vector queue_vector_t; typedef delegate pool_signal_t; struct Waiting { @@ -15,7 +15,6 @@ namespace microLB void serialize(liu::Storage&); net::Stream_ptr conn; - queue_vector_t readq; int total = 0; }; @@ -37,7 +36,7 @@ namespace microLB auto address() const noexcept { return this->addr; } int connection_attempts() const noexcept { return this->connecting; } int pool_size() const noexcept { return pool.size(); } - bool is_active() const noexcept { return active; }; + bool is_active() const noexcept { return active; } bool active_check() const noexcept { return do_active_check; } void restart_active_check(); @@ -77,9 +76,10 @@ namespace microLB void add_node(Args&&... args); void create_connections(int total); // returns the connection back if the operation fails - net::Stream_ptr assign(net::Stream_ptr, queue_vector_t&); + net::Stream_ptr assign(net::Stream_ptr); Session& create_session(net::Stream_ptr inc, net::Stream_ptr out); void close_session(int); + void destroy_sessions(); Session& get_session(int); void serialize(liu::Storage&); @@ -92,8 +92,10 @@ namespace microLB int conn_iterator = 0; int algo_iterator = 0; const bool do_active_check; + Timer cleanup_timer; std::deque sessions; std::deque free_sessions; + std::deque closed_sessions; }; struct Balancer { diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a74cc1c9d2..cc1eddb3ba 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,6 +22,7 @@ include_directories(${OPENSSL_DIR}/include) if(${ARCH} STREQUAL "x86_64") set(OPENSSL_MODULES "net/openssl/init.cpp" "net/openssl/client.cpp" "net/openssl/server.cpp" + "net/openssl/tls_stream.cpp" "net/https/openssl_server.cpp" "net/http/client.cpp") set(OPENSSL_LIBS openssl_ssl openssl_crypto) endif() diff --git a/src/net/conntrack.cpp b/src/net/conntrack.cpp index 3621f3ef1b..42cf65f343 100644 --- a/src/net/conntrack.cpp +++ b/src/net/conntrack.cpp @@ -352,23 +352,22 @@ int Conntrack::deserialize_from(void* addr) const auto size = *reinterpret_cast(buffer); buffer += sizeof(size_t); - + size_t dupes = 0; for(auto i = size; i > 0; i--) { // create the entry auto entry = std::make_shared(); buffer += entry->deserialize_from(buffer); - entries.emplace(std::piecewise_construct, - std::forward_as_tuple(entry->first, entry->proto), - std::forward_as_tuple(entry)); - - entries.emplace(std::piecewise_construct, - std::forward_as_tuple(entry->second, entry->proto), - std::forward_as_tuple(entry)); + bool insert = false; + insert = entries.insert_or_assign({entry->first, entry->proto}, entry).second; + if(not insert) + dupes++; + insert = entries.insert_or_assign({entry->second, entry->proto}, entry).second; + if(not insert) + dupes++; } - - Ensures(entries.size() - prev_size == size * 2); + Ensures(entries.size() - (prev_size-dupes) == size * 2); return buffer - reinterpret_cast(addr); } diff --git a/src/net/openssl/tls_stream.cpp b/src/net/openssl/tls_stream.cpp new file mode 100644 index 0000000000..16722a15f2 --- /dev/null +++ b/src/net/openssl/tls_stream.cpp @@ -0,0 +1,365 @@ +#include + +using namespace openssl; + +TLS_stream::TLS_stream(SSL_CTX* ctx, Stream_ptr t, bool outgoing) + : m_transport(std::move(t)) +{ + ERR_clear_error(); // prevent old errors from mucking things up + this->m_bio_rd = BIO_new(BIO_s_mem()); + this->m_bio_wr = BIO_new(BIO_s_mem()); + assert(ERR_get_error() == 0 && "Initializing BIOs"); + this->m_ssl = SSL_new(ctx); + assert(this->m_ssl != nullptr); + assert(ERR_get_error() == 0 && "Initializing SSL"); + // TLS server-mode + if (outgoing == false) + SSL_set_accept_state(this->m_ssl); + else + SSL_set_connect_state(this->m_ssl); + + SSL_set_bio(this->m_ssl, this->m_bio_rd, this->m_bio_wr); + + // always-on callbacks + m_transport->on_data({this,&TLS_stream::handle_data}); + m_transport->on_close({this, &TLS_stream::close_callback_once}); + + // start TLS handshake process + if (outgoing == true) + { + if (this->tls_perform_handshake() < 0) return; + } +} +TLS_stream::TLS_stream(Stream_ptr t, SSL* ssl, BIO* rd, BIO* wr) + : m_transport(std::move(t)), m_ssl(ssl), m_bio_rd(rd), m_bio_wr(wr) +{ + // always-on callbacks + m_transport->on_data({this, &TLS_stream::handle_data}); + m_transport->on_close({this, &TLS_stream::close_callback_once}); +} +TLS_stream::~TLS_stream() +{ + assert(m_busy == false && "Cannot delete stream while in its call stack"); + SSL_free(this->m_ssl); +} + +void TLS_stream::write(buffer_t buffer) +{ + + if (UNLIKELY(this->is_connected() == false)) { + TLS_PRINT("::write() called on closed stream\n"); + return; + } + int n = SSL_write(this->m_ssl, buffer->data(), buffer->size()); + auto status = this->status(n); + if (status == STATUS_FAIL) { + TLS_PRINT("::write() Fail status %d\n",n); + this->close(); + return; + } + + do { + n = tls_perform_stream_write(); + } while (n > 0); +} + +void TLS_stream::write(const std::string& str) +{ + //TODO handle failed alloc + write(net::StreamBuffer::construct_write_buffer(str.data(),str.data()+str.size())); +} + +void TLS_stream::write(const void* data, const size_t len) +{ + //TODO handle failed alloc + auto* buf = static_cast (data); + write(net::StreamBuffer::construct_write_buffer(buf, buf + len)); +} + +int TLS_stream::decrypt(const void *indata, int size) +{ + int n = BIO_write(this->m_bio_rd, indata, size); + if (UNLIKELY(n < 0)) { + //TODO can we handle this more gracefully? + TLS_PRINT("BIO_write failed\n"); + this->close(); + return 0; + } + + // if we aren't finished initializing session + if (UNLIKELY(!handshake_completed())) + { + int num = SSL_do_handshake(this->m_ssl); + auto status = this->status(num); + + // OpenSSL wants to write + if (status == STATUS_WANT_IO) + { + tls_perform_stream_write(); + } + else if (status == STATUS_FAIL) + { + if (num < 0) { + TLS_PRINT("TLS_stream::SSL_do_handshake() returned %d\n", num); + #ifdef VERBOSE_OPENSSL + ERR_print_errors_fp(stdout); + #endif + } + this->close(); + return 0; + } + // nothing more to do if still not finished + if (handshake_completed() == false) return 0; + // handshake success + this->m_busy=true; + connected(); + this->m_busy=false; + + if (this->m_deferred_close) { + TLS_PRINT("::read() close on m_deferred_close after tls_perform_stream_write\n"); + this->close(); + return 0; + } + } + return n; +} + +int TLS_stream::send_decrypted() +{ + int n; + // read decrypted data + do { + //TODO "increase the size or constructor based ??") + auto buffer=StreamBuffer::construct_read_buffer(8192); + if (!buffer) return 0; + n = SSL_read(this->m_ssl,buffer->data(),buffer->size()); + if (n > 0) { + buffer->resize(n); + enqueue_data(buffer); + } + } while (n > 0); + return n; +} + +void TLS_stream::handle_read_congestion() +{ + //Ordering could be different + send_decrypted(); //decrypt any incomplete + this->m_busy=true; + signal_data(); //send any pending + this->m_busy=false; + + if (this->m_deferred_close) { + TLS_PRINT("::read() close on m_deferred_close after tls_perform_stream_write\n"); + this->close(); + return; + } +} + +void TLS_stream::handle_write_congestion() +{ + //this should resolve the potential malloc congestion + //might be missing some TLS signalling but without malloc we cant do that either + while(tls_perform_stream_write() > 0); +} +void TLS_stream::handle_data() +{ + while ( m_transport->next_size() > 0) + { + if (UNLIKELY(read_congested())){ + break; + } + tls_read(m_transport->read_next()); + //bail + if (m_transport == nullptr) + { + printf("m_transport \n"); + break; + } + } +} + +void TLS_stream::tls_read(buffer_t buffer) +{ + if (buffer == nullptr ) { + return; + } + ERR_clear_error(); + uint8_t* buf_ptr = buffer->data(); + int len = buffer->size(); + + while (len > 0) + { + if (this->m_deferred_close) { + TLS_PRINT("::read() close on m_deferred_close"); + this->close(); + return; + } + + int decrypted_bytes=decrypt(buf_ptr,len); + if (UNLIKELY(decrypted_bytes==0)) return; + buf_ptr += decrypted_bytes; + len -= decrypted_bytes; + + //enqueues decrypted data + int ret=send_decrypted(); + + // this goes here? + if (UNLIKELY(this->is_closing() || this->is_closed())) { + TLS_PRINT("TLS_stream::SSL_read closed during read\n"); + return; + } + if (this->m_deferred_close) { + TLS_PRINT("::read() close on m_deferred_close"); + this->close(); + return; + } + + auto status = this->status(ret); + // did peer request stream renegotiation? + if (status == STATUS_WANT_IO) + { + TLS_PRINT("::read() STATUS_WANT_IO\n"); + int ret; + do { + ret = tls_perform_stream_write(); + } while (ret > 0); + } + else if (status == STATUS_FAIL) + { + TLS_PRINT("::read() close on STATUS_FAIL after tls_perform_stream_write\n"); + this->close(); + return; + } + + } // while it < end + + //forward data + this->m_busy=true; + signal_data(); + this->m_busy=false; + + // check deferred closing + if (this->m_deferred_close) { + TLS_PRINT("::read() close on m_deferred_close after tls_perform_stream_write\n"); + this->close(); return; + } +} // tls_read() + +int TLS_stream::tls_perform_stream_write() +{ + ERR_clear_error(); + int pending = BIO_ctrl_pending(this->m_bio_wr); + if (pending > 0) + { + TLS_PRINT("::tls_perform_stream_write() pending=%d bytes\n",pending); + auto buffer = net::StreamBuffer::construct_write_buffer(pending); + if (buffer == nullptr) { + return 0; + } + int n = BIO_read(this->m_bio_wr, buffer->data(), buffer->size()); + assert(n == pending); + //What if we cant write.. + if (m_transport->is_writable()) + { + m_transport->write(buffer); + + this->m_busy = true; + stream_on_write(n); + this->m_busy = false; + + if (this->m_deferred_close) { + TLS_PRINT("::read() close on m_deferred_close after tls_perform_stream_write\n"); + this->close(); return 0; + } + } + + if (UNLIKELY((pending = BIO_ctrl_pending(this->m_bio_wr)) > 0)) + { + return pending; + } + return 0; + } + + BIO_read(this->m_bio_wr, nullptr, 0); + + if (!BIO_should_retry(this->m_bio_wr)) + { + TLS_PRINT("::tls_perform_stream_write() close on !BIO_should_retry\n"); + this->close(); + return -1; + } + return 0; +} + +int TLS_stream::tls_perform_handshake() +{ + ERR_clear_error(); // prevent old errors from mucking things up + // will return -1:SSL_ERROR_WANT_WRITE + int ret = SSL_do_handshake(this->m_ssl); + int n = this->status(ret); + ERR_print_errors_fp(stderr); + if (n == STATUS_WANT_IO) + { + do { + n = tls_perform_stream_write(); + if (n < 0) { + TLS_PRINT("TLS_stream::tls_perform_handshake() stream write failed\n"); + } + } while (n > 0); + return n; + } + else { + TLS_PRINT("TLS_stream::tls_perform_handshake() returned %d\n", ret); + this->close(); + return -1; + } +} + +void TLS_stream::close() +{ + TLS_PRINT("TLS_stream::close()\n"); + //ERR_clear_error(); + if (this->m_busy) { + TLS_PRINT("TLS_stream::close() deferred\n"); + this->m_deferred_close = true; return; + } + CloseCallback func = getCloseCallback(); + this->reset_callbacks(); + if (m_transport->is_connected()) + { + m_transport->close(); + m_transport->reset_callbacks(); // ??? + } + if (func) func(); +} +void TLS_stream::close_callback_once() +{ + TLS_PRINT("TLS_stream::close_callback_once() \n"); + if (this->m_busy) { + TLS_PRINT("TLS_stream::close_callback_once() deferred\n"); + this->m_deferred_close = true; return; + } + CloseCallback func = getCloseCallback(); + this->reset_callbacks(); + if (func) func(); +} + +bool TLS_stream::handshake_completed() const noexcept +{ + return SSL_is_init_finished(this->m_ssl); +} +TLS_stream::status_t TLS_stream::status(int n) const noexcept +{ + int error = SSL_get_error(this->m_ssl, n); + switch (error) + { + case SSL_ERROR_NONE: + return STATUS_OK; + case SSL_ERROR_WANT_WRITE: + case SSL_ERROR_WANT_READ: + return STATUS_WANT_IO; + default: + return STATUS_FAIL; + } +} diff --git a/src/net/tcp/connection.cpp b/src/net/tcp/connection.cpp index 0bc326336e..38735407b4 100644 --- a/src/net/tcp/connection.cpp +++ b/src/net/tcp/connection.cpp @@ -53,8 +53,9 @@ Connection::Connection(TCP& host, Socket local, Socket remote, ConnectCallback c Connection::~Connection() { - //printf(" Deleted %p %s ACTIVE: %u\n", this, + //printf(" Deleted %p %s ACTIVE: %zu\n", this, // to_string().c_str(), host_.active_connections()); + rtx_clear(); } @@ -434,6 +435,8 @@ bool Connection::handle_ack(const Packet_view& in) if(is_win_update(in, true_win)) { + //if(cb.SND.WND < SMSS()*2) + // printf("Win update: %u => %u\n", cb.SND.WND, true_win); cb.SND.WND = true_win; cb.SND.WL1 = in.seq(); cb.SND.WL2 = in.ack(); @@ -613,7 +616,7 @@ void Connection::on_dup_ack(const Packet_view& in) // 3 dup acks else if(dup_acks_ == 3) { - debug(" Dup ACK == 3 - UNA=%u recover=%u\n", cb.SND.UNA, cb.recover); + //printf(" Dup ACK == 3 - UNA=%u recover=%u\n", cb.SND.UNA, cb.recover); if(cb.SND.UNA - 1 > cb.recover) goto fast_rtx; @@ -997,8 +1000,8 @@ void Connection::retransmit() { // TODO: Finish to send window zero probe, but only on rtx timeout - debug2(" With data (wq.sz=%u) buf.unacked=%u\n", - writeq.size(), buf->size(), buf->size() - writeq.acked()); + //printf(" With data (wq.sz=%zu) buf.size=%zu buf.unacked=%zu SND.WND=%u CWND=%u\n", + // writeq.size(), buf->size(), buf->size() - writeq.acked(), cb.SND.WND, cb.cwnd); fill_packet(*packet, buf->data() + writeq.acked(), buf->size() - writeq.acked()); packet->set_flag(PSH); } @@ -1070,8 +1073,8 @@ void Connection::rtx_clear() { begins (i.e., after the three-way handshake completes). */ void Connection::rtx_timeout() { - debug(" Timed out (RTO %lld ms). FS: %u\n", - rttm.rto_ms().count(), flight_size()); + //printf(" Timed out (RTO %lld ms). FS: %u usable=%u\n", + // rttm.rto_ms().count(), flight_size(), usable_window()); signal_rtx_timeout(); // experimental @@ -1421,12 +1424,12 @@ void Connection::reduce_ssthresh() { fs = (fs >= two_seg) ? fs - two_seg : 0; cb.ssthresh = std::max( (fs / 2), two_seg ); - debug2(" Slow start threshold reduced: %u\n", - cb.ssthresh); + //printf(" Slow start threshold reduced: %u\n", + // cb.ssthresh); } void Connection::fast_retransmit() { - debug(" Fast retransmit initiated.\n"); + //printf(" Fast retransmit initiated.\n"); // reduce sshtresh reduce_ssthresh(); // retransmit segment starting SND.UNA @@ -1441,5 +1444,5 @@ void Connection::finish_fast_recovery() { fast_recovery_ = false; //cb.cwnd = std::min(cb.ssthresh, std::max(flight_size(), (uint32_t)SMSS()) + SMSS()); cb.cwnd = cb.ssthresh; - debug(" Finished Fast Recovery - Cwnd: %u\n", cb.cwnd); + //printf(" Finished Fast Recovery - Cwnd: %u\n", cb.cwnd); } diff --git a/src/net/tcp/tcp.cpp b/src/net/tcp/tcp.cpp index f8f518a66c..a9392761a9 100644 --- a/src/net/tcp/tcp.cpp +++ b/src/net/tcp/tcp.cpp @@ -492,21 +492,27 @@ bool TCP::unbind(const Socket& socket) return false; } -bool TCP::add_connection(tcp::Connection_ptr conn) { +bool TCP::add_connection(tcp::Connection_ptr conn) +{ + const size_t alloc_thres = max_bufsize() * Read_request::buffer_limit; // Stat increment number of incoming connections (*incoming_connections_)++; debug(" Connection added %s \n", conn->to_string().c_str()); - conn->bufalloc = mempool_.get_resource(); + auto resource = mempool_.get_resource(); // Reject connection if we can't allocate memory - if (conn->bufalloc == nullptr - or conn->bufalloc->allocatable() < max_bufsize() * Read_request::buffer_limit){ + if(UNLIKELY(resource == nullptr or resource->allocatable() < alloc_thres)) + { conn->_on_cleanup_ = nullptr; conn->abort(); return false; } + conn->bufalloc = std::move(resource); + + //printf("New inc conn %s allocatable=%zu\n", conn->to_string().c_str(), conn->bufalloc->allocatable()); + Expects(conn->bufalloc != nullptr); conn->_on_cleanup({this, &TCP::close_connection}); return connections_.emplace(conn->tuple(), conn).second; @@ -514,6 +520,15 @@ bool TCP::add_connection(tcp::Connection_ptr conn) { Connection_ptr TCP::create_connection(Socket local, Socket remote, ConnectCallback cb) { + const size_t alloc_thres = max_bufsize() * Read_request::buffer_limit; + + auto resource = mempool_.get_resource(); + // Don't create connection if we can't allocate memory + if(UNLIKELY(resource == nullptr or resource->allocatable() < alloc_thres)) + { + throw TCP_error{"Unable to create new connection: Not enough allocatable memory"}; + } + // Stat increment number of outgoing connections (*outgoing_connections_)++; @@ -523,7 +538,10 @@ Connection_ptr TCP::create_connection(Socket local, Socket remote, ConnectCallba ) ).first->second; conn->_on_cleanup({this, &TCP::close_connection}); - conn->bufalloc = mempool_.get_resource(); + conn->bufalloc = std::move(resource); + + //printf("New out conn %s allocatable=%zu\n", conn->to_string().c_str(), conn->bufalloc->allocatable()); + Expects(conn->bufalloc != nullptr); return conn; } diff --git a/test/net/integration/microLB/server.js b/test/net/integration/microLB/server.js index b6ea1bd9fb..bd9e5e589f 100644 --- a/test/net/integration/microLB/server.js +++ b/test/net/integration/microLB/server.js @@ -1,10 +1,14 @@ var http = require('http'); +var url = require('url') -var dataString = function() { - var len = 1024*1024 * 50; +var dataString = function(len) { return '#'.repeat(len); } +function randomData(len) { + return Array.from({length:len}, () => Math.floor(Math.random() * 40)); +} + var stringToColour = function(str) { var hash = 0; for (var i = 0; i < str.length; i++) { @@ -18,13 +22,56 @@ var stringToColour = function(str) { return colour; } -//We need a function which handles requests and send response -function handleRequest(request, response){ +function handleDigest(path, request, response) { response.setTimeout(500); var addr = request.connection.localPort; response.end(addr.toString() + dataString()); } +function handleFile(path,request, response) { + response.setTimeout(500); + var addr = request.connection.localPort; + var size = parseInt(path.replace("/",""),10); + + if (size == 0) {  + size=1024*64; + } + response.end(addr.toString() + dataString(size)); +} + +function defaultHandler(path,request,response) { + response.setTimeout(500); + var addr = request.connection.localPort; + response.end(addr.toString() + dataString(1024*1024*50)); +} + +var routes = new Map([ + ['/digest' , handleDigest], + ['/file' , handleFile] + ]); + +function findHandler(path) +{ + for (const [key,value] of routes.entries()) { + if (path.startsWith(key)) + { + return { pattern: key, func: value}; + } + } + return { pattern :'',func : defaultHandler}; +} + +function handleRequest(request, response){ + var parts = url.parse(request.url); + + var route = findHandler(parts.pathname); + if (route.func) + { + var path = parts.pathname.replace(route.pattern,''); + route.func(path,request,response); + } +} + http.createServer(handleRequest).listen(6001, '10.0.0.1'); http.createServer(handleRequest).listen(6002, '10.0.0.1'); http.createServer(handleRequest).listen(6003, '10.0.0.1');