Skip to content

Commit

Permalink
runtime: use Arc::increment_strong_count instead of mem::forget
Browse files Browse the repository at this point in the history
  • Loading branch information
Darksonn committed Jul 16, 2023
1 parent 05feb2b commit d2a1d5c
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 28 deletions.
18 changes: 5 additions & 13 deletions tokio/src/runtime/park.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ impl UnparkThread {
use crate::loom::thread::AccessError;
use std::future::Future;
use std::marker::PhantomData;
use std::mem;
use std::rc::Rc;
use std::task::{RawWaker, RawWakerVTable, Waker};

Expand Down Expand Up @@ -317,16 +316,12 @@ unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker {
}

unsafe fn clone(raw: *const ()) -> RawWaker {
let unparker = Inner::from_raw(raw);

// Increment the ref count
mem::forget(unparker.clone());

unparker_to_raw_waker(unparker)
Arc::increment_strong_count(raw as *const Inner);
unparker_to_raw_waker(Inner::from_raw(raw))
}

unsafe fn drop_waker(raw: *const ()) {
let _ = Inner::from_raw(raw);
drop(Inner::from_raw(raw));
}

unsafe fn wake(raw: *const ()) {
Expand All @@ -335,11 +330,8 @@ unsafe fn wake(raw: *const ()) {
}

unsafe fn wake_by_ref(raw: *const ()) {
let unparker = Inner::from_raw(raw);
unparker.unpark();

// We don't actually own a reference to the unparker
mem::forget(unparker);
let raw = raw as *const Inner;
(*raw).unpark();
}

#[cfg(loom)]
Expand Down
2 changes: 1 addition & 1 deletion tokio/src/runtime/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ cfg_loom! {

// Make sure debug assertions are enabled
#[cfg(not(debug_assertions))]
compiler_error!("these tests require debug assertions to be enabled");
compile_error!("these tests require debug assertions to be enabled");
}

cfg_not_loom! {
Expand Down
9 changes: 4 additions & 5 deletions tokio/src/sync/tests/notify.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::sync::Notify;
use std::future::Future;
use std::mem::ManuallyDrop;
use std::sync::Arc;
use std::task::{Context, RawWaker, RawWakerVTable, Waker};

Expand All @@ -12,16 +11,16 @@ fn notify_clones_waker_before_lock() {
const VTABLE: &RawWakerVTable = &RawWakerVTable::new(clone_w, wake, wake_by_ref, drop_w);

unsafe fn clone_w(data: *const ()) -> RawWaker {
let arc = ManuallyDrop::new(Arc::<Notify>::from_raw(data as *const Notify));
let ptr = data as *const Notify;
Arc::<Notify>::increment_strong_count(ptr);
// Or some other arbitrary code that shouldn't be executed while the
// Notify wait list is locked.
arc.notify_one();
let _arc_clone: ManuallyDrop<_> = arc.clone();
(*ptr).notify_one();
RawWaker::new(data, VTABLE)
}

unsafe fn drop_w(data: *const ()) {
let _ = Arc::<Notify>::from_raw(data as *const Notify);
drop(Arc::<Notify>::from_raw(data as *const Notify));
}

unsafe fn wake(_data: *const ()) {
Expand Down
10 changes: 1 addition & 9 deletions tokio/src/util/wake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,8 @@ fn waker_vtable<W: Wake>() -> &'static RawWakerVTable {
)
}

unsafe fn inc_ref_count<T: Wake>(data: *const ()) {
// Retain Arc, but don't touch refcount by wrapping in ManuallyDrop
let arc = ManuallyDrop::new(Arc::<T>::from_raw(data as *const T));

// Now increase refcount, but don't drop new refcount either
let _arc_clone: ManuallyDrop<_> = arc.clone();
}

unsafe fn clone_arc_raw<T: Wake>(data: *const ()) -> RawWaker {
inc_ref_count::<T>(data);
Arc::<T>::increment_strong_count(data as *const T);
RawWaker::new(data, waker_vtable::<T>())
}

Expand Down

0 comments on commit d2a1d5c

Please sign in to comment.