Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split the thread-id TLS into 2 variables #44

Merged
merged 2 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@ license = "MIT OR Apache-2.0"
repository = "https://github.com/Amanieu/thread_local-rs"
readme = "README.md"
keywords = ["thread_local", "concurrent", "thread"]
edition = "2018"
edition = "2021"

[features]
# this feature provides performance improvements using nightly features
nightly = []

[badges]
travis-ci = { repository = "Amanieu/thread_local-rs" }

[dependencies]
once_cell = "1.5.2"
# this is required to gate `nightly` related code paths
cfg-if = "1.0.0"

# This is actually a dev-dependency, see https://github.com/rust-lang/cargo/issues/1596
criterion = { version = "0.4.0", optional = true }
Expand Down
17 changes: 10 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@

#![warn(missing_docs)]
#![allow(clippy::mutex_atomic)]
#![cfg_attr(feature = "nightly", feature(thread_local))]

mod cached;
mod thread_id;
Expand Down Expand Up @@ -189,8 +190,7 @@ impl<T: Send> ThreadLocal<T> {

/// Returns the element for the current thread, if it exists.
pub fn get(&self) -> Option<&T> {
let thread = thread_id::get();
self.get_inner(thread)
thread_id::try_get().and_then(|thread| self.get_inner(thread))
}

/// Returns the element for the current thread, or creates it if it doesn't
Expand All @@ -212,11 +212,13 @@ impl<T: Send> ThreadLocal<T> {
where
F: FnOnce() -> Result<T, E>,
{
let thread = thread_id::get();
match self.get_inner(thread) {
Some(x) => Ok(x),
None => Ok(self.insert(thread, create()?)),
let thread = thread_id::try_get();
if let Some(thread) = thread {
if let Some(val) = self.get_inner(thread) {
return Ok(val);
}
}
Ok(self.insert(create()?))
}

fn get_inner(&self, thread: Thread) -> Option<&T> {
Expand All @@ -237,7 +239,8 @@ impl<T: Send> ThreadLocal<T> {
}

#[cold]
fn insert(&self, thread: Thread, data: T) -> &T {
fn insert(&self, data: T) -> &T {
let thread = thread_id::get();
let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);

Expand Down
104 changes: 87 additions & 17 deletions src/thread_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,95 @@ impl Thread {
}
}

/// Wrapper around `Thread` that allocates and deallocates the ID.
struct ThreadHolder(Thread);
impl ThreadHolder {
fn new() -> ThreadHolder {
ThreadHolder(Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc()))
}
}
impl Drop for ThreadHolder {
fn drop(&mut self) {
THREAD_ID_MANAGER.lock().unwrap().free(self.0.id);
}
}
cfg_if::cfg_if! {
if #[cfg(feature = "nightly")] {
// This is split into 2 thread-local variables so that we can check whether the
// thread is initialized without having to register a thread-local destructor.
//
// This makes the fast path smaller.
#[thread_local]
static mut THREAD: Option<Thread> = None;
thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard }; }

// Guard to ensure the thread ID is released on thread exit.
struct ThreadGuard;

impl Drop for ThreadGuard {
fn drop(&mut self) {
// SAFETY: this is safe because we know that we (the current thread)
// are the only one who can be accessing our `THREAD` and thus
// it's safe for us to access and drop it.
if let Some(thread) = unsafe { THREAD.take() } {
THREAD_ID_MANAGER.lock().unwrap().free(thread.id);
}
}
}

/// Attempts to get the current thread if `get` has previously been
/// called.
#[inline]
pub(crate) fn try_get() -> Option<Thread> {
unsafe {
THREAD
}
}

thread_local!(static THREAD_HOLDER: ThreadHolder = ThreadHolder::new());
/// Returns a thread ID for the current thread, allocating one if needed.
#[inline]
pub(crate) fn get() -> Thread {
if let Some(thread) = unsafe { THREAD } {
thread
} else {
let new = Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc());
unsafe {
THREAD = Some(new);
}
THREAD_GUARD.with(|_| {});
new
}
}
} else {
use std::cell::Cell;

// This is split into 2 thread-local variables so that we can check whether the
// thread is initialized without having to register a thread-local destructor.
//
// This makes the fast path smaller.
thread_local! { static THREAD: Cell<Option<Thread>> = const { Cell::new(None) }; }
thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard }; }

// Guard to ensure the thread ID is released on thread exit.
struct ThreadGuard;

impl Drop for ThreadGuard {
fn drop(&mut self) {
let thread = THREAD.with(|thread| thread.get()).unwrap();
THREAD_ID_MANAGER.lock().unwrap().free(thread.id);
}
}

/// Get the current thread.
#[inline]
pub(crate) fn get() -> Thread {
THREAD_HOLDER.with(|holder| holder.0)
/// Attempts to get the current thread if `get` has previously been
/// called.
#[inline]
pub(crate) fn try_get() -> Option<Thread> {
THREAD.with(|thread| thread.get())
}

/// Returns a thread ID for the current thread, allocating one if needed.
#[inline]
pub(crate) fn get() -> Thread {
THREAD.with(|thread| {
if let Some(thread) = thread.get() {
thread
} else {
let new = Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc());
thread.set(Some(new));
THREAD_GUARD.with(|_| {});
new
}
})
}
}
}

#[test]
Expand Down