Skip to content

Commit

Permalink
refactor: drop futures_util dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Nov 30, 2023
1 parent 0e9ba3a commit c474531
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 29 deletions.
3 changes: 0 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ unindent = { version = "0.2.1", optional = true }
# support crate for multiple-pymethods feature
inventory = { version = "0.3.0", optional = true }

# coroutine implementation
futures-util = "0.3"

# crate integrations that can be added using the eponymous features
anyhow = { version = "1.0", optional = true }
chrono = { version = "0.4.25", default-features = false, optional = true }
Expand Down
31 changes: 17 additions & 14 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
//! Python coroutine implementation, used notably when wrapping `async fn`
//! with `#[pyfunction]`/`#[pymethods]`.
use std::task::Waker;
use std::{
any::Any,
future::Future,
panic,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use futures_util::FutureExt;
use pyo3_macros::{pyclass, pymethods};

use crate::{
coroutine::waker::AsyncioWaker,
exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
panic::PanicException,
pyclass::IterNextOutput,
types::{PyIterator, PyString},
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
Expand All @@ -25,19 +23,18 @@ pub(crate) mod cancel;
mod waker;

use crate::coroutine::cancel::ThrowCallback;
use crate::panic::PanicException;
pub use cancel::CancelHandle;

const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";

type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;

/// Python coroutine wrapping a [`Future`].
#[pyclass(crate = "crate")]
pub struct Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: Option<Pin<Box<dyn Future<Output = FutureOutput> + Send>>>,
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
waker: Option<Arc<AsyncioWaker>>,
}

Expand Down Expand Up @@ -68,7 +65,7 @@ impl Coroutine {
name,
qualname_prefix,
throw_callback,
future: Some(Box::pin(panic::AssertUnwindSafe(wrap).catch_unwind())),
future: Some(Box::pin(wrap)),
waker: None,
}
}
Expand Down Expand Up @@ -98,14 +95,20 @@ impl Coroutine {
} else {
self.waker = Some(Arc::new(AsyncioWaker::new()));
}
let waker = futures_util::task::waker(self.waker.clone().unwrap());
let waker = Waker::from(self.waker.clone().unwrap());
// poll the Rust future and forward its results if ready
if let Poll::Ready(res) = future_rs.as_mut().poll(&mut Context::from_waker(&waker)) {
self.close();
return match res {
Ok(res) => Ok(IterNextOutput::Return(res?)),
Err(err) => Err(PanicException::from_panic_payload(err)),
};
// polling is UnwindSafe because the future is dropped in case of panic
let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
Ok(Poll::Ready(res)) => {
self.close();
return Ok(IterNextOutput::Return(res?));
}
Err(err) => {
self.close();
return Err(PanicException::from_panic_payload(err));
}
_ => {}
}
// otherwise, initialize the waker `asyncio.Future`
if let Some(future) = self.waker.as_ref().unwrap().initialize_future(py)? {
Expand Down
28 changes: 20 additions & 8 deletions src/coroutine/cancel.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use crate::{ffi, Py, PyAny, PyObject};
use futures_util::future::poll_fn;
use futures_util::task::AtomicWaker;
use std::future::Future;
use std::pin::Pin;
use std::ptr;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicPtr, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};

#[derive(Debug, Default)]
struct Inner {
exception: AtomicPtr<ffi::PyObject>,
waker: AtomicWaker,
waker: Mutex<Option<Waker>>,
}

/// Helper used to wait and retrieve exception thrown in [`Coroutine`](super::Coroutine).
Expand Down Expand Up @@ -43,16 +43,17 @@ impl CancelHandle {
if self.is_cancelled() {
return Poll::Ready(take());
}
self.0.waker.register(cx.waker());
let mut guard = self.0.waker.lock().unwrap();
if self.is_cancelled() {
return Poll::Ready(take());
}
*guard = Some(cx.waker().clone());
Poll::Pending
}

/// Retrieve the exception thrown in the associated coroutine.
pub async fn cancelled(&mut self) -> PyObject {
poll_fn(|cx| self.poll_cancelled(cx)).await
Cancelled(self).await
}

#[doc(hidden)]
Expand All @@ -61,6 +62,15 @@ impl CancelHandle {
}
}

struct Cancelled<'a>(&'a mut CancelHandle);

impl Future for Cancelled<'_> {
type Output = PyObject;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_cancelled(cx)
}
}

#[doc(hidden)]
pub struct ThrowCallback(Arc<Inner>);

Expand All @@ -69,6 +79,8 @@ impl ThrowCallback {
let ptr = self.0.exception.swap(exc.into_ptr(), Ordering::Relaxed);
// SAFETY: non-null pointers set in `self.0.exceptions` are valid owned pointers
drop(unsafe { PyObject::from_owned_ptr_or_opt(exc.py(), ptr) });
self.0.waker.wake();
if let Some(waker) = self.0.waker.lock().unwrap().take() {
waker.wake();
}
}
}
12 changes: 8 additions & 4 deletions src/coroutine/waker.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::sync::GILOnceCell;
use crate::types::PyCFunction;
use crate::{intern, wrap_pyfunction, Py, PyAny, PyObject, PyResult, Python};
use futures_util::task::ArcWake;
use pyo3_macros::pyfunction;
use std::sync::Arc;
use std::task::Wake;

/// Lazy `asyncio.Future` wrapper, implementing [`ArcWake`] by calling `Future.set_result`.
///
Expand Down Expand Up @@ -31,10 +31,14 @@ impl AsyncioWaker {
}
}

impl ArcWake for AsyncioWaker {
fn wake_by_ref(arc_self: &Arc<Self>) {
impl Wake for AsyncioWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref()
}

fn wake_by_ref(self: &Arc<Self>) {
Python::with_gil(|gil| {
if let Some(loop_and_future) = arc_self.0.get_or_init(gil, || None) {
if let Some(loop_and_future) = self.0.get_or_init(gil, || None) {
loop_and_future
.set_result(gil)
.expect("unexpected error in coroutine waker");
Expand Down

0 comments on commit c474531

Please sign in to comment.