Skip to content

Commit

Permalink
Split the thread-id TLS into 2 variables
Browse files Browse the repository at this point in the history
This allows the fast path to avoid a branch which checks if the TLS
destructor for the thread ID has been registered.
  • Loading branch information
Amanieu committed Dec 12, 2022
1 parent 3472eb1 commit 6631b73
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
16 changes: 9 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ impl<T: Send> ThreadLocal<T> {

/// Returns the element for the current thread, if it exists.
pub fn get(&self) -> Option<&T> {
let thread = thread_id::get();
self.get_inner(thread)
thread_id::try_get().and_then(|thread| self.get_inner(thread))
}

/// Returns the element for the current thread, or creates it if it doesn't
Expand All @@ -212,11 +211,13 @@ impl<T: Send> ThreadLocal<T> {
where
F: FnOnce() -> Result<T, E>,
{
let thread = thread_id::get();
match self.get_inner(thread) {
Some(x) => Ok(x),
None => Ok(self.insert(thread, create()?)),
let thread = thread_id::try_get();
if let Some(thread) = thread {
if let Some(val) = self.get_inner(thread) {
return Ok(val);
}
}
Ok(self.insert(create()?))
}

fn get_inner(&self, thread: Thread) -> Option<&T> {
Expand All @@ -237,7 +238,8 @@ impl<T: Send> ThreadLocal<T> {
}

#[cold]
fn insert(&self, thread: Thread, data: T) -> &T {
fn insert(&self, data: T) -> &T {
let thread = thread_id::get();
let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);

Expand Down
44 changes: 32 additions & 12 deletions src/thread_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

use crate::POINTER_WIDTH;
use once_cell::sync::Lazy;
use std::cell::Cell;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Mutex;
Expand Down Expand Up @@ -73,25 +74,44 @@ impl Thread {
}
}

/// Wrapper around `Thread` that allocates and deallocates the ID.
struct ThreadHolder(Thread);
impl ThreadHolder {
fn new() -> ThreadHolder {
ThreadHolder(Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc()))
}
}
impl Drop for ThreadHolder {
// This is split into 2 thread-local variables so that we can check whether the
// thread is initialized without having to register a thread-local destructor.
//
// This makes the fast path smaller.
thread_local! { static THREAD: Cell<Option<Thread>> = const { Cell::new(None) }; }
thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard }; }

// Guard to ensure the thread ID is released on thread exit.
struct ThreadGuard;

impl Drop for ThreadGuard {
fn drop(&mut self) {
THREAD_ID_MANAGER.lock().unwrap().free(self.0.id);
let thread = THREAD.with(|thread| thread.get()).unwrap();
THREAD_ID_MANAGER.lock().unwrap().free(thread.id);
}
}

thread_local!(static THREAD_HOLDER: ThreadHolder = ThreadHolder::new());
/// Attempts to get the current thread if `get` has previously been
/// called.
#[inline]
pub(crate) fn try_get() -> Option<Thread> {
THREAD.with(|thread| thread.get())
}

/// Get the current thread.
/// Returns a thread ID for the current thread, allocating one if needed.
#[inline]
pub(crate) fn get() -> Thread {
THREAD_HOLDER.with(|holder| holder.0)
THREAD.with(|thread| {
if let Some(thread) = thread.get() {
thread
} else {
debug_assert!(thread.get().is_none());
let new = Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc());
thread.set(Some(new));
THREAD_GUARD.with(|_| {});
new
}
})
}

#[test]
Expand Down

0 comments on commit 6631b73

Please sign in to comment.