diff --git a/src/stack/mod.rs b/src/stack/mod.rs index c71e33a..684e5fb 100644 --- a/src/stack/mod.rs +++ b/src/stack/mod.rs @@ -365,8 +365,9 @@ impl Stack { let place_ptr = NonNull::new_unchecked(Box::into_raw(place)); let fut = (f)(ctx); - self.tasks - .push(async move { place_ptr.as_ref().get().write(Some(fut.await)) }); + self.tasks.push(async move { + place_ptr.as_ref().get().write(Some(fut.await)); + }); Runner { place: place_ptr, diff --git a/src/tree/mod.rs b/src/tree/mod.rs index 81eb958..c730fd0 100644 --- a/src/tree/mod.rs +++ b/src/tree/mod.rs @@ -85,7 +85,7 @@ impl<'a, 'b, R> Future for StepFuture<'a, 'b, R> { } } - // No futures left in fanout, run on the root stack. + // No futures left in fanout, run on the root stack.l match self.runner.ptr.root.drive_head(cx) { Poll::Ready(_) => { if self.runner.ptr.root.tasks().is_empty() { @@ -97,7 +97,13 @@ impl<'a, 'b, R> Future for StepFuture<'a, 'b, R> { } } Poll::Pending => match self.runner.ptr.root.get_state() { - State::Base => return Poll::Pending, + State::Base => { + if self.runner.ptr.fanout.is_empty() { + return Poll::Pending; + } else { + return Poll::Ready(None); + } + } State::Cancelled => unreachable!("TreeStack dropped while stepping"), State::NewTask | State::Yield => {} }, diff --git a/src/tree/schedular/mod.rs b/src/tree/schedular/mod.rs index 53f45a9..d3e370b 100644 --- a/src/tree/schedular/mod.rs +++ b/src/tree/schedular/mod.rs @@ -21,6 +21,7 @@ use self::queue::NodeHeader; #[derive(Debug, Clone)] pub(crate) struct SchedularVTable { + task_drive: unsafe fn(NonNull>, cx: &mut Context) -> Poll<()>, task_drop: unsafe fn(NonNull>), } @@ -28,12 +29,18 @@ impl SchedularVTable { pub const fn get>() -> SchedularVTable { SchedularVTable { task_drop: Self::drop::, + task_drive: Self::drive::, } } unsafe fn drop>(ptr: NonNull>) { Arc::decrement_strong_count(ptr.cast::>().as_ptr()) } + + unsafe fn drive>(ptr: NonNull>, cx: &mut Context) -> Poll<()> { + let ptr = ptr.cast::>(); + Pin::new_unchecked(&mut (*ptr.as_ref().future.get())).poll(cx) + } } #[repr(C)] @@ -153,8 +160,7 @@ impl Schedular { } unsafe fn drive_task(ptr: NonNull>, ctx: &mut Context) -> Poll<()> { - let future_ptr = NonNull::new_unchecked(ptr.as_ref().future.get()); - (ptr.as_ref().body.vtable.driver)(future_ptr, ctx) + (ptr.as_ref().body.vtable.tree.task_drive)(ptr, ctx) } pub unsafe fn poll(&self, cx: &mut Context) -> Poll<()> {