Skip to content

Commit

Permalink
Add ability to ignore EOF check for writers (#341)
Browse files Browse the repository at this point in the history
Work around the fact that S3 drops connections as soon as all data is
sent or received.

Fixes #304 and fixes the same issue for non-oneshot cases where a
similar issue occurred as well.
  • Loading branch information
aaronmondal committed Oct 26, 2023
1 parent 62a2c1e commit 979f941
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 52 deletions.
51 changes: 26 additions & 25 deletions cas/store/s3_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,7 @@ where
// HTTP-level errors. Sometimes can retry.
Err(RusotoError::Unknown(e)) => match e.status {
StatusCode::NOT_FOUND => RetryResult::Err(make_err!(Code::NotFound, "{}", e.status.to_string())),
StatusCode::INTERNAL_SERVER_ERROR => {
RetryResult::Retry(make_err!(Code::Unavailable, "{}", e.status.to_string()))
}
StatusCode::SERVICE_UNAVAILABLE => {
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
RetryResult::Retry(make_err!(Code::Unavailable, "{}", e.status.to_string()))
}
StatusCode::CONFLICT => RetryResult::Retry(make_err!(Code::Unavailable, "{}", e.status.to_string())),
Expand Down Expand Up @@ -153,7 +150,7 @@ impl S3Store {
S3Client::new_with(dispatcher, credentials_provider, region)
};
let jitter_amt = config.retry.jitter;
S3Store::new_with_client_and_jitter(
Self::new_with_client_and_jitter(
config,
s3_client,
Box::new(move |delay: Duration| {
Expand All @@ -172,12 +169,12 @@ impl S3Store {
s3_client: S3Client,
jitter_fn: Box<dyn Fn(Duration) -> Duration + Send + Sync>,
) -> Result<Self, Error> {
Ok(S3Store {
Ok(Self {
s3_client: Arc::new(s3_client),
bucket: config.bucket.to_string(),
key_prefix: config.key_prefix.as_ref().unwrap_or(&"".to_string()).to_owned(),
key_prefix: config.key_prefix.as_ref().unwrap_or(&String::new()).clone(),
jitter_fn,
retry: config.retry.to_owned(),
retry: config.retry.clone(),
retrier: Retrier::new(Box::new(|duration| Box::pin(sleep(duration)))),
})
}
Expand All @@ -196,8 +193,8 @@ impl S3Store {
retry_config,
unfold((), move |state| async move {
let head_req = HeadObjectRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
..Default::default()
};

Expand Down Expand Up @@ -267,8 +264,7 @@ impl StoreTrait for S3Store {
let s3_path = &self.make_s3_path(&digest);

let max_size = match upload_size {
UploadSizeInfo::ExactSize(sz) => sz,
UploadSizeInfo::MaxSize(sz) => sz,
UploadSizeInfo::ExactSize(sz) | UploadSizeInfo::MaxSize(sz) => sz,
};
// NOTE(blaise.bruer) It might be more optimal to use a different heuristic here, but for
// simplicity we use a hard codded value. Anything going down this if-statement will have
Expand Down Expand Up @@ -296,8 +292,8 @@ impl StoreTrait for S3Store {
};

let put_object_request = PutObjectRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
content_length,
body,
..Default::default()
Expand All @@ -317,8 +313,8 @@ impl StoreTrait for S3Store {
let response = self
.s3_client
.create_multipart_upload(CreateMultipartUploadRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
..Default::default()
})
.await
Expand Down Expand Up @@ -349,8 +345,8 @@ impl StoreTrait for S3Store {
let body = Some(ByteStream::new(ReaderStream::new(Cursor::new(write_buf))));

let request = UploadPartRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
content_length: Some(write_buf_len),
body,
part_number,
Expand Down Expand Up @@ -383,8 +379,8 @@ impl StoreTrait for S3Store {
let completed_parts = try_join_all(completed_part_futs).await?;
self.s3_client
.complete_multipart_upload(CompleteMultipartUploadRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
upload_id: upload_id.clone(),
multipart_upload: Some(CompletedMultipartUpload {
parts: Some(completed_parts),
Expand All @@ -400,8 +396,8 @@ impl StoreTrait for S3Store {
let abort_result = self
.s3_client
.abort_multipart_upload(AbortMultipartUploadRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
upload_id: upload_id.clone(),
..Default::default()
})
Expand Down Expand Up @@ -429,19 +425,24 @@ impl StoreTrait for S3Store {
.map(|d| (self.jitter_fn)(d))
.take(self.retry.max_retries); // Remember this is number of retries, so will run max_retries + 1.

// S3 drops connections when a stream is done. This means that we can't
// run the EOF error check. It's safe to disable it since S3 can be
// trusted to handle incomplete data properly.
writer.set_ignore_eof();

self.retrier
.retry(
retry_config,
unfold(writer, move |writer| async move {
let result = self
.s3_client
.get_object(GetObjectRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
range: Some(format!(
"bytes={}-{}",
offset + writer.get_bytes_written() as usize,
end_read_byte.map_or_else(|| "".to_string(), |v| v.to_string())
end_read_byte.map_or_else(String::new, |v| v.to_string())
)),
..Default::default()
})
Expand Down
65 changes: 39 additions & 26 deletions util/buf_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use error::{error_if, make_err, Code, Error, ResultExt};
/// utility like managing EOF in a more friendly way, ensure if no EOF is received
/// it will send an error to the receiver channel before shutting down and count
/// the number of bytes sent.
#[must_use]
pub fn make_buf_channel_pair() -> (DropCloserWriteHalf, DropCloserReadHalf) {
// We allow up to 2 items in the buffer at any given time. There is no major
// reason behind this magic number other than thinking it will be nice to give
Expand All @@ -39,6 +40,7 @@ pub fn make_buf_channel_pair() -> (DropCloserWriteHalf, DropCloserReadHalf) {
tx: Some(tx),
bytes_written: 0,
close_rx,
disable_eof_check: false,
},
DropCloserReadHalf {
rx,
Expand All @@ -56,6 +58,7 @@ pub struct DropCloserWriteHalf {
/// Receiver channel used to know the error (or success) value of the
/// receiver end's drop status (ie: if the receiver dropped unexpectedly).
close_rx: oneshot::Receiver<Result<(), Error>>,
disable_eof_check: bool,
}

impl DropCloserWriteHalf {
Expand Down Expand Up @@ -98,38 +101,43 @@ impl DropCloserWriteHalf {
S: Stream<Item = Result<Bytes, std::io::Error>> + Send + Unpin,
{
loop {
match reader.next().await {
Some(maybe_chunk) => {
let chunk = maybe_chunk.err_tip(|| "Failed to forward message")?;
if chunk.is_empty() {
// Don't send EOF here. We instead rely on None result to be EOF.
continue;
}
self.send(chunk).await?;
if let Some(maybe_chunk) = reader.next().await {
let chunk = maybe_chunk.err_tip(|| "Failed to forward message")?;
if chunk.is_empty() {
// Don't send EOF here. We instead rely on None result to be EOF.
continue;
}
None => {
if forward_eof {
self.send_eof().await?;
}
break;
self.send(chunk).await?;
} else {
if forward_eof {
self.send_eof().await?;
}
break;
}
}
Ok(())
}

/// Returns the number of bytes written so far. This does not mean the receiver received
/// all of the bytes written to the stream so far.
pub fn get_bytes_written(&self) -> u64 {
#[must_use]
pub const fn get_bytes_written(&self) -> u64 {
self.bytes_written
}

/// Returns if the pipe was broken. This is good for determining if the reader broke the
/// pipe or the writer broke the pipe, since this will only return true if the pipe was
/// broken by the writer.
pub fn is_pipe_broken(&self) -> bool {
#[must_use]
pub const fn is_pipe_broken(&self) -> bool {
self.tx.is_none()
}

/// Some remote receivers drop connections before we can send the EOF check.
/// If the receiver handles failing streams it is safe to disable it.
pub fn set_ignore_eof(&mut self) {
self.disable_eof_check = true;
}
}

impl Drop for DropCloserWriteHalf {
Expand All @@ -139,16 +147,21 @@ impl Drop for DropCloserWriteHalf {
eprintln!("No tokio runtime active. Tx was dropped but can't send error.");
return; // Cant send error, no runtime.
}
if let Some(tx) = self.tx.take() {
// If we do not notify the receiver of the premature close of the stream (ie: without EOF)
// we could end up with the receiver thinking everything is good and saving this bad data.
tokio::spawn(async move {
let _ = tx
.send(Err(
make_err!(Code::Internal, "Writer was dropped before EOF was sent",),
))
.await; // Nowhere to send failure to write here.
});
// Some remote receivers out of our control may close connections before
// we can send the EOF check. If the remote receiver can be trusted to
// handle incomplete data on its side we can disable this check.
if !self.disable_eof_check {
if let Some(tx) = self.tx.take() {
// If we do not notify the receiver of the premature close of the stream (ie: without EOF)
// we could end up with the receiver thinking everything is good and saving this bad data.
tokio::spawn(async move {
let _ = tx
.send(Err(
make_err!(Code::Internal, "Writer was dropped before EOF was sent",),
))
.await; // Nowhere to send failure to write here.
});
}
}
}
}
Expand Down Expand Up @@ -195,7 +208,7 @@ impl DropCloserReadHalf {
Ok(chunk)
}

Some(Err(e)) => Err(e),
Some(Err(e)) => Err(make_err!(Code::Internal, "Received erroneous partial chunk: {e}")),

// None is a safe EOF received.
None => {
Expand Down
22 changes: 21 additions & 1 deletion util/tests/buf_channel_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,27 @@ mod buf_channel_tests {
assert_eq!(rx.recv().await?, Bytes::from(DATA1));
assert_eq!(
rx.recv().await,
Err(make_err!(Code::Internal, "Writer was dropped before EOF was sent"))
Err(make_err!(Code::Internal, "Received erroneous partial chunk: Error {{ code: Internal, messages: [\"Writer was dropped before EOF was sent\"] }}"))
);
Result::<(), Error>::Ok(())
};
try_join!(tx_fut, rx_fut)?;
Ok(())
}

#[tokio::test]
async fn rx_accepts_tx_drop_test_when_eof_ignored() -> Result<(), Error> {
let (mut tx, mut rx) = make_buf_channel_pair();
tx.set_ignore_eof();
let tx_fut = async move {
tx.send(DATA1.into()).await?;
Result::<(), Error>::Ok(())
};
let rx_fut = async move {
assert_eq!(rx.recv().await?, Bytes::from(DATA1));
assert_eq!(
rx.recv().await,
Err(make_err!(Code::Internal, "Failed to send closing ok message to write"))
);
Result::<(), Error>::Ok(())
};
Expand Down

0 comments on commit 979f941

Please sign in to comment.