Skip to content

Commit

Permalink
dnsdist: Handle congested DoQ streams
Browse files Browse the repository at this point in the history
If the stream has no capacity left, Quiche will refuse to queue
more data and return `QUICHE_ERR_DONE`. We then have to wait until
the stream becomes writable again to retry sending our response.
  • Loading branch information
rgacogne committed Dec 15, 2023
1 parent 10ed86d commit c6886da
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 56 deletions.
127 changes: 91 additions & 36 deletions pdns/dnsdistdist/doh3.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class H3Connection
QuicheConnection d_conn;
QuicheHTTP3Connection d_http3{nullptr, quiche_h3_conn_free};
std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
};

static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description);
Expand Down Expand Up @@ -263,7 +264,29 @@ class DOH3CrossProtocolQuery : public CrossProtocolQuery

std::shared_ptr<DOH3TCPCrossQuerySender> DOH3CrossProtocolQuery::s_sender = std::make_shared<DOH3TCPCrossQuerySender>();

static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
static bool tryWriteResponse(H3Connection& conn, const uint64_t streamID, PacketBuffer& response)
{
size_t pos = 0;
while (pos < response.size()) {
// send_body takes care of setting fin to false if it cannot send the entire content so we can try again.
auto res = quiche_h3_send_body(conn.d_http3.get(), conn.d_conn.get(),
streamID, &response.at(pos), response.size() - pos, true);
if (res == QUICHE_H3_ERR_DONE || res == QUICHE_H3_TRANSPORT_ERR_DONE) {
response.erase(response.begin(), response.begin() + static_cast<ssize_t>(pos));
return false;
}
if (res < 0) {
// Shutdown with internal error code
quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(dnsdist::doq::DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
return true;
}
pos += res;
}

return true;
}

static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
{
std::string status = std::to_string(statusCode);
std::string lenStr = std::to_string(len);
Expand All @@ -285,8 +308,13 @@ static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const
.value_len = lenStr.size(),
},
};
quiche_h3_send_response(conn, quic_conn,
streamID, headers.data(), headers.size(), len == 0);
auto returnValue = quiche_h3_send_response(conn.d_http3.get(), conn.d_conn.get(),
streamID, headers.data(), headers.size(), len == 0);
if (returnValue != 0) {
/* in theory it could be QUICHE_H3_ERR_STREAM_BLOCKED if the stream is not writable / congested, but we are not going to handle this case */
quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(dnsdist::doq::DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
return;
}

if (len == 0) {
return;
Expand All @@ -295,28 +323,27 @@ static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const
size_t pos = 0;
while (pos < len) {
// send_body takes care of setting fin to false if it cannot send the entire content so we can try again.
auto res = quiche_h3_send_body(conn, quic_conn,
auto res = quiche_h3_send_body(conn.d_http3.get(), conn.d_conn.get(),
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API
streamID, const_cast<uint8_t*>(body) + pos, len - pos, true);
if (res == QUICHE_H3_ERR_DONE || res == QUICHE_H3_TRANSPORT_ERR_DONE) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API
conn.d_streamOutBuffers[streamID] = PacketBuffer(body + pos, body + len);
return;
}
if (res < 0) {
// Shutdown with internal error code
quiche_conn_stream_shutdown(quic_conn, streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(1));
quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(1));
return;
}
pos += res;
}
}

static void h3_send_response(quiche_conn* quic_conn, quiche_h3_conn* conn, const uint64_t streamID, uint16_t statusCode, const std::string& content)
static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const std::string& content)
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
h3_send_response(quic_conn, conn, streamID, statusCode, reinterpret_cast<const uint8_t*>(content.data()), content.size());
}

static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
{
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, statusCode, body, len);
h3_send_response(conn, streamID, statusCode, reinterpret_cast<const uint8_t*>(content.data()), content.size());
}

static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const PacketBuffer& response)
Expand Down Expand Up @@ -616,14 +643,29 @@ static void flushResponses(pdns::channel::Receiver<DOH3Unit>& receiver)
}
}

static void flushStalledResponses(H3Connection& conn)
{
for (auto streamIt = conn.d_streamOutBuffers.begin(); streamIt != conn.d_streamOutBuffers.end();) {
const auto streamID = streamIt->first;
auto& response = streamIt->second;
if (quiche_conn_stream_writable(conn.d_conn.get(), streamID, response.size()) == 1) {
if (tryWriteResponse(conn, streamID, response)) {
streamIt = conn.d_streamOutBuffers.erase(streamIt);
continue;
}
}
++streamIt;
}
}

static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID, std::map<std::string, std::string>& headers, int64_t streamID, quiche_h3_event* event)
{
auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) {
DEBUGLOG(msg);
++dnsdist::metrics::g_stats.nonCompliantQueries;
++clientState.nonCompliantQueries;
++frontend.d_errorResponses;
h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, msg);
h3_send_response(conn, streamID, 400, msg);
};

// Callback result. Any value other than 0 will interrupt further header processing.
Expand Down Expand Up @@ -684,37 +726,49 @@ static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend,
++dnsdist::metrics::g_stats.nonCompliantQueries;
++clientState.nonCompliantQueries;
++frontend.d_errorResponses;
h3_send_response(conn.d_conn.get(), conn.d_http3.get(), streamID, 400, msg);
h3_send_response(conn, streamID, 400, msg);
};

if (headers.at(":method") == "POST") {
if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") {
handleImmediateError("Unsupported content-type");
return;
}
PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
PacketBuffer decoded;
if (headers.at(":method") != "POST") {
handleImmediateError("DATA frame for non-POST method");
return;
}

