diff --git a/CHANGELOG.md b/CHANGELOG.md index 6353de98a54..a5fcea0d505 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Added * Have `PyModule` generate an index of its members (`__all__` list). + * Allow `slf: PyRef` for pyclass(#419) ### Changed diff --git a/pyo3-derive-backend/src/method.rs b/pyo3-derive-backend/src/method.rs index cc46f2d49a2..82fb295673c 100644 --- a/pyo3-derive-backend/src/method.rs +++ b/pyo3-derive-backend/src/method.rs @@ -27,6 +27,7 @@ pub enum FnType { FnCall, FnClass, FnStatic, + PySelf(syn::TypePath), } #[derive(Clone, PartialEq, Debug)] @@ -51,11 +52,10 @@ impl<'a> FnSpec<'a> { sig: &'a syn::MethodSig, meth_attrs: &'a mut Vec, ) -> syn::Result> { - let (fn_type, fn_attrs) = parse_attributes(meth_attrs)?; + let (mut fn_type, fn_attrs) = parse_attributes(meth_attrs)?; let mut has_self = false; let mut arguments = Vec::new(); - for input in sig.decl.inputs.iter() { match input { syn::FnArg::SelfRef(_) => { @@ -119,6 +119,17 @@ impl<'a> FnSpec<'a> { let ty = get_return_info(&sig.decl.output); + if fn_type == FnType::Fn && !has_self { + if arguments.len() == 0 { + panic!("Static method needs #[staticmethod] attribute"); + } + let tp = match arguments.remove(0).ty { + syn::Type::Path(p) => replace_self(p), + _ => panic!("Invalid type as self"), + }; + fn_type = FnType::PySelf(tp); + } + Ok(FnSpec { tp: fn_type, attrs: fn_attrs, @@ -380,3 +391,25 @@ fn parse_attributes(attrs: &mut Vec) -> syn::Result<(FnType, Vec None => Ok((FnType::Fn, spec)), } } + +// Replace A with A<_> +fn replace_self(path: &syn::TypePath) -> syn::TypePath { + fn infer(span: proc_macro2::Span) -> syn::GenericArgument { + syn::GenericArgument::Type(syn::Type::Infer(syn::TypeInfer { + underscore_token: syn::token::Underscore { spans: [span] }, + })) + } + let mut res = path.to_owned(); + for seg in &mut res.path.segments { + if let syn::PathArguments::AngleBracketed(ref mut g) = seg.arguments { + for arg in &mut g.args { + if let syn::GenericArgument::Type(syn::Type::Path(p)) = arg { + if p.path.segments.len() == 1 && p.path.segments[0].ident == "Self" { + *arg = infer(p.path.segments[0].ident.span()); + } + } + } + } + } + res +} diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index a4688aba921..d68cc39dec9 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -1,5 +1,4 @@ // Copyright (c) 2017-present PyO3 Project and Contributors - use crate::method::{FnArg, FnSpec, FnType}; use crate::utils; use proc_macro2::{Span, TokenStream}; @@ -18,6 +17,12 @@ pub fn gen_py_method( match spec.tp { FnType::Fn => impl_py_method_def(name, doc, &spec, &impl_wrap(cls, name, &spec, true)), + FnType::PySelf(ref self_ty) => impl_py_method_def( + name, + doc, + &spec, + &impl_wrap_pyslf(cls, name, &spec, self_ty, true), + ), FnType::FnNew => impl_py_method_def_new(name, doc, &impl_wrap_new(cls, name, &spec)), FnType::FnInit => impl_py_method_def_init(name, doc, &impl_wrap_init(cls, name, &spec)), FnType::FnCall => impl_py_method_def_call(name, doc, &impl_wrap(cls, name, &spec, false)), @@ -48,7 +53,33 @@ pub fn impl_wrap( noargs: bool, ) -> TokenStream { let body = impl_call(cls, name, &spec); + let slf = impl_self("e! { &mut #cls }); + impl_wrap_common(cls, name, spec, noargs, slf, body) +} + +pub fn impl_wrap_pyslf( + cls: &syn::Type, + name: &syn::Ident, + spec: &FnSpec<'_>, + self_ty: &syn::TypePath, + noargs: bool, +) -> TokenStream { + let names = get_arg_names(spec); + let body = quote! { + #cls::#name(_slf, #(#names),*) + }; + let slf = impl_self(self_ty); + impl_wrap_common(cls, name, spec, noargs, slf, body) +} +fn impl_wrap_common( + cls: &syn::Type, + name: &syn::Ident, + spec: &FnSpec<'_>, + noargs: bool, + slf: TokenStream, + body: TokenStream, +) -> TokenStream { if spec.args.is_empty() && noargs { quote! { unsafe extern "C" fn __wrap( @@ -59,8 +90,7 @@ pub fn impl_wrap( stringify!(#cls), ".", stringify!(#name), "()"); let _pool = pyo3::GILPool::new(); let _py = pyo3::Python::assume_gil_acquired(); - let _slf = _py.mut_from_borrowed_ptr::<#cls>(_slf); - + #slf let _result = { pyo3::derive_utils::IntoPyResult::into_py_result(#body) }; @@ -82,7 +112,7 @@ pub fn impl_wrap( stringify!(#cls), ".", stringify!(#name), "()"); let _pool = pyo3::GILPool::new(); let _py = pyo3::Python::assume_gil_acquired(); - let _slf = _py.mut_from_borrowed_ptr::<#cls>(_slf); + #slf let _args = _py.from_borrowed_ptr::(_args); let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs); @@ -346,6 +376,12 @@ fn impl_call(_cls: &syn::Type, fname: &syn::Ident, spec: &FnSpec<'_>) -> TokenSt quote! { _slf.#fname(#(#names),*) } } +fn impl_self(self_ty: &T) -> TokenStream { + quote! { + let _slf: #self_ty = pyo3::FromPyPointer::from_borrowed_ptr(_py, _slf); + } +} + /// Converts a bool to "true" or "false" fn bool_to_ident(condition: bool) -> syn::Ident { if condition { diff --git a/src/conversion.rs b/src/conversion.rs index 72b7442e518..21c80669bd5 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -1,15 +1,13 @@ // Copyright (c) 2017-present PyO3 Project and Contributors //! Conversions between various states of rust and python types and their wrappers. - -use crate::err::{PyDowncastError, PyResult}; -use crate::ffi; +use crate::err::{self, PyDowncastError, PyResult}; use crate::object::PyObject; use crate::type_object::PyTypeInfo; use crate::types::PyAny; use crate::types::PyTuple; -use crate::Py; -use crate::Python; +use crate::{ffi, gil, Py, Python}; +use std::ptr::NonNull; /// This trait represents that, **we can do zero-cost conversion from the object to FFI pointer**. /// @@ -432,6 +430,66 @@ impl FromPy<()> for Py { } } +/// Raw level conversion between `*mut ffi::PyObject` and PyO3 types. +pub unsafe trait FromPyPointer<'p>: Sized { + unsafe fn from_owned_ptr_or_opt(py: Python<'p>, ptr: *mut ffi::PyObject) -> Option; + unsafe fn from_owned_ptr_or_panic(py: Python<'p>, ptr: *mut ffi::PyObject) -> Self { + match Self::from_owned_ptr_or_opt(py, ptr) { + Some(s) => s, + None => err::panic_after_error(), + } + } + unsafe fn from_owned_ptr(py: Python<'p>, ptr: *mut ffi::PyObject) -> Self { + Self::from_owned_ptr_or_panic(py, ptr) + } + unsafe fn from_owned_ptr_or_err(py: Python<'p>, ptr: *mut ffi::PyObject) -> PyResult { + match Self::from_owned_ptr_or_opt(py, ptr) { + Some(s) => Ok(s), + None => Err(err::PyErr::fetch(py)), + } + } + unsafe fn from_borrowed_ptr_or_opt(py: Python<'p>, ptr: *mut ffi::PyObject) -> Option; + unsafe fn from_borrowed_ptr_or_panic(py: Python<'p>, ptr: *mut ffi::PyObject) -> Self { + match Self::from_borrowed_ptr_or_opt(py, ptr) { + Some(s) => s, + None => err::panic_after_error(), + } + } + unsafe fn from_borrowed_ptr(py: Python<'p>, ptr: *mut ffi::PyObject) -> Self { + Self::from_borrowed_ptr_or_panic(py, ptr) + } + unsafe fn from_borrowed_ptr_or_err(py: Python<'p>, ptr: *mut ffi::PyObject) -> PyResult { + match Self::from_borrowed_ptr_or_opt(py, ptr) { + Some(s) => Ok(s), + None => Err(err::PyErr::fetch(py)), + } + } +} + +unsafe impl<'p, T> FromPyPointer<'p> for &'p T +where + T: PyTypeInfo, +{ + unsafe fn from_owned_ptr_or_opt(py: Python<'p>, ptr: *mut ffi::PyObject) -> Option { + NonNull::new(ptr).map(|p| py.unchecked_downcast(gil::register_owned(py, p))) + } + unsafe fn from_borrowed_ptr_or_opt(py: Python<'p>, ptr: *mut ffi::PyObject) -> Option { + NonNull::new(ptr).map(|p| py.unchecked_downcast(gil::register_borrowed(py, p))) + } +} + +unsafe impl<'p, T> FromPyPointer<'p> for &'p mut T +where + T: PyTypeInfo, +{ + unsafe fn from_owned_ptr_or_opt(py: Python<'p>, ptr: *mut ffi::PyObject) -> Option { + NonNull::new(ptr).map(|p| py.unchecked_mut_downcast(gil::register_owned(py, p))) + } + unsafe fn from_borrowed_ptr_or_opt(py: Python<'p>, ptr: *mut ffi::PyObject) -> Option { + NonNull::new(ptr).map(|p| py.unchecked_mut_downcast(gil::register_borrowed(py, p))) + } +} + #[cfg(test)] mod test { use crate::types::PyList; diff --git a/src/instance.rs b/src/instance.rs index 90507955b25..a2cebc4bc25 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -8,10 +8,9 @@ use crate::objectprotocol::ObjectProtocol; use crate::type_object::PyTypeCreate; use crate::type_object::{PyTypeInfo, PyTypeObject}; use crate::types::PyAny; -use crate::AsPyPointer; -use crate::IntoPyPointer; -use crate::Python; -use crate::{FromPyObject, IntoPyObject, ToPyObject}; +use crate::{ + AsPyPointer, FromPyObject, FromPyPointer, IntoPyObject, IntoPyPointer, Python, ToPyObject, +}; use std::marker::PhantomData; use std::mem; use std::ops::{Deref, DerefMut}; @@ -74,15 +73,14 @@ impl<'a, T: PyTypeInfo> PyRef<'a, T> { } } -impl<'a, T> PyRef<'a, T> +impl<'p, T> PyRef<'p, T> where T: PyTypeInfo + PyTypeObject + PyTypeCreate, { - pub fn new(py: Python, value: T) -> PyResult> { + pub fn new(py: Python<'p>, value: T) -> PyResult> { let obj = T::create(py)?; obj.init(value); - let ref_ = unsafe { py.from_owned_ptr(obj.into_ptr()) }; - Ok(PyRef::from_ref(ref_)) + unsafe { Self::from_owned_ptr_or_err(py, obj.into_ptr()) } } } @@ -105,6 +103,18 @@ impl<'a, T: PyTypeInfo> Deref for PyRef<'a, T> { } } +unsafe impl<'p, T> FromPyPointer<'p> for PyRef<'p, T> +where + T: PyTypeInfo, +{ + unsafe fn from_owned_ptr_or_opt(py: Python<'p>, ptr: *mut ffi::PyObject) -> Option { + FromPyPointer::from_owned_ptr_or_opt(py, ptr).map(Self::from_ref) + } + unsafe fn from_borrowed_ptr_or_opt(py: Python<'p>, ptr: *mut ffi::PyObject) -> Option { + FromPyPointer::from_borrowed_ptr_or_opt(py, ptr).map(Self::from_ref) + } +} + /// Mutable version of [`PyRef`](struct.PyRef.html). /// # Example /// ``` @@ -137,15 +147,14 @@ impl<'a, T: PyTypeInfo> PyRefMut<'a, T> { } } -impl<'a, T> PyRefMut<'a, T> +impl<'p, T> PyRefMut<'p, T> where T: PyTypeInfo + PyTypeObject + PyTypeCreate, { - pub fn new(py: Python, value: T) -> PyResult> { + pub fn new(py: Python<'p>, value: T) -> PyResult> { let obj = T::create(py)?; obj.init(value); - let ref_ = unsafe { py.mut_from_owned_ptr(obj.into_ptr()) }; - Ok(PyRefMut::from_mut(ref_)) + unsafe { Self::from_owned_ptr_or_err(py, obj.into_ptr()) } } } @@ -174,6 +183,18 @@ impl<'a, T: PyTypeInfo> DerefMut for PyRefMut<'a, T> { } } +unsafe impl<'p, T> FromPyPointer<'p> for PyRefMut<'p, T> +where + T: PyTypeInfo, +{ + unsafe fn from_owned_ptr_or_opt(py: Python<'p>, ptr: *mut ffi::PyObject) -> Option { + FromPyPointer::from_owned_ptr_or_opt(py, ptr).map(Self::from_mut) + } + unsafe fn from_borrowed_ptr_or_opt(py: Python<'p>, ptr: *mut ffi::PyObject) -> Option { + FromPyPointer::from_borrowed_ptr_or_opt(py, ptr).map(Self::from_mut) + } +} + /// Trait implements object reference extraction from python managed pointer. pub trait AsPyRef: Sized { /// Return reference to object. diff --git a/src/lib.rs b/src/lib.rs index 6be9a6cbcae..75f56bd0dd6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -119,8 +119,8 @@ pub use crate::class::*; pub use crate::conversion::{ - AsPyPointer, FromPy, FromPyObject, IntoPy, IntoPyObject, IntoPyPointer, PyTryFrom, PyTryInto, - ToBorrowedObject, ToPyObject, + AsPyPointer, FromPy, FromPyObject, FromPyPointer, IntoPy, IntoPyObject, IntoPyPointer, + PyTryFrom, PyTryInto, ToBorrowedObject, ToPyObject, }; pub use crate::err::{PyDowncastError, PyErr, PyErrArguments, PyErrValue, PyResult}; pub use crate::gil::{init_once, GILGuard, GILPool}; diff --git a/src/python.rs b/src/python.rs index bd57ff3f33f..5e08092d631 100644 --- a/src/python.rs +++ b/src/python.rs @@ -10,7 +10,7 @@ use crate::object::PyObject; use crate::type_object::{PyTypeInfo, PyTypeObject}; use crate::types::{PyAny, PyDict, PyModule, PyType}; use crate::AsPyPointer; -use crate::{IntoPyPointer, PyTryFrom}; +use crate::{FromPyPointer, IntoPyPointer, PyTryFrom}; use std::ffi::CString; use std::marker::PhantomData; use std::os::raw::c_int; @@ -193,7 +193,7 @@ impl<'p> Python<'p> { } impl<'p> Python<'p> { - unsafe fn unchecked_downcast(self, ob: &PyAny) -> &'p T { + pub(crate) unsafe fn unchecked_downcast(self, ob: &PyAny) -> &'p T { if T::OFFSET == 0 { &*(ob as *const _ as *const T) } else { @@ -203,7 +203,7 @@ impl<'p> Python<'p> { } #[allow(clippy::cast_ref_to_mut)] // FIXME - unsafe fn unchecked_mut_downcast(self, ob: &PyAny) -> &'p mut T { + pub(crate) unsafe fn unchecked_mut_downcast(self, ob: &PyAny) -> &'p mut T { if T::OFFSET == 0 { &mut *(ob as *const _ as *mut T) } else { @@ -240,18 +240,11 @@ impl<'p> Python<'p> { /// Register `ffi::PyObject` pointer in release pool, /// and do unchecked downcast to specific type. - pub unsafe fn from_owned_ptr(self, ptr: *mut ffi::PyObject) -> &'p T where T: PyTypeInfo, { - match NonNull::new(ptr) { - Some(p) => { - let p = gil::register_owned(self, p); - self.unchecked_downcast(p) - } - None => crate::err::panic_after_error(), - } + FromPyPointer::from_owned_ptr(self, ptr) } /// Register `ffi::PyObject` pointer in release pool, @@ -260,13 +253,7 @@ impl<'p> Python<'p> { where T: PyTypeInfo, { - match NonNull::new(ptr) { - Some(p) => { - let p = gil::register_owned(self, p); - self.unchecked_mut_downcast(p) - } - None => crate::err::panic_after_error(), - } + FromPyPointer::from_owned_ptr(self, ptr) } /// Register owned `ffi::PyObject` pointer in release pool. @@ -276,13 +263,7 @@ impl<'p> Python<'p> { where T: PyTypeInfo, { - match NonNull::new(ptr) { - Some(p) => { - let p = gil::register_owned(self, p); - Ok(self.unchecked_downcast(p)) - } - None => Err(PyErr::fetch(self)), - } + FromPyPointer::from_owned_ptr_or_err(self, ptr) } /// Register owned `ffi::PyObject` pointer in release pool. @@ -292,10 +273,7 @@ impl<'p> Python<'p> { where T: PyTypeInfo, { - NonNull::new(ptr).map(|p| { - let p = gil::register_owned(self, p); - self.unchecked_downcast(p) - }) + FromPyPointer::from_owned_ptr_or_opt(self, ptr) } /// Register borrowed `ffi::PyObject` pointer in release pool. @@ -305,13 +283,7 @@ impl<'p> Python<'p> { where T: PyTypeInfo, { - match NonNull::new(ptr) { - Some(p) => { - let p = gil::register_borrowed(self, p); - self.unchecked_downcast(p) - } - None => crate::err::panic_after_error(), - } + FromPyPointer::from_borrowed_ptr(self, ptr) } /// Register borrowed `ffi::PyObject` pointer in release pool. @@ -321,13 +293,7 @@ impl<'p> Python<'p> { where T: PyTypeInfo, { - match NonNull::new(ptr) { - Some(p) => { - let p = gil::register_borrowed(self, p); - self.unchecked_mut_downcast(p) - } - None => crate::err::panic_after_error(), - } + FromPyPointer::from_borrowed_ptr(self, ptr) } /// Register borrowed `ffi::PyObject` pointer in release pool. @@ -337,13 +303,7 @@ impl<'p> Python<'p> { where T: PyTypeInfo, { - match NonNull::new(ptr) { - Some(p) => { - let p = gil::register_borrowed(self, p); - Ok(self.unchecked_downcast(p)) - } - None => Err(PyErr::fetch(self)), - } + FromPyPointer::from_borrowed_ptr_or_err(self, ptr) } /// Register borrowed `ffi::PyObject` pointer in release pool. @@ -353,10 +313,7 @@ impl<'p> Python<'p> { where T: PyTypeInfo, { - NonNull::new(ptr).map(|p| { - let p = gil::register_borrowed(self, p); - self.unchecked_downcast(p) - }) + FromPyPointer::from_borrowed_ptr_or_opt(self, ptr) } #[doc(hidden)] diff --git a/tests/test_pyself.rs b/tests/test_pyself.rs new file mode 100644 index 00000000000..ee2006817fc --- /dev/null +++ b/tests/test_pyself.rs @@ -0,0 +1,109 @@ +//! Test slf: PyRef/PyMutRef(especially, slf.into::) works +use pyo3; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyString}; +use pyo3::PyIterProtocol; +use std::collections::HashMap; + +#[macro_use] +mod common; + +/// Assumes it's a file reader or so. +/// Inspired by https://github.com/jothan/cordoba, thanks. +#[pyclass] +#[derive(Clone)] +struct Reader { + inner: HashMap, +} + +#[pymethods] +impl Reader { + fn get_iter(slf: PyRef, keys: Py) -> PyResult { + Ok(Iter { + reader: slf.into(), + keys, + idx: 0, + }) + } + fn get_iter_and_reset( + mut slf: PyRefMut, + keys: Py, + py: Python, + ) -> PyResult { + let reader = Py::new(py, slf.clone())?; + slf.inner.clear(); + Ok(Iter { + reader, + keys, + idx: 0, + }) + } +} + +#[pyclass] +struct Iter { + reader: Py, + keys: Py, + idx: usize, +} + +#[pyproto] +impl PyIterProtocol for Iter { + fn __iter__(slf: PyRefMut) -> PyResult { + let py = unsafe { Python::assume_gil_acquired() }; + Ok(slf.to_object(py)) + } + fn __next__(mut slf: PyRefMut) -> PyResult> { + let py = unsafe { Python::assume_gil_acquired() }; + match slf.keys.as_ref(py).as_bytes().get(slf.idx) { + Some(&b) => { + let res = slf + .reader + .as_ref(py) + .inner + .get(&b) + .map(|s| PyString::new(py, s).into()); + slf.idx += 1; + Ok(res) + } + None => Ok(None), + } + } +} + +#[test] +fn test_nested_iter() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let reader = [(1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e")]; + let reader = Reader { + inner: reader.iter().map(|(k, v)| (*k, v.to_string())).collect(), + } + .into_object(py); + py_assert!( + py, + reader, + "list(reader.get_iter(bytes([3, 5, 2]))) == ['c', 'e', 'b']" + ); +} + +#[test] +fn test_nested_iter_reset() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let reader = [(1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e")]; + let reader = PyRef::new( + py, + Reader { + inner: reader.iter().map(|(k, v)| (*k, v.to_string())).collect(), + }, + ) + .unwrap(); + let obj = reader.into_object(py); + py_assert!( + py, + obj, + "list(obj.get_iter_and_reset(bytes([3, 5, 2]))) == ['c', 'e', 'b']" + ); + assert!(reader.inner.is_empty()); +}