Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

provide SeekableStream impl for tokio::fs::File #1364

Merged
merged 7 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sdk/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ url = "2.2"
uuid = { version = "1.0" }
pin-project = "1.0"
paste = "1.0"
tokio = {version="1.0", optional=true}

# Add dependency to getrandom to enable WASM support
[target.'cfg(target_arch = "wasm32")'.dependencies]
Expand All @@ -43,7 +44,7 @@ rustc_version = "0.4"

[dev-dependencies]
env_logger = "0.10"
tokio = { version = "1", features = ["default"] }
tokio = { version = "1.0", features = ["default"] }
thiserror = "1.0"

[features]
Expand All @@ -54,3 +55,4 @@ enable_reqwest_rustls = ["reqwest/rustls-tls"]
test_e2e = []
azurite_workaround = []
xml = ["quick-xml"]
tokio-fs = ["tokio/fs", "tokio/io-util"]
3 changes: 3 additions & 0 deletions sdk/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ use uuid::Uuid;
#[cfg(feature = "xml")]
pub mod xml;

#[cfg(feature = "tokio")]
pub mod tokio;

pub mod base64;
pub use bytes_stream::*;
pub use constants::*;
Expand Down
165 changes: 165 additions & 0 deletions sdk/core/src/tokio/fs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
use crate::{
request::Body,
seekable_stream::{SeekableStream, DEFAULT_BUFFER_SIZE},
setters,
};
use futures::{task::Poll, Future};
use std::{cmp::min, io::SeekFrom, pin::Pin, sync::Arc, task::Context};
use tokio::{
fs::File,
io::{AsyncReadExt, AsyncSeekExt, Take},
sync::Mutex,
};

#[derive(Debug)]
pub struct FileStreamBuilder {
handle: File,
/// Offset into the file to start reading from
offset: Option<u64>,
/// Amount of data to read from the file
buffer_size: Option<usize>,
/// How much to buffer in memory during streaming reads
block_size: Option<u64>,
}

impl FileStreamBuilder {
pub fn new(handle: File) -> Self {
Self {
handle,
offset: None,
buffer_size: None,
block_size: None,
}
}

setters! {
// #[doc = "Offset into the file to start reading from"]
offset: u64 => Some(offset),
// #[doc = "Amount of data to read from the file"]
block_size: u64 => Some(block_size),
// #[doc = "Amount of data to buffer in memory during streaming reads"]
buffer_size: usize => Some(buffer_size),
}

pub async fn build(mut self) -> crate::Result<FileStream> {
let stream_size = self.handle.metadata().await?.len();

let buffer_size = self.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE);

let offset = if let Some(offset) = self.offset {
self.handle.seek(SeekFrom::Start(offset)).await?;
offset
} else {
0
};

let block_size = if let Some(block_size) = self.block_size {
block_size
} else {
stream_size - offset
};

let handle = Arc::new(Mutex::new(self.handle.take(block_size)));

Ok(FileStream {
handle,
buffer_size,
block_size,
stream_size,
offset,
})
}
}

#[derive(Debug, Clone)]
#[pin_project::pin_project]
pub struct FileStream {
#[pin]
handle: Arc<Mutex<Take<File>>>,
pub stream_size: u64,
pub block_size: u64,
buffer_size: usize,
pub offset: u64,
}

impl FileStream {
async fn read(&mut self, slice: &mut [u8]) -> std::io::Result<usize> {
demoray marked this conversation as resolved.
Show resolved Hide resolved
let mut handle = self.handle.clone().lock_owned().await;
handle.read(slice).await
}

/// Resets the number of bytes that will be read from this instance to the
/// `stream_size`
///
/// This is useful if you want to read the stream in multiple blocks
pub async fn next_block(&mut self) -> crate::Result<()> {
log::info!("setting limit to {}", self.block_size);
let mut handle = self.handle.clone().lock_owned().await;
{
let inner = handle.get_mut();
self.offset = inner.stream_position().await?;
}
handle.set_limit(self.block_size);
Ok(())
}
}

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl SeekableStream for FileStream {
/// Seek to the specified offset into the file and reset the number of bytes to read
///
/// This is useful upon encountering an error to reset the stream to the last
async fn reset(&mut self) -> crate::Result<()> {
log::info!(
"resetting stream to offset {} and limit to {}",
self.offset,
self.block_size
);
let mut handle = self.handle.clone().lock_owned().await;
{
let inner = handle.get_mut();
inner.seek(SeekFrom::Start(self.offset)).await?;
}
handle.set_limit(self.block_size);
Ok(())
}

fn len(&self) -> usize {
log::info!(
"stream len: {} - {} ... {}",
self.stream_size,
self.offset,
self.block_size
);
min(self.stream_size - self.offset, self.block_size) as usize
}

/*
fn buffer_size(&self) -> usize {
self.buffer_size
}
*/
}
demoray marked this conversation as resolved.
Show resolved Hide resolved