while (true) {
ssize_t len = quiche_h3_recv_body(conn.d_http3.get(),
conn.d_conn.get(), streamID,
buffer.data(), buffer.capacity());
if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") {
handleImmediateError("Unsupported content-type");
return;
}

if (len <= 0) {
break;
}
decoded.insert(decoded.end(), buffer.begin(), buffer.begin() + len);
}
PacketBuffer buffer(std::numeric_limits<uint16_t>::max());
auto& streamBuffer = conn.d_streamBuffers[streamID];

if (decoded.size() < sizeof(dnsheader)) {
handleImmediateError("DoH3 non-compliant query");
return;
while (true) {
buffer.resize(std::numeric_limits<uint16_t>::max());
ssize_t len = quiche_h3_recv_body(conn.d_http3.get(),
conn.d_conn.get(), streamID,
buffer.data(), buffer.capacity());

if (len <= 0) {
break;
}

DEBUGLOG("Dispatching POST query");
doh3_dispatch_query(*(frontend.d_server_config), std::move(decoded), clientState.local, client, serverConnID, streamID);
buffer.resize(static_cast<size_t>(len));
streamBuffer.insert(streamBuffer.end(), buffer.begin(), buffer.end());
}

if (!quiche_conn_stream_finished(conn.d_conn.get(), streamID)) {
return;
}

if (streamBuffer.size() < sizeof(dnsheader)) {
conn.d_streamBuffers.erase(streamID);
handleImmediateError("DoH3 non-compliant query");
return;
}

DEBUGLOG("Dispatching POST query");
doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), clientState.local, client, serverConnID, streamID);
conn.d_streamBuffers.erase(streamID);
}

static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, PacketBuffer& serverConnID)
Expand Down Expand Up @@ -892,6 +946,7 @@ void doh3Thread(ClientState* clientState)
conn = frontend->d_server_config->d_connections.erase(conn);
}
else {
flushStalledResponses(conn->second);
++conn;
}
}
Expand Down
60 changes: 40 additions & 20 deletions pdns/dnsdistdist/doq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Connection
ComboAddress d_peer;
QuicheConnection d_conn;
std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
};

static void sendBackDOQUnit(DOQUnitUniquePtr&& unit, const char* description);
Expand Down Expand Up @@ -260,7 +261,26 @@ class DOQCrossProtocolQuery : public CrossProtocolQuery

std::shared_ptr<DOQTCPCrossQuerySender> DOQCrossProtocolQuery::s_sender = std::make_shared<DOQTCPCrossQuerySender>();

static void handleResponse(DOQFrontend& frontend, Connection& conn, const uint64_t streamID, const PacketBuffer& response)
static bool tryWriteResponse(Connection& conn, const uint64_t streamID, PacketBuffer& response)
{
size_t pos = 0;
while (pos < response.size()) {
auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &response.at(pos), response.size() - pos, true);
if (res == QUICHE_ERR_DONE) {
response.erase(response.begin(), response.begin() + static_cast<ssize_t>(pos));
return false;
}
if (res < 0) {
quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
return true;
}
pos += res;
}

return true;
}

static void handleResponse(DOQFrontend& frontend, Connection& conn, const uint64_t streamID, PacketBuffer& response)
{
if (response.empty()) {
++frontend.d_errorResponses;
Expand All @@ -270,25 +290,9 @@ static void handleResponse(DOQFrontend& frontend, Connection& conn, const uint64
++frontend.d_validResponses;
auto responseSize = static_cast<uint16_t>(response.size());
const std::array<uint8_t, 2> sizeBytes = {static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256)};
size_t pos = 0;
while (pos < sizeBytes.size()) {
auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &sizeBytes.at(pos), sizeBytes.size() - pos, false);
if (res < 0) {
quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
return;
}
pos += res;
}

pos = 0;
while (pos < response.size()) {
// stream_send sets fin to false itself when the capacity of the stream is less than the desired writing length
auto res = quiche_conn_stream_send(conn.d_conn.get(), streamID, &response.at(pos), response.size() - pos, true);
if (res < 0) {
quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
return;
}
pos += res;
response.insert(response.begin(), sizeBytes.begin(), sizeBytes.end());
if (!tryWriteResponse(conn, streamID, response)) {
conn.d_streamOutBuffers[streamID] = std::move(response);
}
}

Expand Down Expand Up @@ -560,6 +564,21 @@ static void flushResponses(pdns::channel::Receiver<DOQUnit>& receiver)
}
}

static void flushStalledResponses(Connection& conn)
{
for (auto streamIt = conn.d_streamOutBuffers.begin(); streamIt != conn.d_streamOutBuffers.end();) {
const auto& streamID = streamIt->first;
auto& response = streamIt->second;
if (quiche_conn_stream_writable(conn.d_conn.get(), streamID, response.size()) == 1) {
if (tryWriteResponse(conn, streamID, response)) {
streamIt = conn.d_streamOutBuffers.erase(streamIt);
continue;
}
}
++streamIt;
}
}

// this is the entrypoint from dnsdist.cc
void doqThread(ClientState* clientState)
{
Expand Down Expand Up @@ -721,6 +740,7 @@ void doqThread(ClientState* clientState)
conn = frontend->d_server_config->d_connections.erase(conn);
}
else {
flushStalledResponses(conn->second);
++conn;
}
}
Expand Down

0 comments on commit c6886da

Please sign in to comment.