diff --git a/lock_api/src/remutex.rs b/lock_api/src/remutex.rs index a2f81852..a7146b1a 100644 --- a/lock_api/src/remutex.rs +++ b/lock_api/src/remutex.rs @@ -183,8 +183,10 @@ impl RawReentrantMutex { if self.lock_count.get() == 1 { let id = self.owner.load(Ordering::Relaxed); self.owner.store(0, Ordering::Relaxed); + self.lock_count.set(0); self.mutex.bump(); self.owner.store(id, Ordering::Relaxed); + self.lock_count.set(1); } } } diff --git a/src/remutex.rs b/src/remutex.rs index 10379230..a925369e 100644 --- a/src/remutex.rs +++ b/src/remutex.rs @@ -71,9 +71,11 @@ pub type MappedReentrantMutexGuard<'a, T> = #[cfg(test)] mod tests { use crate::ReentrantMutex; + use crate::ReentrantMutexGuard; use std::cell::RefCell; use std::sync::Arc; use std::thread; + use std::sync::mpsc::channel; #[cfg(feature = "serde")] use bincode::{deserialize, serialize}; @@ -134,6 +136,26 @@ mod tests { assert_eq!(format!("{:?}", mutex), "ReentrantMutex { data: [0, 10] }"); } + #[test] + fn test_reentrant_mutex_bump() { + let mutex = Arc::new(ReentrantMutex::new(())); + let mutex2 = mutex.clone(); + + let mut guard = mutex.lock(); + + let (tx, rx) = channel(); + + thread::spawn(move || { + let _guard = mutex2.lock(); + tx.send(()).unwrap(); + }); + + // `bump()` repeatedly until the thread starts up and requests the lock + while rx.try_recv().is_err() { + ReentrantMutexGuard::bump(&mut guard); + } + } + #[cfg(feature = "serde")] #[test] fn test_serde() {