Skip to content

Commit

Permalink
feat: add coroutine __name__/__qualname__ and not-awaited warning
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Perez committed Nov 20, 2023
1 parent 49ae17c commit d29b9cd
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 16 deletions.
18 changes: 17 additions & 1 deletion pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,23 @@ impl<'a> FnSpec<'a> {
let rust_call = |args: Vec<TokenStream>| {
let mut call = quote! { function(#self_arg #(#args),*) };
if self.asyncness.is_some() {
call = quote! { _pyo3::impl_::coroutine::wrap_future(#call) };
let python_name = &self.python_name;
let qualname = match cls {
Some(cls) => quote! {
Some(_pyo3::impl_::coroutine::method_coroutine_qualname::<#cls>(py, stringify!(#python_name)))
},
None => quote! {
_pyo3::impl_::coroutine::coroutine_qualname(py, py.from_borrowed_ptr_or_opt::<_pyo3::types::PyModule>(_slf), stringify!(#python_name))
},
};
call = quote! {{
let future = #call;
_pyo3::impl_::coroutine::new_coroutine(
Some(_pyo3::impl_::coroutine::coroutine_name(py, stringify!(#python_name))),
#qualname,
async move { _pyo3::impl_::wrap::OkWrap::wrap_no_gil(future.await) }
)
}};
}
quotes::map_result_into_ptr(quotes::ok_wrap(call))
};
Expand Down
47 changes: 44 additions & 3 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ use pyo3_macros::{pyclass, pymethods};

use crate::{
coroutine::waker::AsyncioWaker,
exceptions::{PyRuntimeError, PyStopIteration},
exceptions::{PyAttributeError, PyRuntimeError, PyRuntimeWarning, PyStopIteration},
panic::PanicException,
pyclass::IterNextOutput,
types::PyIterator,
types::{PyIterator, PyString},
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
};

Expand All @@ -30,6 +30,8 @@ type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
/// Python coroutine wrapping a [`Future`].
#[pyclass(crate = "crate")]
pub struct Coroutine {
name: Option<Py<PyString>>,
qualname: Option<Py<PyString>>,
future: Option<Pin<Box<dyn Future<Output = FutureOutput> + Send>>>,
waker: Option<Arc<AsyncioWaker>>,
}
Expand All @@ -41,7 +43,11 @@ impl Coroutine {
/// (should always be `None` anyway).
///
/// `Coroutine `throw` drop the wrapped future and reraise the exception passed
pub(crate) fn from_future<F, T, E>(future: F) -> Self
pub(crate) fn new<F, T, E>(
name: Option<Py<PyString>>,
mut qualname: Option<Py<PyString>>,
future: F,
) -> Self
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject>,
Expand All @@ -52,7 +58,10 @@ impl Coroutine {
// SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`)
Ok(obj.into_py(unsafe { Python::assume_gil_acquired() }))
};
qualname = qualname.or_else(|| name.clone());
Self {
name,
qualname,
future: Some(Box::pin(panic::AssertUnwindSafe(wrap).catch_unwind())),
waker: None,
}
Expand Down Expand Up @@ -113,6 +122,20 @@ pub(crate) fn iter_result(result: IterNextOutput<PyObject, PyObject>) -> PyResul

#[pymethods(crate = "crate")]
impl Coroutine {
#[getter]
fn __name__(&self) -> PyResult<Py<PyString>> {
self.name
.clone()
.ok_or_else(|| PyAttributeError::new_err("__name__"))
}

#[getter]
fn __qualname__(&self) -> PyResult<Py<PyString>> {
self.qualname
.clone()
.ok_or_else(|| PyAttributeError::new_err("__qualname__"))
}

fn send(&mut self, py: Python<'_>, _value: &PyAny) -> PyResult<PyObject> {
iter_result(self.poll(py, None)?)
}
Expand All @@ -135,3 +158,21 @@ impl Coroutine {
self.poll(py, None)
}
}

impl Drop for Coroutine {
fn drop(&mut self) {
if self.future.is_some() {
Python::with_gil(|gil| {
let qualname = self
.qualname
.as_ref()
.map_or(Ok("<coroutine>"), |n| n.as_ref(gil).to_str())
.unwrap();
let message = format!("coroutine {qualname} was never awaited");
PyErr::warn(gil, gil.get_type::<PyRuntimeWarning>(), &message, 2)
.expect("warning error");
self.poll(gil, None).expect("coroutine close error");
})
}
}
}
46 changes: 34 additions & 12 deletions src/impl_/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,41 @@
use crate::coroutine::Coroutine;
use crate::impl_::wrap::OkWrap;
use crate::{IntoPy, PyErr, PyObject, Python};
use std::future::Future;

/// Used to wrap the result of async `#[pyfunction]` and `#[pymethods]`.
pub fn wrap_future<F, R, T>(future: F) -> Coroutine
use crate::{
coroutine::Coroutine,
types::{PyModule, PyString},
IntoPy, Py, PyClass, PyErr, PyObject, Python,
};

pub fn new_coroutine<F, T, E>(
name: Option<Py<PyString>>,
qualname: Option<Py<PyString>>,
future: F,
) -> Coroutine
where
F: Future<Output = R> + Send + 'static,
R: OkWrap<T>,
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject>,
PyErr: From<R::Error>,
PyErr: From<E>,
{
let future = async move {
// SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`)
future.await.wrap(unsafe { Python::assume_gil_acquired() })
Coroutine::new(name, qualname, future)
}

pub fn coroutine_name(py: Python<'_>, name: &str) -> Py<PyString> {
PyString::new(py, name).into()
}

pub unsafe fn coroutine_qualname(
py: Python<'_>,
module: Option<&PyModule>,
name: &str,
) -> Option<Py<PyString>> {
Some(PyString::new(py, &format!("{}.{name}", module?.name().ok()?)).into())
}

pub fn method_coroutine_qualname<T: PyClass>(py: Python<'_>, name: &str) -> Py<PyString> {
let class = T::NAME;
let qualname = match T::MODULE {
Some(module) => format!("{module}.{class}.{name}"),
None => format!("{class}.{name}"),
};
Coroutine::from_future(future)
PyString::new(py, &qualname).into()
}
7 changes: 7 additions & 0 deletions src/impl_/wrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ impl<T> SomeWrap<Option<T>> for Option<T> {
/// Used to wrap the result of `#[pyfunction]` and `#[pymethods]`.
pub trait OkWrap<T> {
type Error;
fn wrap_no_gil(self) -> Result<T, Self::Error>;
fn wrap(self, py: Python<'_>) -> Result<Py<PyAny>, Self::Error>;
}

Expand All @@ -30,6 +31,9 @@ where
T: IntoPy<PyObject>,
{
type Error = PyErr;
fn wrap_no_gil(self) -> Result<T, Self::Error> {
Ok(self)
}
fn wrap(self, py: Python<'_>) -> PyResult<Py<PyAny>> {
Ok(self.into_py(py))
}
Expand All @@ -40,6 +44,9 @@ where
T: IntoPy<PyObject>,
{
type Error = E;
fn wrap_no_gil(self) -> Result<T, Self::Error> {
self
}
fn wrap(self, py: Python<'_>) -> Result<Py<PyAny>, Self::Error> {
self.map(|o| o.into_py(py))
}
Expand Down
67 changes: 67 additions & 0 deletions tests/test_coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#![cfg(feature = "macros")]
use std::ops::Deref;
use std::{task::Poll, thread, time::Duration};

use futures::{channel::oneshot, future::poll_fn};
use pyo3::types::IntoPyDict;
use pyo3::{prelude::*, py_run};

#[path = "../src/tests/common.rs"]
Expand Down Expand Up @@ -29,6 +31,71 @@ fn noop_coroutine() {
})
}

#[test]
fn test_coroutine_qualname() {
#[pyfunction]
async fn my_fn() {}
#[pyclass]
struct MyClass;
#[pymethods]
impl MyClass {
#[new]
fn new() -> Self {
Self
}
// TODO use &self when possible
async fn my_method(_self: Py<Self>) {}
// TODO uncomment when https://github.com/PyO3/pyo3/pull/3587 is merged
// #[classmethod]
// async fn my_classmethod(_cls: Py<PyType>) {}
#[staticmethod]
async fn my_staticmethod() {}
}
#[pyclass(module = "my_module")]
struct MyClassWithModule;
#[pymethods]
impl MyClassWithModule {
#[new]
fn new() -> Self {
Self
}
// TODO use &self when possible
async fn my_method(_self: Py<Self>) {}
// TODO uncomment when https://github.com/PyO3/pyo3/pull/3587 is merged
// #[classmethod]
// async fn my_classmethod(_cls: Py<PyType>) {}
#[staticmethod]
async fn my_staticmethod() {}
}
Python::with_gil(|gil| {
let test = r#"
for coro, name, qualname in [
(my_fn(), "my_fn", "my_fn"),
(my_fn_with_module(), "my_fn", "my_module.my_fn"),
(MyClass().my_method(), "my_method", "MyClass.my_method"),
#(MyClass().my_classmethod(), "my_classmethod", "MyClass.my_classmethod"),
(MyClass.my_staticmethod(), "my_staticmethod", "MyClass.my_staticmethod"),
(MyClassWithModule().my_method(), "my_method", "my_module.MyClassWithModule.my_method"),
#(MyClassWithModule().my_classmethod(), "my_classmethod", "my_module.MyClassWithModule.my_classmethod"),
(MyClassWithModule.my_staticmethod(), "my_staticmethod", "my_module.MyClassWithModule.my_staticmethod"),
]:
assert coro.__name__ == name and coro.__qualname__ == qualname
"#;
let my_module = PyModule::new(gil, "my_module").unwrap();
let locals = [
("my_fn", wrap_pyfunction!(my_fn, gil).unwrap().deref()),
(
"my_fn_with_module",
wrap_pyfunction!(my_fn, my_module).unwrap(),
),
("MyClass", gil.get_type::<MyClass>()),
("MyClassWithModule", gil.get_type::<MyClassWithModule>()),
]
.into_py_dict(gil);
py_run!(gil, *locals, &handle_windows(test));
})
}

#[test]
fn sleep_0_like_coroutine() {
#[pyfunction]
Expand Down

0 comments on commit d29b9cd

Please sign in to comment.