diff --git a/libdd-otel-thread-ctx/src/lib.rs b/libdd-otel-thread-ctx/src/lib.rs index 68306b7552..6e3576faa1 100644 --- a/libdd-otel-thread-ctx/src/lib.rs +++ b/libdd-otel-thread-ctx/src/lib.rs @@ -73,23 +73,24 @@ pub mod linux { sync::atomic::{compiler_fence, AtomicPtr, AtomicU8, Ordering}, }; - extern "C" { - /// Return the address of the current thread's `otel_thread_ctx_v1` local. - /// - /// **CAUTION**: do not use this directly, always go through [get_tls_slot] to read and - /// write it atomically. - fn libdd_get_otel_thread_ctx_v1() -> *mut *mut c_void; - } - - /// Return an atomic view of the TLS slot. The address calculation requires a call to a C shim - /// in order to use the TLSDESC dialect from Rust. The returned address is stable (per thread), - /// so the resulting atomic should be reused whenever possible, to reduce the number of calls - /// to this function. + /// Run `f` with an atomic view of the current thread's TLS slot. + /// + /// The address calculation requires a call to a C shim in order to use the TLSDESC dialect + /// from Rust. The returned address is stable (per thread), so callers should try to do as + /// much work as possible inside a single call to reduce the number of C-shim round-trips. /// /// The slot is read by an async signal handler. Atomic operations should in general use /// [Ordering::Relaxed], but modifications to the record might need additional compiler-only /// fences (see [ThreadContext::update] for an example). - fn get_tls_slot<'a>() -> &'a AtomicPtr { + fn with_tls_slot(f: F) -> R + where + F: FnOnce(&AtomicPtr) -> R, + { + extern "C" { + /// Return the address of the current thread's `otel_thread_ctx_v1` local. + fn libdd_get_otel_thread_ctx_v1() -> *mut *mut c_void; + } + const { assert!( mem::align_of::>() @@ -98,13 +99,12 @@ pub mod linux { } // Safety: the const assertion above ensures the alignment is correct. The TLS slot is - // valid for writes during the lifetime of the program. - // - // We forbid direct usage of `libdd_get_otel_thread_ctx_v1`, which guarantees - // that there's never conflicting non-atomic accesses to the TLS slot. - unsafe { + // valid for the lifetime of the current thread. The `extern "C"` declaration is scoped + // to this function, guaranteeing that all accesses go through the `AtomicPtr` wrapper. + let slot = unsafe { AtomicPtr::from_ptr(libdd_get_otel_thread_ctx_v1().cast::<*mut ThreadContextRecord>()) - } + }; + f(slot) } // We maintain the convention in libdatadog that the `local_root_span_id` attribute key is @@ -395,7 +395,7 @@ pub mod linux { // // We still need a release fence to avoid exposing uninitialized memory to the handler. compiler_fence(Ordering::Release); - Self::swap(get_tls_slot(), self.into_ptr().as_ptr()) + with_tls_slot(|slot| Self::swap(slot, self.into_ptr().as_ptr())) } /// Update the currently attached record in-place. Sets `valid = 0` before the update and @@ -411,36 +411,36 @@ pub mod linux { local_root_span_id: [u8; 8], attrs: &[(u8, &str)], ) { - let slot = get_tls_slot(); - - if let Some(current) = unsafe { slot.load(Ordering::Relaxed).as_mut() } { - current.valid.store(0, Ordering::Relaxed); - compiler_fence(Ordering::SeqCst); - - current.trace_id = trace_id; - current.span_id = span_id; - current.set_attrs(local_root_span_id, attrs); - - compiler_fence(Ordering::SeqCst); - current.valid.store(1, Ordering::Relaxed); - } else { - // No need for `AcqRel`, see [^tls-slot-ordering]. - compiler_fence(Ordering::Release); - // `ThreadContext::new` already initialises `valid = 1`. - let _ = Self::swap( - slot, - ThreadContext::new(trace_id, span_id, local_root_span_id, attrs) - .into_ptr() - .as_ptr(), - ); - } + with_tls_slot(|slot| { + if let Some(current) = unsafe { slot.load(Ordering::Relaxed).as_mut() } { + current.valid.store(0, Ordering::Relaxed); + compiler_fence(Ordering::SeqCst); + + current.trace_id = trace_id; + current.span_id = span_id; + current.set_attrs(local_root_span_id, attrs); + + compiler_fence(Ordering::SeqCst); + current.valid.store(1, Ordering::Relaxed); + } else { + // No need for `AcqRel`, see [^tls-slot-ordering]. + compiler_fence(Ordering::Release); + // `ThreadContext::new` already initialises `valid = 1`. + let _ = Self::swap( + slot, + ThreadContext::new(trace_id, span_id, local_root_span_id, attrs) + .into_ptr() + .as_ptr(), + ); + } + }) } /// Detach the current record from the TLS slot. Writes null to the slot and returns the /// detached record. pub fn detach() -> Option { // We don't need any fence here, see [^tls-slot-ordering]. - Self::swap(get_tls_slot(), ptr::null_mut()) + with_tls_slot(|slot| Self::swap(slot, ptr::null_mut())) } } @@ -463,7 +463,7 @@ pub mod linux { /// Read the TLS pointer for the current thread (the value stored in the TLS slot, not the /// address of the slot itself). fn read_tls_context_ptr() -> *const ThreadContextRecord { - super::get_tls_slot().load(Ordering::Relaxed) + super::with_tls_slot(|slot| slot.load(Ordering::Relaxed)) } #[test]