From c4edc98ea1ff134cc6be20a33fc35e0435accfae Mon Sep 17 00:00:00 2001 From: Azat Khuzhin Date: Wed, 5 Feb 2025 21:28:37 +0000 Subject: [PATCH] Merge pull request #74749 from azat/intermediate-connections-fix Avoid reusing connections that had been left in the intermediate state --- src/Client/Connection.cpp | 49 ++++++++++++++++++++++----------------- src/Client/Connection.h | 7 +++--- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/src/Client/Connection.cpp b/src/Client/Connection.cpp index a08f71bfa906..adb8d5417d8c 100644 --- a/src/Client/Connection.cpp +++ b/src/Client/Connection.cpp @@ -37,6 +37,7 @@ #include #include +#include #include #include "config.h" @@ -220,7 +221,7 @@ void Connection::connect(const ConnectionTimeouts & timeouts) connected = true; setDescription(); - sendHello(); + sendHello(timeouts.handshake_timeout); receiveHello(timeouts.handshake_timeout); if (server_revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_CHUNKED_PACKETS) @@ -371,7 +372,7 @@ void Connection::disconnect() } -void Connection::sendHello() +void Connection::sendHello([[maybe_unused]] const Poco::Timespan & handshake_timeout) { /** Disallow control characters in user controlled parameters * to mitigate the possibility of SSRF. @@ -424,7 +425,7 @@ void Connection::sendHello() writeStringBinary(String(EncodedUserInfo::SSH_KEY_AUTHENTICAION_MARKER) + user, *out); writeStringBinary(password, *out); - performHandshakeForSSHAuth(); + performHandshakeForSSHAuth(handshake_timeout); } #endif else if (!jwt.empty()) @@ -461,8 +462,10 @@ void Connection::sendAddendum() #if USE_SSH -void Connection::performHandshakeForSSHAuth() +void Connection::performHandshakeForSSHAuth(const Poco::Timespan & handshake_timeout) { + TimeoutSetter timeout_setter(*socket, handshake_timeout, handshake_timeout); + String challenge; { writeVarUInt(Protocol::Client::SSHChallengeRequest, *out); @@ -479,11 +482,7 @@ void Connection::performHandshakeForSSHAuth() else if (packet_type == Protocol::Server::Exception) receiveException()->rethrow(); else - { - /// Close connection, to not stay in unsynchronised state. - disconnect(); - throwUnexpectedPacket(packet_type, "SSHChallenge or Exception"); - } + throwUnexpectedPacket(timeout_setter, packet_type, "SSHChallenge or Exception"); } writeVarUInt(Protocol::Client::SSHChallengeResponse, *out); @@ -569,15 +568,7 @@ void Connection::receiveHello(const Poco::Timespan & handshake_timeout) else if (packet_type == Protocol::Server::Exception) receiveException()->rethrow(); else - { - /// Reset timeout_setter before disconnect, - /// because after disconnect socket will be invalid. - timeout_setter.reset(); - - /// Close connection, to not stay in unsynchronised state. - disconnect(); - throwUnexpectedPacket(packet_type, "Hello or Exception"); - } + throwUnexpectedPacket(timeout_setter, packet_type, "Hello or Exception"); } void Connection::setDefaultDatabase(const String & database) @@ -702,7 +693,7 @@ bool Connection::ping(const ConnectionTimeouts & timeouts) } if (pong != Protocol::Server::Pong) - throwUnexpectedPacket(pong, "Pong"); + throwUnexpectedPacket(timeout_setter, pong, "Pong"); } catch (const Poco::Exception & e) { @@ -741,7 +732,7 @@ TablesStatusResponse Connection::getTablesStatus(const ConnectionTimeouts & time if (response_type == Protocol::Server::Exception) receiveException()->rethrow(); else if (response_type != Protocol::Server::TablesStatusResponse) - throwUnexpectedPacket(response_type, "TablesStatusResponse"); + throwUnexpectedPacket(timeout_setter, response_type, "TablesStatusResponse"); TablesStatusResponse response; response.read(*in, server_revision); @@ -810,6 +801,14 @@ void Connection::sendQuery( query_id = query_id_; + /// Avoid reusing connections that had been left in the intermediate state + /// (i.e. not all packets had been sent). + bool completed = false; + SCOPE_EXIT({ + if (!completed) + disconnect(); + }); + writeVarUInt(Protocol::Client::Query, *out); writeStringBinary(query_id, *out); @@ -910,6 +909,8 @@ void Connection::sendQuery( sendData(Block(), "", false); out->next(); } + + completed = true; } @@ -1436,8 +1437,14 @@ InitialAllRangesAnnouncement Connection::receiveInitialParallelReadAnnouncement( } -void Connection::throwUnexpectedPacket(UInt64 packet_type, const char * expected) const +void Connection::throwUnexpectedPacket(TimeoutSetter & timeout_setter, UInt64 packet_type, const char * expected) { + /// Reset timeout_setter before disconnect, because after disconnect socket will be invalid. + timeout_setter.reset(); + + /// Close connection, to avoid leaving it in an unsynchronised state. + disconnect(); + throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_SERVER, "Unexpected packet from server {} (expected {}, got {})", getDescription(), expected, String(Protocol::Server::toString(packet_type))); diff --git a/src/Client/Connection.h b/src/Client/Connection.h index 29939e5d5541..f9f86b49e285 100644 --- a/src/Client/Connection.h +++ b/src/Client/Connection.h @@ -26,6 +26,7 @@ namespace DB { struct Settings; +struct TimeoutSetter; class Connection; struct ConnectionParameters; @@ -275,10 +276,10 @@ class Connection : public IServerConnection AsyncCallback async_callback = {}; void connect(const ConnectionTimeouts & timeouts); - void sendHello(); + void sendHello(const Poco::Timespan & handshake_timeout); #if USE_SSH - void performHandshakeForSSHAuth(); + void performHandshakeForSSHAuth(const Poco::Timespan & handshake_timeout); #endif void sendAddendum(); @@ -306,7 +307,7 @@ class Connection : public IServerConnection void initBlockLogsInput(); void initBlockProfileEventsInput(); - [[noreturn]] void throwUnexpectedPacket(UInt64 packet_type, const char * expected) const; + [[noreturn]] void throwUnexpectedPacket(TimeoutSetter & timeout_setter, UInt64 packet_type, const char * expected); }; template