Skip to content

Commit

Permalink
Properly nest ThreadContextElement
Browse files Browse the repository at this point in the history
    * Restore the context in the reverse order of update, so they are properly nested into each other
    * Also, do a minor cleanup

Fixes #2195
  • Loading branch information
qwwdfsad committed Feb 2, 2021
1 parent 727c38f commit 3fa8ee6
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 21 deletions.
4 changes: 2 additions & 2 deletions kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines
Expand Down Expand Up @@ -134,7 +134,7 @@ internal abstract class DispatchedTask<in T>(
* Fatal exception handling can be intercepted with [CoroutineExceptionHandler] element in the context of
* a failed coroutine, but such exceptions should be reported anyway.
*/
internal fun handleFatalException(exception: Throwable?, finallyException: Throwable?) {
public fun handleFatalException(exception: Throwable?, finallyException: Throwable?) {
if (exception === null && finallyException === null) return
if (exception !== null && finallyException !== null) {
exception.addSuppressedThrowable(finallyException)
Expand Down
36 changes: 17 additions & 19 deletions kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.internal
Expand All @@ -11,13 +11,22 @@ import kotlin.coroutines.*
internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")

// Used when there are >= 2 active elements in the context
private class ThreadState(val context: CoroutineContext, n: Int) {
private var a = arrayOfNulls<Any>(n)
@Suppress("UNCHECKED_CAST")
private class ThreadState(@JvmField val context: CoroutineContext, n: Int) {
private val values = arrayOfNulls<Any>(n)
private val elements = arrayOfNulls<ThreadContextElement<Any?>>(n)
private var i = 0

fun append(value: Any?) { a[i++] = value }
fun take() = a[i++]
fun start() { i = 0 }
fun append(element: ThreadContextElement<*>, value: Any?) {
values[i] = value
elements[i++] = element as ThreadContextElement<Any?>
}

fun restore(context: CoroutineContext) {
for (i in elements.indices.reversed()) {
elements[i]?.restoreThreadContext(context, values[i])
}
}
}

// Counts ThreadContextElements in the context
Expand All @@ -42,17 +51,7 @@ private val findOne =
private val updateState =
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
if (element is ThreadContextElement<*>) {
state.append(element.updateThreadContext(state.context))
}
return state
}

// Restores state for all ThreadContextElements in the context from the given ThreadState
private val restoreState =
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
@Suppress("UNCHECKED_CAST")
if (element is ThreadContextElement<*>) {
(element as ThreadContextElement<Any?>).restoreThreadContext(state.context, state.take())
state.append(element, element.updateThreadContext(state.context))
}
return state
}
Expand Down Expand Up @@ -86,8 +85,7 @@ internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements
oldState is ThreadState -> {
// slow path with multiple stored ThreadContextElements
oldState.start()
context.fold(oldState, restoreState)
oldState.restore(context)
}
else -> {
// fast path for one ThreadContextElement, but need to find it
Expand Down
65 changes: 65 additions & 0 deletions kotlinx-coroutines-core/jvm/test/ThreadContextOrderTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines

import kotlinx.coroutines.internal.*
import org.junit.Test
import kotlin.coroutines.*
import kotlin.test.*

class ThreadContextOrderTest : TestBase() {
/*
* The test verifies that two thread context elements are correctly nested:
* The restoration order is the reverse of update order.
*/
private val transactionalContext = ThreadLocal<String>()
private val loggingContext = ThreadLocal<String>()

private val transactionalElement = object : ThreadContextElement<String> {
override val key = ThreadLocalKey(transactionalContext)

override fun updateThreadContext(context: CoroutineContext): String {
assertEquals("test", loggingContext.get())
val previous = transactionalContext.get()
transactionalContext.set("tr coroutine")
return previous
}

override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
assertEquals("test", loggingContext.get())
assertEquals("tr coroutine", transactionalContext.get())
transactionalContext.set(oldState)
}
}

private val loggingElement = object : ThreadContextElement<String> {
override val key = ThreadLocalKey(loggingContext)

override fun updateThreadContext(context: CoroutineContext): String {
val previous = loggingContext.get()
loggingContext.set("log coroutine")
return previous
}

override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
assertEquals("log coroutine", loggingContext.get())
assertEquals("tr coroutine", transactionalContext.get())
loggingContext.set(oldState)
}
}

@Test
fun testCorrectOrder() = runTest {
transactionalContext.set("test")
loggingContext.set("test")
launch(transactionalElement + loggingElement) {
assertEquals("log coroutine", loggingContext.get())
assertEquals("tr coroutine", transactionalContext.get())
}
assertEquals("test", loggingContext.get())
assertEquals("test", transactionalContext.get())

}
}

0 comments on commit 3fa8ee6

Please sign in to comment.