Skip to content

Commit

Permalink
Add ThreadLocal.isPresent and ThreadLocal.ensurePresent methods
Browse files Browse the repository at this point in the history
Fixes #1028
  • Loading branch information
qwwdfsad committed Mar 14, 2019
1 parent 91c49e1 commit 1985155
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ public final class kotlinx/coroutines/ThreadContextElement$DefaultImpls {
public final class kotlinx/coroutines/ThreadContextElementKt {
public static final fun asContextElement (Ljava/lang/ThreadLocal;Ljava/lang/Object;)Lkotlinx/coroutines/ThreadContextElement;
public static synthetic fun asContextElement$default (Ljava/lang/ThreadLocal;Ljava/lang/Object;ILjava/lang/Object;)Lkotlinx/coroutines/ThreadContextElement;
public static final fun ensurePresent (Ljava/lang/ThreadLocal;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun isPresent (Ljava/lang/ThreadLocal;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public final class kotlinx/coroutines/ThreadPoolDispatcherKt {
Expand Down
7 changes: 6 additions & 1 deletion docs/coroutine-context-and-dispatchers.md
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ fun main() = runBlocking<Unit> {
threadLocal.set("main")
println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
val job = launch(Dispatchers.Default + threadLocal.asContextElement(value = "launch")) {
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
yield()
println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
}
Expand Down Expand Up @@ -664,6 +664,10 @@ Post-main, current thread: Thread[main @coroutine#1,5,main], thread local value:
<!--- TEST FLEXIBLE_THREAD -->
Note how easily one may forget the corresponding context element and then still safely access thread local.
To avoid such situations, it is recommended to use [ensurePresent] method
and fail-fast on improper usages.
`ThreadLocal` has first-class support and can be used with any primitive `kotlinx.coroutines` provides.
It has one key limitation: when thread-local is mutated, a new value is not propagated to the coroutine caller
(as context element cannot track all `ThreadLocal` object accesses) and updated value is lost on the next suspension.
Expand Down Expand Up @@ -701,5 +705,6 @@ that should be implemented.
[MainScope()]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-main-scope.html
[Dispatchers.Main]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-dispatchers/-main.html
[asContextElement]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/java.lang.-thread-local/as-context-element.html
[ensurePresent]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/java.lang.-thread-local/ensure-present.html
[ThreadContextElement]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-thread-context-element/index.html
<!--- END -->
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ import org.hamcrest.core.*
import org.junit.*
import org.junit.Assert.*
import org.junit.Test
import java.io.*
import java.util.concurrent.*
import kotlin.test.assertFailsWith

class ListenableFutureTest : TestBase() {
@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import java.util.function.*
import kotlin.concurrent.*
import kotlin.coroutines.*
import kotlin.reflect.*
import kotlin.test.assertFailsWith

class FutureTest : TestBase() {
@Before
Expand Down
36 changes: 36 additions & 0 deletions kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,39 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
*/
public fun <T> ThreadLocal<T>.asContextElement(value: T = get()): ThreadContextElement<T> =
ThreadLocalElement(value, this)

/**
* Return `true` when current thread local is present in the coroutine context, `false` otherwise.
* Thread local can be present in the context only if it was added via [asContextElement] to the context.
*
* Example of usage:
* ```
* suspend fun processRequest() {
* if (traceCurrentRequestThreadLocal.isPresent()) { // Probabilistic tracing
* // Do some heavy-weight tracing
* }
* // Process request regularly
* }
* ```
*/
public suspend fun ThreadLocal<*>.isPresent(): Boolean = coroutineContext[ThreadLocalKey(this)] != null

/**
* Checks whether current thread local is present in the coroutine context and throws [IllegalStateException] if it is not.
* It is a good practice to validate that thread local is present in the context, especially in large code-bases,
* to avoid stale thread-local values and to have a strict invariants.
*
* E.g. one may use the following method to enforce proper use of the thread locals with coroutines:
* ```
* public suspend inline fun <T> ThreadLocal<T>.getSafely(): T {
* ensurePresent()
* return get()
* }
*
* // Usage
* withContext(...) {
* val value = threadLocal.getSafely() // Fail-fast in case of improper context
* }
* ```
*/
public suspend fun ThreadLocal<*>.ensurePresent(): Unit = check(isPresent()) { "ThreadLocal $this is missing from context $coroutineContext" }
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
}

// top-level data class for a nicer out-of-the-box toString representation and class name
private data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>

internal class ThreadLocalElement<T>(
private val value: T,
Expand Down
6 changes: 6 additions & 0 deletions kotlinx-coroutines-core/jvm/test/TestBase.kt
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,10 @@ public actual open class TestBase actual constructor() {
if (exCount < unhandled.size)
error("Too few unhandled exceptions $exCount, expected ${unhandled.size}")
}

protected inline fun <reified T: Throwable> assertFailsWith(block: () -> Unit): T {
val result = runCatching(block)
assertTrue(result.exceptionOrNull() is T, "Expected ${T::class}, but had $result")
return result.exceptionOrNull()!! as T
}
}
18 changes: 18 additions & 0 deletions kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package kotlinx.coroutines

import org.junit.*
import org.junit.Test
import java.lang.IllegalStateException
import kotlin.test.*

@Suppress("RedundantAsync")
Expand All @@ -22,25 +23,33 @@ class ThreadLocalTest : TestBase() {
@Test
fun testThreadLocal() = runTest {
assertNull(stringThreadLocal.get())
assertFalse(stringThreadLocal.isPresent())
val deferred = async(Dispatchers.Default + stringThreadLocal.asContextElement("value")) {
assertEquals("value", stringThreadLocal.get())
assertTrue(stringThreadLocal.isPresent())
withContext(executor) {
assertTrue(stringThreadLocal.isPresent())
assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
assertEquals("value", stringThreadLocal.get())
}
assertTrue(stringThreadLocal.isPresent())
assertEquals("value", stringThreadLocal.get())
}

assertNull(stringThreadLocal.get())
deferred.await()
assertNull(stringThreadLocal.get())
assertFalse(stringThreadLocal.isPresent())
}

@Test
fun testThreadLocalInitialValue() = runTest {
intThreadLocal.set(42)
assertFalse(intThreadLocal.isPresent())
val deferred = async(Dispatchers.Default + intThreadLocal.asContextElement(239)) {
assertEquals(239, intThreadLocal.get())
withContext(executor) {
intThreadLocal.ensurePresent()
assertEquals(239, intThreadLocal.get())
}
assertEquals(239, intThreadLocal.get())
Expand All @@ -63,6 +72,8 @@ class ThreadLocalTest : TestBase() {
withContext(executor) {
assertEquals(239, intThreadLocal.get())
assertEquals("pew", stringThreadLocal.get())
intThreadLocal.ensurePresent()
stringThreadLocal.ensurePresent()
}

assertEquals(239, intThreadLocal.get())
Expand Down Expand Up @@ -129,6 +140,7 @@ class ThreadLocalTest : TestBase() {
}

deferred.await()
assertFalse(stringThreadLocal.isPresent())
assertEquals("main", stringThreadLocal.get())
}

Expand Down Expand Up @@ -212,4 +224,10 @@ class ThreadLocalTest : TestBase() {
assertNotSame(mainThread, Thread.currentThread())
}.await()
}

@Test
fun testMissingThreadLocal() = runTest {
assertFailsWith<IllegalStateException> { stringThreadLocal.ensurePresent() }
assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fun main() = runBlocking<Unit> {
threadLocal.set("main")
println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
val job = launch(Dispatchers.Default + threadLocal.asContextElement(value = "launch")) {
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
yield()
println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
}
Expand Down

0 comments on commit 1985155

Please sign in to comment.