From 8ef1d1ea98509b0982ea8aab1b4fafe566be7495 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Fri, 22 Dec 2023 12:01:40 +0100 Subject: [PATCH] Turn calls of __traverse__ into no-ops for unsendable pyclass if on the wrong thread Adds a "threadsafe" variant of `PyCell::try_borrow` which will fail instead of panicking if called on the wrong thread and use it in `call_traverse` to turn GC traversals of unsendable pyclasses into no-ops if on the wrong thread. This can imply leaking the underlying resource if the originator thread has already exited so that the GC will never run there again, but it does avoid hard aborts as we cannot raise an exception from within `call_traverse`. --- newsfragments/3689.changed.md | 1 + src/impl_/pyclass.rs | 11 +++++++++++ src/impl_/pymethods.rs | 2 +- src/pycell.rs | 18 ++++++++++++++++++ 4 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 newsfragments/3689.changed.md diff --git a/newsfragments/3689.changed.md b/newsfragments/3689.changed.md new file mode 100644 index 00000000000..d9eca98def4 --- /dev/null +++ b/newsfragments/3689.changed.md @@ -0,0 +1 @@ +Calls to `__traverse__` become no-ops for unsendable pyclasses if on the wrong thread, thereby avoid hard aborts at the cost of potential leakage. diff --git a/src/impl_/pyclass.rs b/src/impl_/pyclass.rs index 3941dfcb3e7..5ee67dc998d 100644 --- a/src/impl_/pyclass.rs +++ b/src/impl_/pyclass.rs @@ -1013,6 +1013,7 @@ impl PyClassNewTextSignature for &'_ PyClassImplCollector { #[doc(hidden)] pub trait PyClassThreadChecker: Sized { fn ensure(&self); + fn check(&self) -> bool; fn can_drop(&self, py: Python<'_>) -> bool; fn new() -> Self; private_decl! {} @@ -1028,6 +1029,9 @@ pub struct SendablePyClass(PhantomData); impl PyClassThreadChecker for SendablePyClass { fn ensure(&self) {} + fn check(&self) -> bool { + true + } fn can_drop(&self, _py: Python<'_>) -> bool { true } @@ -1053,6 +1057,10 @@ impl ThreadCheckerImpl { ); } + fn check(&self) -> bool { + thread::current().id() == self.0 + } + fn can_drop(&self, py: Python<'_>, type_name: &'static str) -> bool { if thread::current().id() != self.0 { PyRuntimeError::new_err(format!( @@ -1071,6 +1079,9 @@ impl PyClassThreadChecker for ThreadCheckerImpl { fn ensure(&self) { self.ensure(std::any::type_name::()); } + fn check(&self) -> bool { + self.check() + } fn can_drop(&self, py: Python<'_>) -> bool { self.can_drop(py, std::any::type_name::()) } diff --git a/src/impl_/pymethods.rs b/src/impl_/pymethods.rs index f2d816bba8d..e403aa23c79 100644 --- a/src/impl_/pymethods.rs +++ b/src/impl_/pymethods.rs @@ -269,7 +269,7 @@ where let py = Python::assume_gil_acquired(); let slf = py.from_borrowed_ptr::>(slf); - let borrow = slf.try_borrow(); + let borrow = slf.try_borrow_threadsafe(); let visit = PyVisit::from_raw(visit, arg, py); let retval = if let Ok(borrow) = borrow { diff --git a/src/pycell.rs b/src/pycell.rs index 3bc80a7eb07..bde95ad8313 100644 --- a/src/pycell.rs +++ b/src/pycell.rs @@ -351,6 +351,14 @@ impl PyCell { .map(|_| PyRef { inner: self }) } + /// Variant of [`try_borrow`][Self::try_borrow] which fails instead of panicking if called from the wrong thread + pub(crate) fn try_borrow_threadsafe(&self) -> Result, PyBorrowError> { + self.check_threadsafe()?; + self.borrow_checker() + .try_borrow() + .map(|_| PyRef { inner: self }) + } + /// Mutably borrows the value `T`, returning an error if the value is currently borrowed. /// This borrow lasts as long as the returned `PyRefMut` exists. /// @@ -975,6 +983,7 @@ impl From for PyErr { #[doc(hidden)] pub trait PyCellLayout: PyLayout { fn ensure_threadsafe(&self); + fn check_threadsafe(&self) -> Result<(), PyBorrowError>; /// Implementation of tp_dealloc. /// # Safety /// - slf must be a valid pointer to an instance of a T or a subclass. @@ -988,6 +997,9 @@ where T: PyTypeInfo, { fn ensure_threadsafe(&self) {} + fn check_threadsafe(&self) -> Result<(), PyBorrowError> { + Ok(()) + } unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) { let type_obj = T::type_object_raw(py); // For `#[pyclass]` types which inherit from PyAny, we can just call tp_free @@ -1025,6 +1037,12 @@ where self.contents.thread_checker.ensure(); self.ob_base.ensure_threadsafe(); } + fn check_threadsafe(&self) -> Result<(), PyBorrowError> { + if !self.contents.thread_checker.check() { + return Err(PyBorrowError { _private: () }); + } + self.ob_base.check_threadsafe() + } unsafe fn tp_dealloc(py: Python<'_>, slf: *mut ffi::PyObject) { // Safety: Python only calls tp_dealloc when no references to the object remain. let cell = &mut *(slf as *mut PyCell);