From 8850d5d3848da75263826c3a96e243a889f9b32c Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Sat, 3 Jun 2023 15:55:13 +0100 Subject: [PATCH] support ordering magic methods for `#[pyclass]` --- Cargo.toml | 4 + benches/bench_comparisons.rs | 70 ++++++++++ guide/src/class/protocols.md | 19 ++- newsfragments/3203.added.md | 1 + pyo3-macros-backend/src/pyimpl.rs | 57 ++++---- pyo3-macros-backend/src/pymethod.rs | 25 ++++ pytests/requirements-dev.txt | 1 + pytests/src/comparisons.rs | 111 +++++++++++++++ pytests/src/lib.rs | 3 + pytests/tests/test_comparisons.py | 171 ++++++++++++++++++++++++ src/impl_/pyclass.rs | 124 +++++++++++++++++ src/test_hygiene/pymethods.rs | 14 -- tests/ui/invalid_proto_pymethods.rs | 15 +++ tests/ui/invalid_proto_pymethods.stderr | 27 ++-- tests/ui/pyclass_send.stderr | 30 ++--- 15 files changed, 610 insertions(+), 62 deletions(-) create mode 100644 benches/bench_comparisons.rs create mode 100644 newsfragments/3203.added.md create mode 100644 pytests/src/comparisons.rs create mode 100644 pytests/tests/test_comparisons.py diff --git a/Cargo.toml b/Cargo.toml index b0d84aae0d2..e9b04aa3562 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -122,6 +122,10 @@ harness = false name = "bench_call" harness = false +[[bench]] +name = "bench_comparisons" +harness = false + [[bench]] name = "bench_err" harness = false diff --git a/benches/bench_comparisons.rs b/benches/bench_comparisons.rs new file mode 100644 index 00000000000..bfa4ac63fa4 --- /dev/null +++ b/benches/bench_comparisons.rs @@ -0,0 +1,70 @@ +use criterion::{criterion_group, criterion_main, Bencher, Criterion}; + +use pyo3::{prelude::*, pyclass::CompareOp, Python}; + +#[pyclass] +struct OrderedDunderMethods(i64); + +#[pymethods] +impl OrderedDunderMethods { + fn __lt__(&self, other: &Self) -> bool { + self.0 < other.0 + } + + fn __le__(&self, other: &Self) -> bool { + self.0 <= other.0 + } + + fn __eq__(&self, other: &Self) -> bool { + self.0 == other.0 + } + + fn __ne__(&self, other: &Self) -> bool { + self.0 != other.0 + } + + fn __gt__(&self, other: &Self) -> bool { + self.0 > other.0 + } + + fn __ge__(&self, other: &Self) -> bool { + self.0 >= other.0 + } +} + +#[pyclass] +#[derive(PartialEq, Eq, PartialOrd, Ord)] +struct OrderedRichcmp(i64); + +#[pymethods] +impl OrderedRichcmp { + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + op.matches(self.cmp(other)) + } +} + +fn bench_ordered_dunder_methods(b: &mut Bencher<'_>) { + Python::with_gil(|py| { + let obj1 = Py::new(py, OrderedDunderMethods(0)).unwrap().into_ref(py); + let obj2 = Py::new(py, OrderedDunderMethods(1)).unwrap().into_ref(py); + + b.iter(|| obj2.gt(obj1).unwrap()); + }); +} + +fn bench_ordered_richcmp(b: &mut Bencher<'_>) { + Python::with_gil(|py| { + let obj1 = Py::new(py, OrderedRichcmp(0)).unwrap().into_ref(py); + let obj2 = Py::new(py, OrderedRichcmp(1)).unwrap().into_ref(py); + + b.iter(|| obj2.gt(obj1).unwrap()); + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("ordered_dunder_methods", bench_ordered_dunder_methods); + c.bench_function("ordered_richcmp", bench_ordered_richcmp); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/guide/src/class/protocols.md b/guide/src/class/protocols.md index eca9f1f2b4b..891635c0654 100644 --- a/guide/src/class/protocols.md +++ b/guide/src/class/protocols.md @@ -58,11 +58,28 @@ given signatures should be interpreted as follows: ``` + - `__lt__(, object) -> object` + - `__le__(, object) -> object` + - `__eq__(, object) -> object` + - `__ne__(, object) -> object` + - `__gt__(, object) -> object` + - `__ge__(, object) -> object` + + The implementations of Python's "rich comparison" operators `<`, `<=`, `==`, `!=`, `>` and `>=` respectively. + + _Note that implementing any of these methods will cause Python not to generate a default `__hash__` implementation, so consider also implementing `__hash__`._ +
+ Return type + The return type will normally be `bool` or `PyResult`, however any Python object can be returned. +
+ - `__richcmp__(, object, pyo3::basic::CompareOp) -> object` - Overloads Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`). + Implements Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`) in a single method. The `CompareOp` argument indicates the comparison operation being performed. + _This method cannot be implemented in combination with any of `__lt__`, `__le__`, `__eq__`, `__ne__`, `__gt__`, or `__ge__`._ + _Note that implementing `__richcmp__` will cause Python not to generate a default `__hash__` implementation, so consider implementing `__hash__` when implementing `__richcmp__`._
Return type diff --git a/newsfragments/3203.added.md b/newsfragments/3203.added.md new file mode 100644 index 00000000000..58c0a24db64 --- /dev/null +++ b/newsfragments/3203.added.md @@ -0,0 +1 @@ +Support `__lt__`, `__le__`, `__eq__`, `__ne__`, `__gt__` and `__ge__` in `#[pymethods]` diff --git a/pyo3-macros-backend/src/pyimpl.rs b/pyo3-macros-backend/src/pyimpl.rs index d0a1b6157cf..0615371c488 100644 --- a/pyo3-macros-backend/src/pyimpl.rs +++ b/pyo3-macros-backend/src/pyimpl.rs @@ -235,41 +235,50 @@ fn add_shared_proto_slots( mut implemented_proto_fragments: HashSet, ) { macro_rules! try_add_shared_slot { - ($first:literal, $second:literal, $slot:ident) => {{ - let first_implemented = implemented_proto_fragments.remove($first); - let second_implemented = implemented_proto_fragments.remove($second); - if first_implemented || second_implemented { + ($slot:ident, $($fragments:literal),*) => {{ + let mut implemented = false; + $(implemented |= implemented_proto_fragments.remove($fragments));*; + if implemented { proto_impls.push(quote! { _pyo3::impl_::pyclass::$slot!(#ty) }) } }}; } try_add_shared_slot!( + generate_pyclass_getattro_slot, "__getattribute__", - "__getattr__", - generate_pyclass_getattro_slot + "__getattr__" ); - try_add_shared_slot!("__setattr__", "__delattr__", generate_pyclass_setattr_slot); - try_add_shared_slot!("__set__", "__delete__", generate_pyclass_setdescr_slot); - try_add_shared_slot!("__setitem__", "__delitem__", generate_pyclass_setitem_slot); - try_add_shared_slot!("__add__", "__radd__", generate_pyclass_add_slot); - try_add_shared_slot!("__sub__", "__rsub__", generate_pyclass_sub_slot); - try_add_shared_slot!("__mul__", "__rmul__", generate_pyclass_mul_slot); - try_add_shared_slot!("__mod__", "__rmod__", generate_pyclass_mod_slot); - try_add_shared_slot!("__divmod__", "__rdivmod__", generate_pyclass_divmod_slot); - try_add_shared_slot!("__lshift__", "__rlshift__", generate_pyclass_lshift_slot); - try_add_shared_slot!("__rshift__", "__rrshift__", generate_pyclass_rshift_slot); - try_add_shared_slot!("__and__", "__rand__", generate_pyclass_and_slot); - try_add_shared_slot!("__or__", "__ror__", generate_pyclass_or_slot); - try_add_shared_slot!("__xor__", "__rxor__", generate_pyclass_xor_slot); - try_add_shared_slot!("__matmul__", "__rmatmul__", generate_pyclass_matmul_slot); - try_add_shared_slot!("__truediv__", "__rtruediv__", generate_pyclass_truediv_slot); + try_add_shared_slot!(generate_pyclass_setattr_slot, "__setattr__", "__delattr__"); + try_add_shared_slot!(generate_pyclass_setdescr_slot, "__set__", "__delete__"); + try_add_shared_slot!(generate_pyclass_setitem_slot, "__setitem__", "__delitem__"); + try_add_shared_slot!(generate_pyclass_add_slot, "__add__", "__radd__"); + try_add_shared_slot!(generate_pyclass_sub_slot, "__sub__", "__rsub__"); + try_add_shared_slot!(generate_pyclass_mul_slot, "__mul__", "__rmul__"); + try_add_shared_slot!(generate_pyclass_mod_slot, "__mod__", "__rmod__"); + try_add_shared_slot!(generate_pyclass_divmod_slot, "__divmod__", "__rdivmod__"); + try_add_shared_slot!(generate_pyclass_lshift_slot, "__lshift__", "__rlshift__"); + try_add_shared_slot!(generate_pyclass_rshift_slot, "__rshift__", "__rrshift__"); + try_add_shared_slot!(generate_pyclass_and_slot, "__and__", "__rand__"); + try_add_shared_slot!(generate_pyclass_or_slot, "__or__", "__ror__"); + try_add_shared_slot!(generate_pyclass_xor_slot, "__xor__", "__rxor__"); + try_add_shared_slot!(generate_pyclass_matmul_slot, "__matmul__", "__rmatmul__"); + try_add_shared_slot!(generate_pyclass_truediv_slot, "__truediv__", "__rtruediv__"); try_add_shared_slot!( + generate_pyclass_floordiv_slot, "__floordiv__", - "__rfloordiv__", - generate_pyclass_floordiv_slot + "__rfloordiv__" + ); + try_add_shared_slot!(generate_pyclass_pow_slot, "__pow__", "__rpow__"); + try_add_shared_slot!( + generate_pyclass_richcompare_slot, + "__lt__", + "__le__", + "__eq__", + "__ne__", + "__gt__", + "__ge__" ); - try_add_shared_slot!("__pow__", "__rpow__", generate_pyclass_pow_slot); // if this assertion trips, a slot fragment has been implemented which has not been added in the // list above diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 9689d863a44..10950a82826 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -135,6 +135,12 @@ impl PyMethodKind { "__ror__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__ROR__)), "__pow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__POW__)), "__rpow__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__RPOW__)), + "__lt__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__LT__)), + "__le__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__LE__)), + "__eq__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__EQ__)), + "__ne__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__NE__)), + "__gt__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GT__)), + "__ge__" => PyMethodKind::Proto(PyMethodProtoKind::SlotFragment(&__GE__)), // Some tricky protocols which don't fit the pattern of the rest "__call__" => PyMethodKind::Proto(PyMethodProtoKind::Call), "__traverse__" => PyMethodKind::Proto(PyMethodProtoKind::Traverse), @@ -1300,6 +1306,25 @@ const __RPOW__: SlotFragmentDef = SlotFragmentDef::new("__rpow__", &[Ty::Object, .extract_error_mode(ExtractErrorMode::NotImplemented) .ret_ty(Ty::Object); +const __LT__: SlotFragmentDef = SlotFragmentDef::new("__lt__", &[Ty::Object]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .ret_ty(Ty::Object); +const __LE__: SlotFragmentDef = SlotFragmentDef::new("__le__", &[Ty::Object]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .ret_ty(Ty::Object); +const __EQ__: SlotFragmentDef = SlotFragmentDef::new("__eq__", &[Ty::Object]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .ret_ty(Ty::Object); +const __NE__: SlotFragmentDef = SlotFragmentDef::new("__ne__", &[Ty::Object]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .ret_ty(Ty::Object); +const __GT__: SlotFragmentDef = SlotFragmentDef::new("__gt__", &[Ty::Object]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .ret_ty(Ty::Object); +const __GE__: SlotFragmentDef = SlotFragmentDef::new("__ge__", &[Ty::Object]) + .extract_error_mode(ExtractErrorMode::NotImplemented) + .ret_ty(Ty::Object); + fn extract_proto_arguments( py: &syn::Ident, spec: &FnSpec<'_>, diff --git a/pytests/requirements-dev.txt b/pytests/requirements-dev.txt index d1fa05414e9..9901d8b2048 100644 --- a/pytests/requirements-dev.txt +++ b/pytests/requirements-dev.txt @@ -2,3 +2,4 @@ hypothesis>=3.55 pytest>=6.0 pytest-benchmark>=3.4 psutil>=5.6 +typing_extensions>=4.0.0 diff --git a/pytests/src/comparisons.rs b/pytests/src/comparisons.rs new file mode 100644 index 00000000000..d8c2f5a6a52 --- /dev/null +++ b/pytests/src/comparisons.rs @@ -0,0 +1,111 @@ +use pyo3::prelude::*; +use pyo3::{types::PyModule, Python}; + +#[pyclass] +struct Eq(i64); + +#[pymethods] +impl Eq { + #[new] + fn new(value: i64) -> Self { + Self(value) + } + + fn __eq__(&self, other: &Self) -> bool { + self.0 == other.0 + } + + fn __ne__(&self, other: &Self) -> bool { + self.0 != other.0 + } +} + +#[pyclass] +struct EqDefaultNe(i64); + +#[pymethods] +impl EqDefaultNe { + #[new] + fn new(value: i64) -> Self { + Self(value) + } + + fn __eq__(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +#[pyclass] +struct Ordered(i64); + +#[pymethods] +impl Ordered { + #[new] + fn new(value: i64) -> Self { + Self(value) + } + + fn __lt__(&self, other: &Self) -> bool { + self.0 < other.0 + } + + fn __le__(&self, other: &Self) -> bool { + self.0 <= other.0 + } + + fn __eq__(&self, other: &Self) -> bool { + self.0 == other.0 + } + + fn __ne__(&self, other: &Self) -> bool { + self.0 != other.0 + } + + fn __gt__(&self, other: &Self) -> bool { + self.0 > other.0 + } + + fn __ge__(&self, other: &Self) -> bool { + self.0 >= other.0 + } +} + +#[pyclass] +struct OrderedDefaultNe(i64); + +#[pymethods] +impl OrderedDefaultNe { + #[new] + fn new(value: i64) -> Self { + Self(value) + } + + fn __lt__(&self, other: &Self) -> bool { + self.0 < other.0 + } + + fn __le__(&self, other: &Self) -> bool { + self.0 <= other.0 + } + + fn __eq__(&self, other: &Self) -> bool { + self.0 == other.0 + } + + fn __gt__(&self, other: &Self) -> bool { + self.0 > other.0 + } + + fn __ge__(&self, other: &Self) -> bool { + self.0 >= other.0 + } +} + +#[pymodule] +pub fn comparisons(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} diff --git a/pytests/src/lib.rs b/pytests/src/lib.rs index d9d6e3f949a..fd96f72526d 100644 --- a/pytests/src/lib.rs +++ b/pytests/src/lib.rs @@ -3,6 +3,7 @@ use pyo3::types::PyDict; use pyo3::wrap_pymodule; pub mod buf_and_str; +pub mod comparisons; pub mod datetime; pub mod deprecated_pyfunctions; pub mod dict_iter; @@ -19,6 +20,7 @@ pub mod subclassing; fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> { #[cfg(not(Py_LIMITED_API))] m.add_wrapped(wrap_pymodule!(buf_and_str::buf_and_str))?; + m.add_wrapped(wrap_pymodule!(comparisons::comparisons))?; #[cfg(not(Py_LIMITED_API))] m.add_wrapped(wrap_pymodule!(datetime::datetime))?; m.add_wrapped(wrap_pymodule!( @@ -40,6 +42,7 @@ fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> { let sys = PyModule::import(py, "sys")?; let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?; sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?; + sys_modules.set_item("pyo3_pytests.comparisons", m.getattr("comparisons")?)?; sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?; sys_modules.set_item( "pyo3_pytests.deprecated_pyfunctions", diff --git a/pytests/tests/test_comparisons.py b/pytests/tests/test_comparisons.py new file mode 100644 index 00000000000..54bb7aafcbd --- /dev/null +++ b/pytests/tests/test_comparisons.py @@ -0,0 +1,171 @@ +from typing import Type, Union + +import pytest +from pyo3_pytests.comparisons import Eq, EqDefaultNe, Ordered, OrderedDefaultNe +from typing_extensions import Self + + +class PyEq: + def __init__(self, x: int) -> None: + self.x = x + + def __eq__(self, other: Self) -> bool: + return self.x == other.x + + def __ne__(self, other: Self) -> bool: + return self.x != other.x + + +@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python")) +def test_eq(ty: Type[Union[Eq, PyEq]]): + a = ty(0) + b = ty(0) + c = ty(1) + + assert a == b + assert a != c + + assert b == a + assert b != c + + with pytest.raises(TypeError): + assert a <= b + + with pytest.raises(TypeError): + assert a >= b + + with pytest.raises(TypeError): + assert a < c + + with pytest.raises(TypeError): + assert c > a + + +class PyEqDefaultNe: + def __init__(self, x: int) -> None: + self.x = x + + def __eq__(self, other: Self) -> bool: + return self.x == other.x + + +@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python")) +def test_eq_default_ne(ty: Type[Union[EqDefaultNe, PyEqDefaultNe]]): + a = ty(0) + b = ty(0) + c = ty(1) + + assert a == b + assert a != c + + assert b == a + assert b != c + + with pytest.raises(TypeError): + assert a <= b + + with pytest.raises(TypeError): + assert a >= b + + with pytest.raises(TypeError): + assert a < c + + with pytest.raises(TypeError): + assert c > a + + +class PyOrdered: + def __init__(self, x: int) -> None: + self.x = x + + def __lt__(self, other: Self) -> bool: + return self.x < other.x + + def __le__(self, other: Self) -> bool: + return self.x <= other.x + + def __eq__(self, other: Self) -> bool: + return self.x == other.x + + def __ne__(self, other: Self) -> bool: + return self.x != other.x + + def __gt__(self, other: Self) -> bool: + return self.x >= other.x + + def __ge__(self, other: Self) -> bool: + return self.x >= other.x + + +@pytest.mark.parametrize("ty", (Ordered, PyOrdered), ids=("rust", "python")) +def test_ordered(ty: Type[Union[Ordered, PyOrdered]]): + a = ty(0) + b = ty(0) + c = ty(1) + + assert a == b + assert a <= b + assert a >= b + assert a != c + assert a <= c + + assert b == a + assert b <= a + assert b >= a + assert b != c + assert b <= c + + assert c != a + assert c != b + assert c > a + assert c >= a + assert c > b + assert c >= b + + +class PyOrderedDefaultNe: + def __init__(self, x: int) -> None: + self.x = x + + def __lt__(self, other: Self) -> bool: + return self.x < other.x + + def __le__(self, other: Self) -> bool: + return self.x <= other.x + + def __eq__(self, other: Self) -> bool: + return self.x == other.x + + def __gt__(self, other: Self) -> bool: + return self.x >= other.x + + def __ge__(self, other: Self) -> bool: + return self.x >= other.x + + +@pytest.mark.parametrize( + "ty", (OrderedDefaultNe, PyOrderedDefaultNe), ids=("rust", "python") +) +def test_ordered_default_ne(ty: Type[Union[OrderedDefaultNe, PyOrderedDefaultNe]]): + a = ty(0) + b = ty(0) + c = ty(1) + + assert a == b + assert a <= b + assert a >= b + assert a != c + assert a <= c + + assert b == a + assert b <= a + assert b >= a + assert b != c + assert b <= c + + assert c != a + assert c != b + assert c > a + assert c >= a + assert c > b + assert c >= b diff --git a/src/impl_/pyclass.rs b/src/impl_/pyclass.rs index 312144a69c0..2ba722d49cb 100644 --- a/src/impl_/pyclass.rs +++ b/src/impl_/pyclass.rs @@ -757,6 +757,130 @@ macro_rules! generate_pyclass_pow_slot { } pub use generate_pyclass_pow_slot; +slot_fragment_trait! { + PyClass__lt__SlotFragment, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn __lt__( + self, + _py: Python<'_>, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + Ok(ffi::_Py_NewRef(ffi::Py_NotImplemented())) + } +} + +slot_fragment_trait! { + PyClass__le__SlotFragment, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn __le__( + self, + _py: Python<'_>, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + Ok(ffi::_Py_NewRef(ffi::Py_NotImplemented())) + } +} + +slot_fragment_trait! { + PyClass__eq__SlotFragment, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn __eq__( + self, + _py: Python<'_>, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + Ok(ffi::_Py_NewRef(ffi::Py_NotImplemented())) + } +} + +slot_fragment_trait! { + PyClass__ne__SlotFragment, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn __ne__( + self, + _py: Python<'_>, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + Ok(ffi::_Py_NewRef(ffi::Py_NotImplemented())) + } +} + +slot_fragment_trait! { + PyClass__gt__SlotFragment, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn __gt__( + self, + _py: Python<'_>, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + Ok(ffi::_Py_NewRef(ffi::Py_NotImplemented())) + } +} + +slot_fragment_trait! { + PyClass__ge__SlotFragment, + + /// # Safety: _slf and _other must be valid non-null Python objects + #[inline] + unsafe fn __ge__( + self, + _py: Python<'_>, + _slf: *mut ffi::PyObject, + _other: *mut ffi::PyObject, + ) -> PyResult<*mut ffi::PyObject> { + Ok(ffi::_Py_NewRef(ffi::Py_NotImplemented())) + } +} + +#[doc(hidden)] +#[macro_export] +macro_rules! generate_pyclass_richcompare_slot { + ($cls:ty) => {{ + impl $cls { + #[allow(non_snake_case)] + unsafe extern "C" fn __pymethod___richcmp____( + slf: *mut $crate::ffi::PyObject, + other: *mut $crate::ffi::PyObject, + op: ::std::os::raw::c_int, + ) -> *mut $crate::ffi::PyObject { + $crate::impl_::trampoline::richcmpfunc(slf, other, op, |py, slf, other, op| { + use $crate::class::basic::CompareOp; + use $crate::impl_::pyclass::*; + let collector = PyClassImplCollector::<$cls>::new(); + match CompareOp::from_raw(op).expect("invalid compareop") { + CompareOp::Lt => collector.__lt__(py, slf, other), + CompareOp::Le => collector.__le__(py, slf, other), + CompareOp::Eq => collector.__eq__(py, slf, other), + CompareOp::Ne => collector.__ne__(py, slf, other), + CompareOp::Gt => collector.__gt__(py, slf, other), + CompareOp::Ge => collector.__ge__(py, slf, other), + } + }) + } + } + $crate::ffi::PyType_Slot { + slot: $crate::ffi::Py_tp_richcompare, + pfunc: <$cls>::__pymethod___richcmp____ as $crate::ffi::richcmpfunc as _, + } + }}; +} +pub use generate_pyclass_richcompare_slot; + /// Implements a freelist. /// /// Do not implement this trait manually. Instead, use `#[pyclass(freelist = N)]` diff --git a/src/test_hygiene/pymethods.rs b/src/test_hygiene/pymethods.rs index ba0fcdff2b8..8e5bce8eefe 100644 --- a/src/test_hygiene/pymethods.rs +++ b/src/test_hygiene/pymethods.rs @@ -398,13 +398,6 @@ impl Dummy { // Dunder methods invented for protocols - fn __richcmp__( - &self, - other: &Self, - op: crate::class::basic::CompareOp, - ) -> crate::PyResult { - ::std::result::Result::Ok(false) - } // PyGcProtocol // Buffer protocol? } @@ -797,13 +790,6 @@ impl Dummy { // Dunder methods invented for protocols - fn __richcmp__( - &self, - other: &Self, - op: crate::class::basic::CompareOp, - ) -> crate::PyResult { - ::std::result::Result::Ok(false) - } // PyGcProtocol // Buffer protocol? } diff --git a/tests/ui/invalid_proto_pymethods.rs b/tests/ui/invalid_proto_pymethods.rs index d410bc80ea6..d370c4fddb5 100644 --- a/tests/ui/invalid_proto_pymethods.rs +++ b/tests/ui/invalid_proto_pymethods.rs @@ -4,6 +4,7 @@ //! so that the function names can describe the edge case to be rejected. use pyo3::prelude::*; +use pyo3::pyclass::CompareOp; #[pyclass] struct MyClass {} @@ -48,4 +49,18 @@ impl MyClass { } } +#[pyclass] +struct EqAndRichcmp; + +#[pymethods] +impl EqAndRichcmp { + fn __eq__(&self, other: &Self) -> bool { + true + } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + true + } +} + fn main() {} diff --git a/tests/ui/invalid_proto_pymethods.stderr b/tests/ui/invalid_proto_pymethods.stderr index 4defab917b9..275a6b93c46 100644 --- a/tests/ui/invalid_proto_pymethods.stderr +++ b/tests/ui/invalid_proto_pymethods.stderr @@ -1,23 +1,34 @@ error: Expected 1 arguments, got 0 - --> tests/ui/invalid_proto_pymethods.rs:18:8 + --> tests/ui/invalid_proto_pymethods.rs:19:8 | -18 | fn truediv_expects_one_argument(&self) -> PyResult<()> { +19 | fn truediv_expects_one_argument(&self) -> PyResult<()> { | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ error: Expected 1 arguments, got 0 - --> tests/ui/invalid_proto_pymethods.rs:26:8 + --> tests/ui/invalid_proto_pymethods.rs:27:8 | -26 | fn truediv_expects_one_argument_py(&self, _py: Python<'_>) -> PyResult<()> { +27 | fn truediv_expects_one_argument_py(&self, _py: Python<'_>) -> PyResult<()> { | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ error: `signature` cannot be used with magic method `__bool__` - --> tests/ui/invalid_proto_pymethods.rs:37:31 + --> tests/ui/invalid_proto_pymethods.rs:38:31 | -37 | #[pyo3(name = "__bool__", signature = ())] +38 | #[pyo3(name = "__bool__", signature = ())] | ^^^^^^^^^ error: `text_signature` cannot be used with magic method `__bool__` - --> tests/ui/invalid_proto_pymethods.rs:45:31 + --> tests/ui/invalid_proto_pymethods.rs:46:31 | -45 | #[pyo3(name = "__bool__", text_signature = "")] +46 | #[pyo3(name = "__bool__", text_signature = "")] | ^^^^^^^^^^^^^^ + +error[E0592]: duplicate definitions with name `__pymethod___richcmp____` + --> tests/ui/invalid_proto_pymethods.rs:55:1 + | +55 | #[pymethods] + | ^^^^^^^^^^^^ + | | + | duplicate definitions for `__pymethod___richcmp____` + | other definition for `__pymethod___richcmp____` + | + = note: this error originates in the macro `_pyo3::impl_::pyclass::generate_pyclass_richcompare_slot` which comes from the expansion of the attribute macro `pymethods` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/ui/pyclass_send.stderr b/tests/ui/pyclass_send.stderr index 287430ac078..ace1775dfab 100644 --- a/tests/ui/pyclass_send.stderr +++ b/tests/ui/pyclass_send.stderr @@ -1,18 +1,18 @@ error[E0277]: `Rc` cannot be sent between threads safely - --> tests/ui/pyclass_send.rs:4:1 - | -4 | #[pyclass] - | ^^^^^^^^^^ `Rc` cannot be sent between threads safely - | - = help: within `NotThreadSafe`, the trait `Send` is not implemented for `Rc` + --> tests/ui/pyclass_send.rs:4:1 + | +4 | #[pyclass] + | ^^^^^^^^^^ `Rc` cannot be sent between threads safely + | + = help: within `NotThreadSafe`, the trait `Send` is not implemented for `Rc` note: required because it appears within the type `NotThreadSafe` - --> tests/ui/pyclass_send.rs:5:8 - | -5 | struct NotThreadSafe { - | ^^^^^^^^^^^^^ + --> tests/ui/pyclass_send.rs:5:8 + | +5 | struct NotThreadSafe { + | ^^^^^^^^^^^^^ note: required by a bound in `ThreadCheckerStub` - --> src/impl_/pyclass.rs - | - | pub struct ThreadCheckerStub(PhantomData); - | ^^^^ required by this bound in `ThreadCheckerStub` - = note: this error originates in the attribute macro `pyclass` (in Nightly builds, run with -Z macro-backtrace for more info) + --> src/impl_/pyclass.rs + | + | pub struct ThreadCheckerStub(PhantomData); + | ^^^^ required by this bound in `ThreadCheckerStub` + = note: this error originates in the attribute macro `pyclass` (in Nightly builds, run with -Z macro-backtrace for more info)