Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved handling of errors while sending TCP responses. #309

Merged
merged 15 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 14 additions & 34 deletions src/net/server/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::base::{Message, StreamTarget};
use crate::net::server::buf::BufSource;
use crate::net::server::message::Request;
use crate::net::server::metrics::ServerMetrics;
use crate::net::server::service::{Service, ServiceError, ServiceFeedback};
use crate::net::server::service::{Service, ServiceFeedback};
use crate::net::server::util::to_pcap_text;
use crate::utils::config::DefMinMax;

Expand Down Expand Up @@ -216,7 +216,7 @@ impl Clone for Config {
}
}

//------------ Connection -----------------------------------------------
//------------ Connection ----------------------------------------------------

/// A handler for a single stream connection between client and server.
pub struct Connection<Stream, Buf, Svc>
Expand Down Expand Up @@ -446,9 +446,6 @@ where
self.flush_write_queue().await;
break 'outer;
}
ConnectionEvent::ServiceError(err) => {
error!("Service error: {}", err);
}
}
}
}
Expand Down Expand Up @@ -536,10 +533,7 @@ where
}

/// Stop queueing new responses and process those already in the queue.
async fn flush_write_queue(&mut self)
// where
// Target: Composer,
{
async fn flush_write_queue(&mut self) {
debug!("Flushing connection write queue.");
// Stop accepting new response messages (should we check for in-flight
// messages that haven't generated a response yet but should be
Expand All @@ -564,10 +558,7 @@ where
async fn process_queued_result(
&mut self,
response: Option<AdditionalBuilder<StreamTarget<Svc::Target>>>,
) -> Result<(), ConnectionEvent>
// where
// Target: Composer,
{
) -> Result<(), ConnectionEvent> {
// If we failed to read the results of requests processed by the
// service because the queue holding those results is empty and can no
// longer be read from, then there is no point continuing to read from
Expand All @@ -583,19 +574,14 @@ where
"Writing queued response with id {} to stream",
response.header().id()
);
self.write_response_to_stream(response.finish()).await;

Ok(())
self.write_response_to_stream(response.finish()).await
}

/// Write a response back to the caller over the network stream.
async fn write_response_to_stream(
&mut self,
msg: StreamTarget<Svc::Target>,
)
// where
// Target: AsRef<[u8]>,
{
) -> Result<(), ConnectionEvent> {
if enabled!(Level::TRACE) {
let bytes = msg.as_dgram_slice();
let pcap_text = to_pcap_text(bytes, bytes.len());
Expand All @@ -613,10 +599,11 @@ where
"Write timed out (>{:?})",
self.config.load().response_write_timeout
);
// TODO: Push it to the back of the queue to retry it?
return Err(ConnectionEvent::DisconnectWithoutFlush);
}
Ok(Err(err)) => {
error!("Write error: {err}");
return Err(ConnectionEvent::DisconnectWithoutFlush);
}
Ok(Ok(_)) => {
self.metrics.inc_num_sent_responses();
Expand All @@ -628,6 +615,8 @@ where
if self.result_q_tx.capacity() == self.result_q_tx.max_capacity() {
self.idle_timer.response_queue_emptied();
}

Ok(())
}

/// Implemnt DNS rules regarding timing out of idle connections.
Expand Down Expand Up @@ -674,9 +663,7 @@ where
tracing::warn!(
"Failed while parsing request message: {err}"
);
return Err(ConnectionEvent::ServiceError(
ServiceError::FormatError,
));
return Err(ConnectionEvent::DisconnectWithoutFlush);
}

// https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1
Expand Down Expand Up @@ -775,8 +762,8 @@ where
}

Err(TrySendError::Closed(_)) => {
error!("Unable to queue message for sending: server is shutting down.");
break;
error!("Unable to queue message for sending: connection is shutting down.");
return;
}

Err(TrySendError::Full(
Expand All @@ -790,7 +777,7 @@ where
unused_response;
} else {
error!("Unable to queue message for sending: queue is full.");
break;
return;
}
}
}
Expand Down Expand Up @@ -993,10 +980,6 @@ enum ConnectionEvent {
/// to send those responses. Of course, the DNS server MAY cache those
/// responses."
DisconnectWithFlush,

/// A [`Service`] specific error occurred while the service was processing
/// a request message.
ServiceError(ServiceError),
}

//--- Display
Expand All @@ -1010,9 +993,6 @@ impl Display for ConnectionEvent {
ConnectionEvent::DisconnectWithFlush => {
write!(f, "Disconnect with flush")
}
ConnectionEvent::ServiceError(err) => {
write!(f, "Service error: {err}")
}
}
}
}
Expand Down
123 changes: 119 additions & 4 deletions src/net/server/tests/unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::sleep;
use tokio::time::Instant;
use tracing::trace;
use tracing_subscriber::EnvFilter;

use crate::base::MessageBuilder;
use crate::base::Name;
Expand Down Expand Up @@ -43,20 +44,26 @@ struct MockStream {
/// The rate at which messages should be made available to the server.
new_message_every: Duration,

/// The number of responses pending.
pending_responses: usize,

/// Disconnect while one or more responses are pending?
disconnect_with_pending_responses: bool,
}

