Skip to content

Commit

Permalink
provide SeekableStream impl for tokio::fs::File (#1364)
Browse files Browse the repository at this point in the history
Addresses #1219
  • Loading branch information
demoray committed Sep 12, 2023
1 parent c0a0834 commit 69709be
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 2 deletions.
4 changes: 3 additions & 1 deletion sdk/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ url = "2.2"
uuid = { version = "1.0" }
pin-project = "1.0"
paste = "1.0"
tokio = {version="1.0", optional=true}

# Add dependency to getrandom to enable WASM support
[target.'cfg(target_arch = "wasm32")'.dependencies]
Expand All @@ -43,7 +44,7 @@ rustc_version = "0.4"

[dev-dependencies]
env_logger = "0.10"
tokio = { version = "1", features = ["default"] }
tokio = { version = "1.0", features = ["default"] }
thiserror = "1.0"

[features]
Expand All @@ -54,3 +55,4 @@ enable_reqwest_rustls = ["reqwest/rustls-tls"]
test_e2e = []
azurite_workaround = []
xml = ["quick-xml"]
tokio-fs = ["tokio/fs", "tokio/io-util"]
3 changes: 3 additions & 0 deletions sdk/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ use uuid::Uuid;
#[cfg(feature = "xml")]
pub mod xml;

#[cfg(feature = "tokio")]
pub mod tokio;

pub mod base64;
pub use bytes_stream::*;
pub use constants::*;
Expand Down
169 changes: 169 additions & 0 deletions sdk/core/src/tokio/fs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use crate::{
request::Body,
seekable_stream::{SeekableStream, DEFAULT_BUFFER_SIZE},
setters,
};
use futures::{task::Poll, Future};
use std::{cmp::min, io::SeekFrom, pin::Pin, sync::Arc, task::Context};
use tokio::{
fs::File,
io::{AsyncReadExt, AsyncSeekExt, Take},
sync::Mutex,
};

#[derive(Debug)]
pub struct FileStreamBuilder {
handle: File,
/// Offset into the file to start reading from
offset: Option<u64>,
/// Amount of data to read from the file
buffer_size: Option<usize>,
/// How much to buffer in memory during streaming reads
block_size: Option<u64>,
}

impl FileStreamBuilder {
pub fn new(handle: File) -> Self {
Self {
handle,
offset: None,
buffer_size: None,
block_size: None,
}
}

setters! {
// #[doc = "Offset into the file to start reading from"]
offset: u64 => Some(offset),
// #[doc = "Amount of data to read from the file"]
block_size: u64 => Some(block_size),
// #[doc = "Amount of data to buffer in memory during streaming reads"]
buffer_size: usize => Some(buffer_size),
}

pub async fn build(mut self) -> crate::Result<FileStream> {
let stream_size = self.handle.metadata().await?.len();

let buffer_size = self.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE);

let offset = if let Some(offset) = self.offset {
self.handle.seek(SeekFrom::Start(offset)).await?;
offset
} else {
0
};

let block_size = if let Some(block_size) = self.block_size {
block_size
} else {
stream_size - offset
};

let handle = Arc::new(Mutex::new(self.handle.take(block_size)));

Ok(FileStream {
handle,
buffer_size,
block_size,
stream_size,
offset,
})
}
}

#[derive(Debug, Clone)]
#[pin_project::pin_project]
pub struct FileStream {
#[pin]
handle: Arc<Mutex<Take<File>>>,
pub stream_size: u64,
pub block_size: u64,
buffer_size: usize,
pub offset: u64,
}

impl FileStream {
/// Attempts to read from the underlying file handle.
///
/// This first acquires a lock the handle, then reads from the handle. The
/// lock is released upon completion. This is necessary due to the
/// requirement of `Request` (the primary consumer of `FileStream`) must be
/// `Clone`.
async fn read(&mut self, slice: &mut [u8]) -> std::io::Result<usize> {
let mut handle = self.handle.clone().lock_owned().await;
handle.read(slice).await
}

/// Resets the number of bytes that will be read from this instance to the
/// `stream_size`
///
/// This is useful if you want to read the stream in multiple blocks
pub async fn next_block(&mut self) -> crate::Result<()> {
log::info!("setting limit to {}", self.block_size);
let mut handle = self.handle.clone().lock_owned().await;
{
let inner = handle.get_mut();
self.offset = inner.stream_position().await?;
}
handle.set_limit(self.block_size);
Ok(())
}
}

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl SeekableStream for FileStream {
/// Seek to the specified offset into the file and reset the number of bytes to read
///
/// This is useful upon encountering an error to reset the stream to the last
async fn reset(&mut self) -> crate::Result<()> {
log::info!(
"resetting stream to offset {} and limit to {}",
self.offset,
self.block_size
);
let mut handle = self.handle.clone().lock_owned().await;
{
let inner = handle.get_mut();
inner.seek(SeekFrom::Start(self.offset)).await?;
}
handle.set_limit(self.block_size);
Ok(())
}

fn len(&self) -> usize {
log::info!(
"stream len: {} - {} ... {}",
self.stream_size,
self.offset,
self.block_size
);
min(self.stream_size - self.offset, self.block_size) as usize
}

fn buffer_size(&self) -> usize {
self.buffer_size
}
}

