Skip to content

Commit

Permalink
Refactor future::AsyncWaitGroup
Browse files Browse the repository at this point in the history
  • Loading branch information
al8n committed Feb 17, 2024
1 parent f111bc0 commit e5bebe8
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 155 deletions.
9 changes: 6 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ homepage = "https://github.com/al8n/wg"
repository = "https://github.com/al8n/wg.git"
documentation = "https://docs.rs/wg/"
readme = "README.md"
version = "0.7.2"
version = "0.7.3"
license = "MIT OR Apache-2.0"
keywords = ["waitgroup", "async", "sync", "notify", "wake"]
categories = ["asynchronous", "concurrency", "data-structures"]
Expand All @@ -18,15 +18,14 @@ full = ["triomphe", "parking_lot"]
triomphe = ["dep:triomphe"]
parking_lot = ["dep:parking_lot"]

future = ["event-listener", "event-listener-strategy", "pin-project-lite"]
future = ["event-listener", "pin-project-lite"]

tokio = ["dep:tokio", "futures-core", "pin-project-lite"]

[dependencies]
parking_lot = { version = "0.12", optional = true }
triomphe = { version = "0.1", optional = true }
event-listener = { version = "5", optional = true }
event-listener-strategy = { version = "0.5", optional = true }
pin-project-lite = { version = "0.2", optional = true }

tokio = { version = "1", default-features = false, optional = true, features = ["sync", "rt"] }
Expand All @@ -50,3 +49,7 @@ name = "future"
path = "tests/future.rs"
required-features = ["future"]

[[test]]
name = "sync"
path = "tests/sync.rs"

104 changes: 46 additions & 58 deletions src/future.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use super::*;
use event_listener::{Event, EventListener};
use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy};

use std::{
pin::Pin,
sync::atomic::{AtomicUsize, Ordering},
task::Poll,
task::{Context, Poll},
};