impl futures::io::AsyncRead for FileStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
slice: &mut [u8],
) -> Poll<std::io::Result<usize>> {
std::pin::pin!(self.read(slice)).poll(cx)
}
}

impl From<&FileStream> for Body {
fn from(stream: &FileStream) -> Self {
Body::SeekableStream(Box::new(stream.clone()))
}
}

impl From<FileStream> for Body {
fn from(stream: FileStream) -> Self {
Body::SeekableStream(Box::new(stream))
}
}
1 change: 1 addition & 0 deletions sdk/core/src/tokio/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod fs;
4 changes: 3 additions & 1 deletion sdk/storage_blobs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ uuid = { version = "1.0", features = ["v4"] }
url = "2.2"

[dev-dependencies]
tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] }
tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "io-util"]}
env_logger = "0.10"
azure_identity = { path = "../identity", default-features = false }
reqwest = "0.11"
mock_transport = { path = "../../eng/test/mock_transport" }
md5 = "0.7"
async-trait = "0.1"
clap = { version = "4.0", features = ["derive", "env"] }
azure_core = {path = "../core", version = "0.14", features = ["tokio-fs"]}

[features]
default = ["enable_reqwest"]
Expand Down
92 changes: 92 additions & 0 deletions sdk/storage_blobs/examples/stream_blob_02.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use azure_core::{
error::{ErrorKind, ResultExt},
tokio::fs::FileStreamBuilder,
};
use azure_storage::prelude::*;
use azure_storage_blobs::prelude::*;
use clap::Parser;
use std::path::PathBuf;
use tokio::fs::File;

#[derive(Debug, Parser)]
struct Args {
/// Name of the container to upload
container_name: String,
/// Blob name
blob_name: String,
/// File path to upload
file_path: PathBuf,

/// Offset to start uploading from
#[clap(long)]
offset: Option<u64>,

/// how much to buffer in memory during streaming reads
#[clap(long)]
buffer_size: Option<usize>,

#[clap(long)]
block_size: Option<u64>,

/// storage account name
#[clap(env = "STORAGE_ACCOUNT")]
account: String,

/// storage account access key
#[clap(env = "STORAGE_ACCESS_KEY")]
access_key: String,
}

#[tokio::main]
async fn main() -> azure_core::Result<()> {
env_logger::init();
let args = Args::parse();

let storage_credentials =
StorageCredentials::Key(args.account.clone(), args.access_key.clone());
let blob_client = BlobServiceClient::new(&args.account, storage_credentials)
.container_client(&args.container_name)
.blob_client(&args.blob_name);

let file = File::open(&args.file_path).await?;

let mut builder = FileStreamBuilder::new(file);

if let Some(buffer_size) = args.buffer_size {
builder = builder.buffer_size(buffer_size);
}

if let Some(offset) = args.offset {
builder = builder.offset(offset);
}

if let Some(block_size) = args.block_size {
builder = builder.block_size(block_size);
}

let mut handle = builder.build().await?;

if let Some(block_size) = args.block_size {
let mut block_list = BlockList::default();
for offset in (handle.offset..handle.stream_size).step_by(block_size as usize) {
log::info!("trying to upload at offset {offset} - {block_size}");
let block_id = format!("{:08X}", offset);
blob_client.put_block(block_id.clone(), &handle).await?;
log::info!("uploaded block {block_id}");
block_list
.blocks
.push(BlobBlockType::new_uncommitted(block_id));
handle.next_block().await?;
}
blob_client.put_block_list(block_list).await?;
} else {
// upload as one large block
blob_client.put_block_blob(handle).await?;
}

let blob = blob_client.get_content().await?;
let s = String::from_utf8(blob).map_kind(ErrorKind::DataConversion)?;
println!("retrieved contents == {s:?}");

Ok(())
}