diff --git a/Cargo.toml b/Cargo.toml index 0092826..cf7f009 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,8 @@ categories = ["filesystem"] [dependencies] rand = { version = "0.8" } +async-std = { version = "1.12", optional = true } + [target.'cfg(unix)'.dependencies] nix = { version = "0.28", features = ["fs", "user"] } diff --git a/src/future/mod.rs b/src/future/mod.rs new file mode 100644 index 0000000..b665da5 --- /dev/null +++ b/src/future/mod.rs @@ -0,0 +1,207 @@ +use crate::imp; +use crate::OpenOptions; +use async_std::fs::File; +use async_std::io::IoSlice; +use async_std::io::IoSliceMut; +use async_std::io::Read; +use async_std::io::Seek; +use async_std::io::SeekFrom; +use async_std::io::Write; +use async_std::sync::Arc; +use async_std::task::block_on; +use async_std::task::spawn_blocking; +use async_std::task::Context; +use async_std::task::Poll; +use std::io::Result; +use std::mem::ManuallyDrop; +use std::ops::Deref; +use std::path::Path; +use std::pin::pin; +use std::pin::Pin; +use std::ptr; + +#[cfg(test)] +mod tests; + +#[derive(Clone, Debug)] +pub struct AtomicWriteFile { + temporary_file: Arc>, + finalized: bool, +} + +impl AtomicWriteFile { + pub async fn open>(path: P) -> Result { + let path = path.as_ref().to_path_buf(); + let file = spawn_blocking(move || OpenOptions::new().open(path)).await?; + + // Take the `temporary_file` out of the blocking `AtomicWriteFile`, so that we can convert + // it to an async file. This requires unsafe code because `AtomicWriteFile` has a + // destructor, and we want to avoid running it now + let file = ManuallyDrop::new(file); + // SAFETY: we're taking ownership of the `temporary_file`, and disposing of `file` without + // running its destructor + let temporary_file = unsafe { ptr::read(&(*file).temporary_file) }; + + Ok(Self { + temporary_file: Arc::new(temporary_file.into()), + finalized: false, + }) + } + + #[inline] + pub fn as_file(&self) -> &File { + &self.temporary_file.file + } + + pub async fn commit(mut self) -> Result<()> { + self._commit().await + } + + async fn _commit(&mut self) -> Result<()> { + if self.finalized { + return Ok(()); + } + self.finalized = true; + self.sync_all().await?; + let temporary_file = Arc::clone(&self.temporary_file); + spawn_blocking(move || temporary_file.rename_file()).await + } + + pub async fn discard(mut self) -> Result<()> { + self._discard().await + } + + async fn _discard(&mut self) -> Result<()> { + if self.finalized { + return Ok(()); + } + self.finalized = true; + let temporary_file = Arc::clone(&self.temporary_file); + spawn_blocking(move || temporary_file.remove_file()).await + } +} + +impl Drop for AtomicWriteFile { + #[inline] + fn drop(&mut self) { + if !self.finalized { + // Ignore all errors + let _ = block_on(self._discard()); + } + } +} + +impl Deref for AtomicWriteFile { + type Target = File; + + #[inline] + fn deref(&self) -> &Self::Target { + self.as_file() + } +} + +impl Read for AtomicWriteFile { + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + pin!(&(*self.temporary_file).file).poll_read(cx, buf) + } + + #[inline] + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + pin!(&(*self.temporary_file).file).poll_read_vectored(cx, bufs) + } +} + +impl Read for &AtomicWriteFile { + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + pin!(&(*self.temporary_file).file).poll_read(cx, buf) + } + + #[inline] + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + pin!(&(*self.temporary_file).file).poll_read_vectored(cx, bufs) + } +} + +impl Write for AtomicWriteFile { + #[inline] + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + pin!(&(*self.temporary_file).file).poll_write(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(&(*self.temporary_file).file).poll_flush(cx) + } + + #[inline] + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(&(*self.temporary_file).file).poll_close(cx) + } + + #[inline] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + pin!(&(*self.temporary_file).file).poll_write_vectored(cx, bufs) + } +} + +impl Write for &AtomicWriteFile { + #[inline] + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + pin!(&(*self.temporary_file).file).poll_write(cx, buf) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(&(*self.temporary_file).file).poll_flush(cx) + } + + #[inline] + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + pin!(&(*self.temporary_file).file).poll_close(cx) + } + + #[inline] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + pin!(&(*self.temporary_file).file).poll_write_vectored(cx, bufs) + } +} + +impl Seek for AtomicWriteFile { + #[inline] + fn poll_seek(self: Pin<&mut Self>, cx: &mut Context<'_>, pos: SeekFrom) -> Poll> { + pin!(&(*self.temporary_file).file).poll_seek(cx, pos) + } +} + +impl Seek for &AtomicWriteFile { + #[inline] + fn poll_seek(self: Pin<&mut Self>, cx: &mut Context<'_>, pos: SeekFrom) -> Poll> { + pin!(&(*self.temporary_file).file).poll_seek(cx, pos) + } +} diff --git a/src/future/tests.rs b/src/future/tests.rs new file mode 100644 index 0000000..fabdca0 --- /dev/null +++ b/src/future/tests.rs @@ -0,0 +1,65 @@ +use crate::future::AtomicWriteFile; +use crate::tests::test_file; +use crate::tests::verify_no_leftovers; +use async_std::fs; +use async_std::io::WriteExt; +use async_std::task::block_on; +use std::io::Result; + +#[test] +fn create_new() -> Result<()> { + block_on(async { + let path = test_file("async-new"); + assert!(!path.exists()); + + let mut file = AtomicWriteFile::open(&path).await?; + assert!(!path.exists()); + + file.write_all(b"hello ").await?; + assert!(!path.exists()); + file.flush().await?; + assert!(!path.exists()); + file.write_all(b"world\n").await?; + assert!(!path.exists()); + file.flush().await?; + assert!(!path.exists()); + + file.commit().await?; + + assert!(path.exists()); + assert_eq!(fs::read(&path).await?, b"hello world\n"); + + verify_no_leftovers(path); + + Ok(()) + }) +} + +#[test] +fn overwrite_existing() -> Result<()> { + block_on(async { + let path = test_file("async-existing"); + fs::write(&path, b"initial contents\n").await?; + assert_eq!(fs::read(&path).await?, b"initial contents\n"); + + let mut file = AtomicWriteFile::open(&path).await?; + assert_eq!(fs::read(&path).await?, b"initial contents\n"); + + file.write_all(b"hello ").await?; + assert_eq!(fs::read(&path).await?, b"initial contents\n"); + file.flush().await?; + assert_eq!(fs::read(&path).await?, b"initial contents\n"); + file.write_all(b"world\n").await?; + assert_eq!(fs::read(&path).await?, b"initial contents\n"); + file.flush().await?; + assert_eq!(fs::read(&path).await?, b"initial contents\n"); + + file.commit().await?; + + assert_eq!(fs::read(&path).await?, b"hello world\n"); + + verify_no_leftovers(path); + + Ok(()) + }) +} diff --git a/src/imp/generic.rs b/src/imp/generic.rs index 64f5f56..7f873ef 100644 --- a/src/imp/generic.rs +++ b/src/imp/generic.rs @@ -33,10 +33,10 @@ impl Default for OpenOptions { } #[derive(Debug)] -pub(crate) struct TemporaryFile { +pub(crate) struct TemporaryFile { pub(crate) temp_path: PathBuf, pub(crate) dest_path: PathBuf, - pub(crate) file: File, + pub(crate) file: F, } impl TemporaryFile { @@ -69,7 +69,9 @@ impl TemporaryFile { file, }) } +} +impl TemporaryFile { pub(crate) fn rename_file(&self) -> Result<()> { fs::rename(&self.temp_path, &self.dest_path) } @@ -84,6 +86,21 @@ impl TemporaryFile { } } +#[cfg(feature = "async-std")] +impl TemporaryFile { + #[inline] + pub(crate) fn into(self) -> TemporaryFile + where + G: From, + { + TemporaryFile:: { + temp_path: self.temp_path, + dest_path: self.dest_path, + file: self.file.into(), + } + } +} + // An enum without variants, so that it can never be constructed #[derive(Debug)] pub(crate) enum Dir {} diff --git a/src/imp/unix/generic.rs b/src/imp/unix/generic.rs index 8020ce8..b6ec339 100644 --- a/src/imp/unix/generic.rs +++ b/src/imp/unix/generic.rs @@ -11,9 +11,9 @@ use std::io::Result; use std::path::Path; #[derive(Debug)] -pub(crate) struct TemporaryFile { +pub(crate) struct TemporaryFile { pub(crate) dir: Dir, - pub(crate) file: File, + pub(crate) file: F, pub(crate) name: OsString, pub(crate) temporary_name: OsString, } @@ -42,7 +42,9 @@ impl TemporaryFile { temporary_name, }) } +} +impl TemporaryFile { pub(crate) fn rename_file(&self) -> Result<()> { rename_temporary_file(&self.dir, &self.temporary_name, &self.name)?; Ok(()) @@ -58,3 +60,19 @@ impl TemporaryFile { Some(&self.dir) } } + +#[cfg(feature = "async-std")] +impl TemporaryFile { + #[inline] + pub(crate) fn into(self) -> TemporaryFile + where + G: From, + { + TemporaryFile:: { + dir: self.dir, + file: self.file.into(), + name: self.name, + temporary_name: self.temporary_name, + } + } +} diff --git a/src/imp/unix/linux.rs b/src/imp/unix/linux.rs index 1b95285..1df058a 100644 --- a/src/imp/unix/linux.rs +++ b/src/imp/unix/linux.rs @@ -39,7 +39,10 @@ fn create_unnamed_temporary_file(dir: &Dir, opts: &OpenOptions) -> nix::Result nix::Result<()> { +fn rename_unnamed_temporary_file(dir: &Dir, file: &F, name: &OsStr) -> nix::Result<()> +where + F: AsRawFd, +{ let fd = file.as_raw_fd(); let src = OsString::from(format!("/proc/self/fd/{fd}")); let mut random_name = RandomName::new(name); @@ -69,9 +72,9 @@ fn rename_unnamed_temporary_file(dir: &Dir, file: &File, name: &OsStr) -> nix::R } #[derive(Debug)] -pub(crate) struct TemporaryFile { +pub(crate) struct TemporaryFile { pub(crate) dir: Dir, - pub(crate) file: File, + pub(crate) file: F, pub(crate) name: OsString, pub(crate) temporary_name: Option, } @@ -106,6 +109,8 @@ impl TemporaryFile { copy_file_perms(&dir, &name, &file, opts)?; } + let file = file.into(); + Ok(Self { dir, file, @@ -113,8 +118,13 @@ impl TemporaryFile { temporary_name, }) } +} - pub(crate) fn rename_file(&self) -> Result<()> { +impl TemporaryFile { + pub(crate) fn rename_file(&self) -> Result<()> + where + F: AsRawFd, + { match self.temporary_name { None => rename_unnamed_temporary_file(&self.dir, &self.file, &self.name)?, Some(ref temporary_name) => { @@ -137,3 +147,19 @@ impl TemporaryFile { Some(&self.dir) } } + +#[cfg(feature = "async-std")] +impl TemporaryFile { + #[inline] + pub(crate) fn into(self) -> TemporaryFile + where + G: From, + { + TemporaryFile:: { + dir: self.dir, + file: self.file.into(), + name: self.name, + temporary_name: self.temporary_name, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 747cdf0..2f87240 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -207,6 +207,9 @@ mod fd; #[cfg(unix)] pub mod unix; +#[cfg(feature = "async-std")] +pub mod future; + #[cfg(test)] mod tests; diff --git a/src/tests.rs b/src/tests.rs index beb7c08..cd2213a 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -10,7 +10,7 @@ use std::panic; use std::path::Path; use std::path::PathBuf; -fn test_directory() -> PathBuf { +pub(crate) fn test_directory() -> PathBuf { let path = option_env!("TEST_DIR").unwrap_or("target/test-files"); println!("using test directory: {path:?}"); fs::create_dir_all(path) @@ -18,7 +18,7 @@ fn test_directory() -> PathBuf { path.into() } -fn test_file>(name: P) -> PathBuf { +pub(crate) fn test_file>(name: P) -> PathBuf { let mut path = test_directory(); path.push(name); match fs::remove_file(&path) { @@ -29,7 +29,7 @@ fn test_file>(name: P) -> PathBuf { path } -fn list_temporary_files>(path: P) -> impl Iterator { +pub(crate) fn list_temporary_files>(path: P) -> impl Iterator { let path = path.as_ref(); let dir_path = path.parent().unwrap(); let file_name = path.file_name().unwrap(); @@ -54,7 +54,7 @@ fn list_temporary_files>(path: P) -> impl Iterator, P2: AsRef>( +pub(crate) fn verify_temporary_file_name, P2: AsRef>( dst_file_name: P1, temp_file_name: P2, ) { @@ -70,7 +70,7 @@ fn verify_temporary_file_name, P2: AsRef>( ); } -fn verify_no_leftovers>(path: P) { +pub(crate) fn verify_no_leftovers>(path: P) { let leftovers = list_temporary_files(path).collect::>(); if !leftovers.is_empty() { panic!("found leftover files: {leftovers:?}");