impl futures::io::AsyncRead for FileStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
slice: &mut [u8],
) -> Poll<std::io::Result<usize>> {
std::pin::pin!(self.read(slice)).poll(cx)
}
}

impl From<&FileStream> for Body {
fn from(stream: &FileStream) -> Self {
Body::SeekableStream(Box::new(stream.clone()))
}
}

impl From<FileStream> for Body {
fn from(stream: FileStream) -> Self {
Body::SeekableStream(Box::new(stream))
}
}
1 change: 1 addition & 0 deletions sdk/core/src/tokio/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod fs;
4 changes: 3 additions & 1 deletion sdk/storage_blobs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ uuid = { version = "1.0", features = ["v4"] }
url = "2.2"

[dev-dependencies]
tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] }
tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "io-util"]}
env_logger = "0.10"
azure_identity = { path = "../identity", default-features = false }
reqwest = "0.11"
mock_transport = { path = "../../eng/test/mock_transport" }
md5 = "0.7"
async-trait = "0.1"
clap = { version = "4.0", features = ["derive", "env"] }
azure_core = {path = "../core", version = "0.15", features = ["tokio-fs"]}

[features]
default = ["enable_reqwest"]
Expand Down
92 changes: 92 additions & 0 deletions sdk/storage_blobs/examples/stream_blob_02.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use azure_core::{
error::{ErrorKind, ResultExt},
tokio::fs::FileStreamBuilder,
};
use azure_storage::prelude::*;
use azure_storage_blobs::prelude::*;
use clap::Parser;
use std::path::PathBuf;
use tokio::fs::File;

#[derive(Debug, Parser)]
struct Args {
/// Name of the container to upload
container_name: String,
/// Blob name
blob_name: String,
/// File path to upload
file_path: PathBuf,

/// Offset to start uploading from
#[clap(long)]
offset: Option<u64>,

/// how much to buffer in memory during streaming reads
#[clap(long)]
buffer_size: Option<usize>,

#[clap(long)]
block_size: Option<u64>,

/// storage account name
#[clap(env = "STORAGE_ACCOUNT")]
account: String,

/// storage account access key
#[clap(env = "STORAGE_ACCESS_KEY")]
access_key: String,
}

#[tokio::main]
async fn main() -> azure_core::Result<()> {
env_logger::init();
let args = Args::parse();

let storage_credentials =
StorageCredentials::Key(args.account.clone(), args.access_key.clone());
let blob_client = BlobServiceClient::new(&args.account, storage_credentials)
.container_client(&args.container_name)
.blob_client(&args.blob_name);

let file = File::open(&args.file_path).await?;

let mut builder = FileStreamBuilder::new(file);

if let Some(buffer_size) = args.buffer_size {
builder = builder.buffer_size(buffer_size);
}

if let Some(offset) = args.offset {
builder = builder.offset(offset);
}

if let Some(block_size) = args.block_size {
builder = builder.block_size(block_size);
}

let mut handle = builder.build().await?;

if let Some(block_size) = args.block_size {
let mut block_list = BlockList::default();
for offset in (handle.offset..handle.stream_size).step_by(block_size as usize) {
log::info!("trying to upload at offset {offset} - {block_size}");
let block_id = format!("{:08X}", offset);
blob_client.put_block(block_id.clone(), &handle).await?;
log::info!("uploaded block {block_id}");
block_list
.blocks
.push(BlobBlockType::new_uncommitted(block_id));
handle.next_block().await?;
}
blob_client.put_block_list(block_list).await?;
} else {
// upload as one large block
blob_client.put_block_blob(handle).await?;
}

let blob = blob_client.get_content().await?;
let s = String::from_utf8(blob).map_kind(ErrorKind::DataConversion)?;
println!("retrieved contents == {s:?}");

Ok(())
}

0 comments on commit 69709be

Please sign in to comment.