impl MockStream {
fn new(
messages_to_read: VecDeque<Vec<u8>>,
new_message_every: Duration,
disconnect_with_pending_responses: bool,
) -> Self {
let pending_responses = messages_to_read.len();
Self {
last_ready: Mutex::new(Option::None),
messages_to_read: Mutex::new(messages_to_read),
new_message_every,
pending_responses,
disconnect_with_pending_responses,
}
}
}
Expand Down Expand Up @@ -86,10 +93,15 @@ impl AsyncRead for MockStream {
return Poll::Ready(Ok(()));
} else {
// Disconnect once we've sent all of the requests AND received all of the responses.
if self.pending_responses == 0 {
if self.disconnect_with_pending_responses {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"mock connection premature disconnect",
)));
} else if self.pending_responses == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"mock connection disconnect",
"mock connection normal disconnect",
)));
}
}
Expand Down Expand Up @@ -149,6 +161,7 @@ struct MockClientConfig {
pub new_message_every: Duration,
pub messages: VecDeque<Vec<u8>>,
pub client_port: u16,
pub disconnect_with_pending_responses: bool,
}

/// A mock TCP connection acceptor with a fixed rate at which (mock) client
Expand Down Expand Up @@ -222,13 +235,15 @@ impl AsyncAccept for MockListener {
new_message_every,
messages,
client_port,
disconnect_with_pending_responses,
}) = streams_to_read.pop_front()
{
last_accept.replace(Instant::now());
return Poll::Ready(Ok((
std::future::ready(Ok(MockStream::new(
messages,
new_message_every,
disconnect_with_pending_responses,
))),
format!("192.168.0.1:{}", client_port)
.parse()
Expand Down Expand Up @@ -362,11 +377,11 @@ fn mk_query() -> StreamTarget<Vec<u8>> {
// time dependent test to run much faster without actual periods of
// waiting to allow time to elapse.
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn service_test() {
async fn tcp_service_test() {
// Initialize tracing based logging. Override with env var RUST_LOG, e.g.
// RUST_LOG=trace.
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_env_filter(EnvFilter::from_default_env())
.with_thread_ids(true)
.without_time()
.try_init()
Expand All @@ -383,6 +398,7 @@ async fn service_test() {
mk_query().as_dgram_slice().to_vec(),
]),
client_port: 1,
disconnect_with_pending_responses: false,
};
let slow_client = MockClientConfig {
new_message_every: Duration::from_millis(3000),
Expand All @@ -391,6 +407,7 @@ async fn service_test() {
mk_query().as_dgram_slice().to_vec(),
]),
client_port: 2,
disconnect_with_pending_responses: false,
};
let num_messages =
fast_client.messages.len() + slow_client.messages.len();
Expand Down Expand Up @@ -457,3 +474,101 @@ async fn service_test() {
// Terminate the task that periodically prints the server status
server_status_printer_handle.abort();
}

#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn tcp_client_disconnect_test() {
// Initialize tracing based logging. Override with env var RUST_LOG, e.g.
// RUST_LOG=trace.
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.with_thread_ids(true)
.without_time()
.try_init()
.ok();

let (srv_handle, server_status_printer_handle) = {
let fast_client = MockClientConfig {
new_message_every: Duration::from_millis(100),
messages: VecDeque::from([
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
]),
client_port: 1,
disconnect_with_pending_responses: true,
};
let slow_client = MockClientConfig {
new_message_every: Duration::from_millis(3000),
messages: VecDeque::from([
mk_query().as_dgram_slice().to_vec(),
mk_query().as_dgram_slice().to_vec(),
]),
client_port: 2,
disconnect_with_pending_responses: false,
};
let num_messages =
fast_client.messages.len() + slow_client.messages.len();
let streams_to_read = VecDeque::from([fast_client, slow_client]);
let new_client_every = Duration::from_millis(2000);
let listener = MockListener::new(streams_to_read, new_client_every);
let ready_flag = listener.get_ready_flag();

let buf = MockBufSource;
let my_service = Arc::new(MyService::new());
let srv =
Arc::new(StreamServer::new(listener, buf, my_service.clone()));

let metrics = srv.metrics();
let server_status_printer_handle = tokio::spawn(async move {
loop {
sleep(Duration::from_millis(250)).await;
eprintln!(
"Server status: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}",
metrics.num_connections(),
metrics.num_inflight_requests(),
metrics.num_pending_writes(),
metrics.num_received_requests(),
metrics.num_sent_responses(),
);
}
});

let spawned_srv = srv.clone();
let srv_handle = tokio::spawn(async move { spawned_srv.run().await });

eprintln!("Clients sleeping");
sleep(Duration::from_secs(1)).await;

eprintln!("Clients connecting");
ready_flag.store(true, Ordering::Relaxed);

// Simulate a wait long enough that all simulated clients had time
// to connect, communicate and disconnect.
sleep(Duration::from_secs(20)).await;

// Verify that all simulated clients connected.
assert_eq!(0, srv.source().streams_remaining());

// Verify that no requests or responses are in progress still in
// the server.
assert_eq!(srv.metrics().num_connections(), 0);
assert_eq!(srv.metrics().num_inflight_requests(), 0);
assert_eq!(srv.metrics().num_pending_writes(), 0);
assert_eq!(srv.metrics().num_received_requests(), num_messages);
assert!(srv.metrics().num_sent_responses() < num_messages);

eprintln!("Shutting down");
srv.shutdown().unwrap();
eprintln!("Shutdown command sent");

(srv_handle, server_status_printer_handle)
};

eprintln!("Waiting for service to shutdown");
let _ = srv_handle.await;

// Terminate the task that periodically prints the server status
server_status_printer_handle.abort();
}