Skip to content

Commit

Permalink
Add optimistic channels
Browse files Browse the repository at this point in the history
This variation on make_buf_channel_pair omits the EOF check. This works
around the fact that S3 drops connections as soon as all data is sent or
received.

Fixes #304 and fixes the same issue for non-oneshot cases where a
similar issue occurred as well.
  • Loading branch information
aaronmondal committed Oct 25, 2023
1 parent f15146d commit d8bf8c0
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 55 deletions.
97 changes: 70 additions & 27 deletions cas/store/s3_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ use std::sync::{Arc, Mutex};
use std::time::Duration;

use async_trait::async_trait;
use bytes::Bytes;
use futures::future::{try_join_all, FutureExt};
use futures::stream::{unfold, FuturesUnordered};
use futures::TryStreamExt;
use futures::{join, try_join, TryStreamExt};
use http::status::StatusCode;
use lazy_static::lazy_static;
use rand::{rngs::OsRng, Rng};
Expand All @@ -39,7 +40,7 @@ use tokio::sync::Semaphore;
use tokio::time::sleep;
use tokio_util::io::ReaderStream;

use buf_channel::{DropCloserReadHalf, DropCloserWriteHalf};
use buf_channel::{make_buf_channel_pair_optimistic_writer, DropCloserReadHalf, DropCloserWriteHalf};
use common::{log, DigestInfo, JoinHandleDropGuard};
use error::{error_if, make_err, make_input_err, Code, Error, ResultExt};
use retry::{ExponentialBackoff, Retrier, RetryResult};
Expand Down Expand Up @@ -110,10 +111,7 @@ where
// HTTP-level errors. Sometimes can retry.
Err(RusotoError::Unknown(e)) => match e.status {
StatusCode::NOT_FOUND => RetryResult::Err(make_err!(Code::NotFound, "{}", e.status.to_string())),
StatusCode::INTERNAL_SERVER_ERROR => {
RetryResult::Retry(make_err!(Code::Unavailable, "{}", e.status.to_string()))
}
StatusCode::SERVICE_UNAVAILABLE => {
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
RetryResult::Retry(make_err!(Code::Unavailable, "{}", e.status.to_string()))
}
StatusCode::CONFLICT => RetryResult::Retry(make_err!(Code::Unavailable, "{}", e.status.to_string())),
Expand Down Expand Up @@ -153,7 +151,7 @@ impl S3Store {
S3Client::new_with(dispatcher, credentials_provider, region)
};
let jitter_amt = config.retry.jitter;
S3Store::new_with_client_and_jitter(
Self::new_with_client_and_jitter(
config,
s3_client,
Box::new(move |delay: Duration| {
Expand All @@ -172,12 +170,12 @@ impl S3Store {
s3_client: S3Client,
jitter_fn: Box<dyn Fn(Duration) -> Duration + Send + Sync>,
) -> Result<Self, Error> {
Ok(S3Store {
Ok(Self {
s3_client: Arc::new(s3_client),
bucket: config.bucket.to_string(),
key_prefix: config.key_prefix.as_ref().unwrap_or(&"".to_string()).to_owned(),
key_prefix: config.key_prefix.as_ref().unwrap_or(&String::new()).clone(),
jitter_fn,
retry: config.retry.to_owned(),
retry: config.retry.clone(),
retrier: Retrier::new(Box::new(|duration| Box::pin(sleep(duration)))),
})
}
Expand All @@ -196,8 +194,8 @@ impl S3Store {
retry_config,
unfold((), move |state| async move {
let head_req = HeadObjectRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
..Default::default()
};

Expand Down Expand Up @@ -240,6 +238,52 @@ impl S3Store {

#[async_trait]
impl StoreTrait for S3Store {
/// Brute-force override to ignore connection drops from S3.
async fn get_part_unchunked(
self: Pin<&Self>,
digest: DigestInfo,
offset: usize,
length: Option<usize>,
size_hint: Option<usize>,
) -> Result<Bytes, Error> {
// TODO(blaise.bruer) This is extremely inefficient, since we have exactly
// what we need here. Maybe we could instead make a version of the stream
// that can take objects already fully in memory instead?
let (tx, rx) = make_buf_channel_pair_optimistic_writer();

let (data_res, get_part_res) = join!(
rx.collect_all_with_size_hint(length.unwrap_or_else(|| size_hint.unwrap_or(0))),
self.get_part(digest, tx, offset, length),
);
get_part_res
.err_tip(|| "Failed to get_part in get_part_unchunked")
.merge(data_res.err_tip(|| "Failed to read stream to completion in get_part_unchunked"))
}

/// Brute-force override to ignore connection drops from S3.
async fn update_oneshot(self: Pin<&Self>, digest: DigestInfo, data: Bytes) -> Result<(), Error> {
// TODO(blaise.bruer) This is extremely inefficient, since we have exactly
// what we need here. Maybe we could instead make a version of the stream
// that can take objects already fully in memory instead?
let (mut tx, rx) = make_buf_channel_pair_optimistic_writer();

let data_len = data.len();
let send_fut = async move {
// Only send if we are not EOF.
if !data.is_empty() {
tx.send(data)
.await
.err_tip(|| "Failed to write data in update_oneshot")?;
}
tx.send_eof()
.await
.err_tip(|| "Failed to write EOF in update_oneshot")?;
Ok(())
};
try_join!(send_fut, self.update(digest, rx, UploadSizeInfo::ExactSize(data_len)))?;
Ok(())
}

async fn has_with_results(
self: Pin<&Self>,
digests: &[DigestInfo],
Expand Down Expand Up @@ -267,8 +311,7 @@ impl StoreTrait for S3Store {
let s3_path = &self.make_s3_path(&digest);

let max_size = match upload_size {
UploadSizeInfo::ExactSize(sz) => sz,
UploadSizeInfo::MaxSize(sz) => sz,
UploadSizeInfo::ExactSize(sz) | UploadSizeInfo::MaxSize(sz) => sz,
};
// NOTE(blaise.bruer) It might be more optimal to use a different heuristic here, but for
// simplicity we use a hard codded value. Anything going down this if-statement will have
Expand Down Expand Up @@ -296,8 +339,8 @@ impl StoreTrait for S3Store {
};

let put_object_request = PutObjectRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
content_length,
body,
..Default::default()
Expand All @@ -317,8 +360,8 @@ impl StoreTrait for S3Store {
let response = self
.s3_client
.create_multipart_upload(CreateMultipartUploadRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
..Default::default()
})
.await
Expand Down Expand Up @@ -349,8 +392,8 @@ impl StoreTrait for S3Store {
let body = Some(ByteStream::new(ReaderStream::new(Cursor::new(write_buf))));

let request = UploadPartRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
content_length: Some(write_buf_len),
body,
part_number,
Expand Down Expand Up @@ -383,8 +426,8 @@ impl StoreTrait for S3Store {
let completed_parts = try_join_all(completed_part_futs).await?;
self.s3_client
.complete_multipart_upload(CompleteMultipartUploadRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
upload_id: upload_id.clone(),
multipart_upload: Some(CompletedMultipartUpload {
parts: Some(completed_parts),
Expand All @@ -400,8 +443,8 @@ impl StoreTrait for S3Store {
let abort_result = self
.s3_client
.abort_multipart_upload(AbortMultipartUploadRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
upload_id: upload_id.clone(),
..Default::default()
})
Expand Down Expand Up @@ -436,12 +479,12 @@ impl StoreTrait for S3Store {
let result = self
.s3_client
.get_object(GetObjectRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
bucket: self.bucket.clone(),
key: s3_path.clone(),
range: Some(format!(
"bytes={}-{}",
offset + writer.get_bytes_written() as usize,
end_read_byte.map_or_else(|| "".to_string(), |v| v.to_string())
end_read_byte.map_or_else(String::new, |v| v.to_string())
)),
..Default::default()
})
Expand Down
86 changes: 60 additions & 26 deletions util/buf_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use error::{error_if, make_err, Code, Error, ResultExt};
/// utility like managing EOF in a more friendly way, ensure if no EOF is received
/// it will send an error to the receiver channel before shutting down and count
/// the number of bytes sent.
#[must_use]
pub fn make_buf_channel_pair() -> (DropCloserWriteHalf, DropCloserReadHalf) {
// We allow up to 2 items in the buffer at any given time. There is no major
// reason behind this magic number other than thinking it will be nice to give
Expand All @@ -39,6 +40,34 @@ pub fn make_buf_channel_pair() -> (DropCloserWriteHalf, DropCloserReadHalf) {
tx: Some(tx),
bytes_written: 0,
close_rx,
disable_eof_check: false,
},
DropCloserReadHalf {
rx,
partial: None,
close_tx: Some(close_tx),
close_after_size: u64::MAX,
},
)
}

/// Same as `make_buf_channel_pair` but disables the EOF check for the sender.
/// Some remote receivers drop connections before we can send the EOF check. If
/// the receiver can be trusted to handle failing streams this is safe to do.
#[must_use]
pub fn make_buf_channel_pair_optimistic_writer() -> (DropCloserWriteHalf, DropCloserReadHalf) {
// We allow up to 2 items in the buffer at any given time. There is no major
// reason behind this magic number other than thinking it will be nice to give
// a little time for another thread to wake up and consume data if another
// thread is pumping large amounts of data into the channel.
let (tx, rx) = mpsc::channel(2);
let (close_tx, close_rx) = oneshot::channel();
(
DropCloserWriteHalf {
tx: Some(tx),
bytes_written: 0,
close_rx,
disable_eof_check: true,
},
DropCloserReadHalf {
rx,
Expand All @@ -56,6 +85,7 @@ pub struct DropCloserWriteHalf {
/// Receiver channel used to know the error (or success) value of the
/// receiver end's drop status (ie: if the receiver dropped unexpectedly).
close_rx: oneshot::Receiver<Result<(), Error>>,
disable_eof_check: bool,
}

impl DropCloserWriteHalf {
Expand Down Expand Up @@ -98,36 +128,35 @@ impl DropCloserWriteHalf {
S: Stream<Item = Result<Bytes, std::io::Error>> + Send + Unpin,
{
loop {
match reader.next().await {
Some(maybe_chunk) => {
let chunk = maybe_chunk.err_tip(|| "Failed to forward message")?;
if chunk.is_empty() {
// Don't send EOF here. We instead rely on None result to be EOF.
continue;
}
self.send(chunk).await?;
if let Some(maybe_chunk) = reader.next().await {
let chunk = maybe_chunk.err_tip(|| "Failed to forward message")?;
if chunk.is_empty() {
// Don't send EOF here. We instead rely on None result to be EOF.
continue;
}
None => {
if forward_eof {
self.send_eof().await?;
}
break;
self.send(chunk).await?;
} else {
if forward_eof {
self.send_eof().await?;
}
break;
}
}
Ok(())
}

/// Returns the number of bytes written so far. This does not mean the receiver received
/// all of the bytes written to the stream so far.
pub fn get_bytes_written(&self) -> u64 {
#[must_use]
pub const fn get_bytes_written(&self) -> u64 {
self.bytes_written
}

/// Returns if the pipe was broken. This is good for determining if the reader broke the
/// pipe or the writer broke the pipe, since this will only return true if the pipe was
/// broken by the writer.
pub fn is_pipe_broken(&self) -> bool {
#[must_use]
pub const fn is_pipe_broken(&self) -> bool {
self.tx.is_none()
}
}
Expand All @@ -139,16 +168,21 @@ impl Drop for DropCloserWriteHalf {
eprintln!("No tokio runtime active. Tx was dropped but can't send error.");
return; // Cant send error, no runtime.
}
if let Some(tx) = self.tx.take() {
// If we do not notify the receiver of the premature close of the stream (ie: without EOF)
// we could end up with the receiver thinking everything is good and saving this bad data.
tokio::spawn(async move {
let _ = tx
.send(Err(
make_err!(Code::Internal, "Writer was dropped before EOF was sent",),
))
.await; // Nowhere to send failure to write here.
});
// Some remote receivers out of our control may close connections before
// we can send the EOF check. If the remote receiver can be trusted to
// handle incomplete data on its side we can disable this check.
if !self.disable_eof_check {
if let Some(tx) = self.tx.take() {
// If we do not notify the receiver of the premature close of the stream (ie: without EOF)
// we could end up with the receiver thinking everything is good and saving this bad data.
tokio::spawn(async move {
let _ = tx
.send(Err(
make_err!(Code::Internal, "Writer was dropped before EOF was sent",),
))
.await; // Nowhere to send failure to write here.
});
}
}
}
}
Expand Down Expand Up @@ -195,7 +229,7 @@ impl DropCloserReadHalf {
Ok(chunk)
}

Some(Err(e)) => Err(e),
Some(Err(e)) => Err(make_err!(Code::Internal, "Received erroneous partial chunk: {e}")),

// None is a safe EOF received.
None => {
Expand Down
23 changes: 21 additions & 2 deletions util/tests/buf_channel_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use bytes::Bytes;
use tokio::try_join;

use buf_channel::make_buf_channel_pair;
use buf_channel::{make_buf_channel_pair, make_buf_channel_pair_optimistic_writer};
use error::{make_err, Code, Error, ResultExt};

#[cfg(test)]
Expand Down Expand Up @@ -233,7 +233,26 @@ mod buf_channel_tests {
assert_eq!(rx.recv().await?, Bytes::from(DATA1));
assert_eq!(
rx.recv().await,
Err(make_err!(Code::Internal, "Writer was dropped before EOF was sent"))
Err(make_err!(Code::Internal, "Received erroneous partial chunk: Error {{ code: Internal, messages: [\"Writer was dropped before EOF was sent\"] }}"))
);
Result::<(), Error>::Ok(())
};
try_join!(tx_fut, rx_fut)?;
Ok(())
}

#[tokio::test]
async fn rx_doesnt_treat_optimistic_tx_drops_test_as_error() -> Result<(), Error> {
let (mut tx, mut rx) = make_buf_channel_pair_optimistic_writer();
let tx_fut = async move {
tx.send(DATA1.into()).await?;
Result::<(), Error>::Ok(())
};
let rx_fut = async move {
assert_eq!(rx.recv().await?, Bytes::from(DATA1));
assert_eq!(
rx.recv().await,
Err(make_err!(Code::Internal, "Failed to send closing ok message to write"))
);
Result::<(), Error>::Ok(())
};
Expand Down

0 comments on commit d8bf8c0

Please sign in to comment.