#[derive(Debug)]
Expand Down Expand Up @@ -163,7 +162,7 @@ impl AsyncWaitGroup {
/// });
/// # })
/// ```
pub fn done(&self) {
pub fn done(self) {
if self.inner.counter.fetch_sub(1, Ordering::SeqCst) == 1 {
self.inner.event.notify(usize::MAX);
}
Expand Down Expand Up @@ -197,7 +196,11 @@ impl AsyncWaitGroup {
/// # })
/// ```
pub fn wait(&self) -> WaitGroupFuture<'_> {
WaitGroupFuture::_new(WaitGroupFutureInner::new(&self.inner))
WaitGroupFuture {
inner: self,
notified: self.inner.event.listen(),
_pin: std::marker::PhantomPinned,
}
}

/// Wait blocks until the [`AsyncWaitGroup`] counter is zero. This method is
Expand All @@ -222,79 +225,64 @@ impl AsyncWaitGroup {
/// t_wg.done()
/// });
///
/// let spawner = |fut| {
/// spawn(fut);
/// };
///
/// // wait other thread completes
/// wg.block_wait();
/// wg.block_wait(spawner);
/// # })
/// ```
pub fn block_wait(&self) {
WaitGroupFutureInner::new(&self.inner).wait();
pub fn block_wait<S>(&self, spawner: S)
where
S: FnOnce(Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>>),
{
let this = self.clone();
let (tx, rx) = std::sync::mpsc::channel();
spawner(Box::pin(async move {
this.wait().await;
let _ = tx.send(());
}));

let _ = rx.recv();
}
}

easy_wrapper! {
pin_project_lite::pin_project! {
/// A future returned by [`AsyncWaitGroup::wait()`].
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[cfg_attr(docsrs, doc(cfg(feature = "future")))]
pub struct WaitGroupFuture<'a>(WaitGroupFutureInner<'a> => ());

#[cfg(all(feature = "std", not(target_family = "wasm")))]
pub(crate) wait();
}

pin_project_lite::pin_project! {
/// A future that used to wait for the [`AsyncWaitGroup`] counter is zero.
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[project(!Unpin)]
#[derive(Debug)]
struct WaitGroupFutureInner<'a> {
inner: &'a Arc<AsyncInner>,
listener: Option<EventListener>,
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
pub struct WaitGroupFuture<'a> {
inner: &'a AsyncWaitGroup,
#[pin]
notified: EventListener,
#[pin]
_pin: std::marker::PhantomPinned,
}
}

impl<'a> WaitGroupFutureInner<'a> {
fn new(inner: &'a Arc<AsyncInner>) -> Self {
Self {
inner,
listener: None,
_pin: std::marker::PhantomPinned,
}
}
}

impl EventListenerFuture for WaitGroupFutureInner<'_> {
impl<'a> std::future::Future for WaitGroupFuture<'a> {
type Output = ();

fn poll_with_strategy<'a, S: Strategy<'a>>(
self: Pin<&mut Self>,
strategy: &mut S,
context: &mut S::Context,
) -> Poll<Self::Output> {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.inner.inner.counter.load(Ordering::SeqCst) == 0 {
return Poll::Ready(());
}

let this = self.project();
loop {
if this.inner.counter.load(Ordering::SeqCst) == 0 {
return Poll::Ready(());
match this.notified.poll(cx) {
Poll::Pending => {
cx.waker().wake_by_ref();
Poll::Pending
}

if this.listener.is_some() {
// Poll using the given strategy
match S::poll(strategy, &mut *this.listener, context) {
Poll::Ready(_) => {
// Event received, check the condition again.
if this.inner.counter.load(Ordering::SeqCst) == 0 {
return Poll::Ready(());
}

// Event received but condition not met, reset listener.
*this.listener = None;
}
Poll::Pending => return Poll::Pending,
Poll::Ready(_) => {
if this.inner.inner.counter.load(Ordering::SeqCst) == 0 {
Poll::Ready(())
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
} else {
*this.listener = Some(this.inner.event.listen());
}
}
}
Expand Down
91 changes: 1 addition & 90 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ impl WaitGroup {
/// });
///
/// ```
pub fn done(&self) {
pub fn done(self) {
let mut val = self.inner.count.lock_me();

*val = if val.eq(&1) {
Expand Down Expand Up @@ -277,92 +277,3 @@ impl WaitGroup {
}
}
}

#[cfg(test)]
mod test {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;

#[test]
fn test_sync_wait_group_reuse() {
let wg = WaitGroup::new();
let ctr = Arc::new(AtomicUsize::new(0));
for _ in 0..6 {
let wg = wg.add(1);
let ctrx = ctr.clone();
std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(5));
ctrx.fetch_add(1, Ordering::Relaxed);
wg.done();
});
}

wg.wait();
assert_eq!(ctr.load(Ordering::Relaxed), 6);

let worker = wg.add(1);
let ctrx = ctr.clone();
std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(5));
ctrx.fetch_add(1, Ordering::Relaxed);
worker.done();
});
wg.wait();
assert_eq!(ctr.load(Ordering::Relaxed), 7);
}

#[test]
fn test_sync_wait_group_nested() {
let wg = WaitGroup::new();
let ctr = Arc::new(AtomicUsize::new(0));
for _ in 0..5 {
let worker = wg.add(1);
let ctrx = ctr.clone();
std::thread::spawn(move || {
let nested_worker = worker.add(1);
let ctrxx = ctrx.clone();
std::thread::spawn(move || {
ctrxx.fetch_add(1, Ordering::Relaxed);
nested_worker.done();
});
ctrx.fetch_add(1, Ordering::Relaxed);
worker.done();
});
}

wg.wait();
assert_eq!(ctr.load(Ordering::Relaxed), 10);
}

#[test]
fn test_sync_wait_group_from() {
std::thread::scope(|s| {
let wg = WaitGroup::from(5);
for _ in 0..5 {
let t = wg.clone();
s.spawn(move || {
t.done();
});
}
wg.wait();
});
}

#[test]
fn test_clone_and_fmt() {
let swg = WaitGroup::new();
let swg1 = swg.clone();
swg1.add(3);
assert_eq!(format!("{:?}", swg), format!("{:?}", swg1));
}

#[test]
fn test_waitings() {
let wg = WaitGroup::new();
wg.add(1);
wg.add(1);
assert_eq!(wg.waitings(), 2);
}
}
18 changes: 16 additions & 2 deletions src/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl AsyncWaitGroup {
/// });
/// }
/// ```
pub fn done(&self) {
pub fn done(self) {
if self.inner.counter.fetch_sub(1, Ordering::SeqCst) == 1 {
self.inner.notify.notify_waiters();
}
Expand Down Expand Up @@ -261,6 +261,20 @@ impl<'a> Future for WaitGroupFuture<'a> {
return Poll::Ready(());
}

self.project().notified.poll(cx)
let this = self.project();
match this.notified.poll(cx) {
Poll::Pending => {
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Ready(_) => {
if this.inner.inner.counter.load(Ordering::SeqCst) == 0 {
Poll::Ready(())
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
}
}
6 changes: 4 additions & 2 deletions tests/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,11 @@ fn test_async_block_wait() {
// do some time consuming task
t_wg.done();
});

let spawner = |fut| {
async_std::task::spawn(fut);
};
// wait other thread completes
wg.block_wait();
wg.block_wait(spawner);

assert_eq!(wg.waitings(), 0);
}
Expand Down
Loading

0 comments on commit e5bebe8

Please sign in to comment.