Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow #[new] to return existing instances #3287

Merged
merged 1 commit into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ impl Nonzero {
}
```

If you want to return an existing object (for example, because your `new`
method caches the values it returns), `new` can return `pyo3::Py<Self>`.

As you can see, the Rust method name is not important here; this way you can
still, use `new()` for a Rust-level constructor.

Expand Down
1 change: 1 addition & 0 deletions newsfragments/3287.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`#[new]` methods may now return `Py<Self>` in order to return existing instances
alex marked this conversation as resolved.
Show resolved Hide resolved
31 changes: 24 additions & 7 deletions src/pyclass_init.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Contains initialization utilities for `#[pyclass]`.
use crate::callback::IntoPyCallbackOutput;
use crate::impl_::pyclass::{PyClassBaseType, PyClassDict, PyClassThreadChecker, PyClassWeakRef};
use crate::{ffi, PyCell, PyClass, PyErr, PyResult, Python};
use crate::{ffi, IntoPyPointer, Py, PyCell, PyClass, PyErr, PyResult, Python};
use crate::{
ffi::PyTypeObject,
pycell::{
Expand Down Expand Up @@ -134,17 +134,22 @@ impl<T: PyTypeInfo> PyObjectInit<T> for PyNativeTypeInitializer<T> {
/// );
/// });
/// ```
pub struct PyClassInitializer<T: PyClass> {
init: T,
super_init: <T::BaseType as PyClassBaseType>::Initializer,
pub struct PyClassInitializer<T: PyClass>(PyClassInitializerImpl<T>);

enum PyClassInitializerImpl<T: PyClass> {
Existing(Py<T>),
New {
init: T,
super_init: <T::BaseType as PyClassBaseType>::Initializer,
},
}

impl<T: PyClass> PyClassInitializer<T> {
/// Constructs a new initializer from value `T` and base class' initializer.
///
/// It is recommended to use `add_subclass` instead of this method for most usage.
pub fn new(init: T, super_init: <T::BaseType as PyClassBaseType>::Initializer) -> Self {
Self { init, super_init }
Self(PyClassInitializerImpl::New { init, super_init })
}

/// Constructs a new initializer from an initializer for the base class.
Expand Down Expand Up @@ -242,13 +247,18 @@ impl<T: PyClass> PyObjectInit<T> for PyClassInitializer<T> {
contents: MaybeUninit<PyCellContents<T>>,
}

let obj = self.super_init.into_new_object(py, subtype)?;
let (init, super_init) = match self.0 {
PyClassInitializerImpl::Existing(value) => return Ok(value.into_ptr()),
PyClassInitializerImpl::New { init, super_init } => (init, super_init),
};

let obj = super_init.into_new_object(py, subtype)?;

let cell: *mut PartiallyInitializedPyCell<T> = obj as _;
std::ptr::write(
(*cell).contents.as_mut_ptr(),
PyCellContents {
value: ManuallyDrop::new(UnsafeCell::new(self.init)),
value: ManuallyDrop::new(UnsafeCell::new(init)),
borrow_checker: <T::PyClassMutability as PyClassMutability>::Storage::new(),
thread_checker: T::ThreadChecker::new(),
dict: T::Dict::INIT,
Expand Down Expand Up @@ -284,6 +294,13 @@ where
}
}

impl<T: PyClass> From<Py<T>> for PyClassInitializer<T> {
#[inline]
fn from(value: Py<T>) -> PyClassInitializer<T> {
PyClassInitializer(PyClassInitializerImpl::Existing(value))
}
}

// Implementation used by proc macros to allow anything convertible to PyClassInitializer<T> to be
// the return value of pyclass #[new] method (optionally wrapped in `Result<U, E>`).
impl<T, U> IntoPyCallbackOutput<PyClassInitializer<T>> for U
Expand Down
60 changes: 60 additions & 0 deletions tests/test_class_new.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::sync::GILOnceCell;
use pyo3::types::IntoPyDict;

#[pyclass]
Expand Down Expand Up @@ -204,3 +205,62 @@ fn new_with_custom_error() {
assert_eq!(err.to_string(), "ValueError: custom error");
});
}

#[pyclass]
struct NewExisting {
#[pyo3(get)]
num: usize,
}

#[pymethods]
impl NewExisting {
#[new]
fn new(py: pyo3::Python<'_>, val: usize) -> pyo3::Py<NewExisting> {
static PRE_BUILT: GILOnceCell<[pyo3::Py<NewExisting>; 2]> = GILOnceCell::new();
let existing = PRE_BUILT.get_or_init(py, || {
[
pyo3::PyCell::new(py, NewExisting { num: 0 })
.unwrap()
.into(),
pyo3::PyCell::new(py, NewExisting { num: 1 })
.unwrap()
.into(),
]
});

if val < existing.len() {
return existing[val].clone_ref(py);
}

pyo3::PyCell::new(py, NewExisting { num: val })
.unwrap()
.into()
}
}

#[test]
fn test_new_existing() {
Python::with_gil(|py| {
let typeobj = py.get_type::<NewExisting>();

let obj1 = typeobj.call1((0,)).unwrap();
let obj2 = typeobj.call1((0,)).unwrap();
let obj3 = typeobj.call1((1,)).unwrap();
let obj4 = typeobj.call1((1,)).unwrap();
let obj5 = typeobj.call1((2,)).unwrap();
let obj6 = typeobj.call1((2,)).unwrap();

assert!(obj1.getattr("num").unwrap().extract::<u32>().unwrap() == 0);
assert!(obj2.getattr("num").unwrap().extract::<u32>().unwrap() == 0);
assert!(obj3.getattr("num").unwrap().extract::<u32>().unwrap() == 1);
assert!(obj4.getattr("num").unwrap().extract::<u32>().unwrap() == 1);
assert!(obj5.getattr("num").unwrap().extract::<u32>().unwrap() == 2);
assert!(obj6.getattr("num").unwrap().extract::<u32>().unwrap() == 2);

assert!(obj1.is(obj2));
assert!(obj3.is(obj4));
assert!(!obj1.is(obj3));
assert!(!obj1.is(obj5));
assert!(!obj5.is(obj6));
});
}