From 99e8050f960f5519685b9cab532eb34dea8b1d6e Mon Sep 17 00:00:00 2001 From: Vsevolod Tolstopyatov Date: Tue, 21 Aug 2018 19:45:29 +0300 Subject: [PATCH] =?UTF-8?q?Implement=20withTimeoutOrNull=20via=20withTimeo?= =?UTF-8?q?ut=20to=20avoid=20timing=20bugs=20and=20=E2=80=A6=20(#499)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement withTimeoutOrNull via withTimeout to avoid timing bugs and races. Remove deprecated API Fixes #498 --- .../src/Scheduled.kt | 46 ++++++------------- .../test/WithTimeoutOrNullTest.kt | 31 +++++++++++++ 2 files changed, 45 insertions(+), 32 deletions(-) diff --git a/common/kotlinx-coroutines-core-common/src/Scheduled.kt b/common/kotlinx-coroutines-core-common/src/Scheduled.kt index 68cf55eec4..9c5250a583 100644 --- a/common/kotlinx-coroutines-core-common/src/Scheduled.kt +++ b/common/kotlinx-coroutines-core-common/src/Scheduled.kt @@ -69,13 +69,6 @@ private fun setupTimeout( return coroutine.startUndispatchedOrReturn(coroutine, block) } -/** - * @suppress **Deprecated**: for binary compatibility only - */ -@Deprecated("for binary compatibility only", level=DeprecationLevel.HIDDEN) -public suspend fun withTimeout(time: Long, unit: TimeUnit = TimeUnit.MILLISECONDS, block: suspend () -> T): T = - withTimeout(time, unit) { block() } - private open class TimeoutCoroutine( @JvmField val time: Long, @JvmField val unit: TimeUnit, @@ -140,32 +133,21 @@ public suspend fun withTimeoutOrNull(time: Int, block: suspend CoroutineScop */ public suspend fun withTimeoutOrNull(time: Long, unit: TimeUnit = TimeUnit.MILLISECONDS, block: suspend CoroutineScope.() -> T): T? { if (time <= 0L) return null - return suspendCoroutineUninterceptedOrReturn { uCont -> - setupTimeout(TimeoutOrNullCoroutine(time, unit, uCont), block) - } -} - -/** - * @suppress **Deprecated**: for binary compatibility only - */ -@Deprecated("for binary compatibility only", level=DeprecationLevel.HIDDEN) -public suspend fun withTimeoutOrNull(time: Long, unit: TimeUnit = TimeUnit.MILLISECONDS, block: suspend () -> T): T? = - withTimeoutOrNull(time, unit) { block() } -private class TimeoutOrNullCoroutine( - time: Long, - unit: TimeUnit, - uCont: Continuation // unintercepted continuation -) : TimeoutCoroutine(time, unit, uCont) { - @Suppress("UNCHECKED_CAST") - internal override fun onCompletionInternal(state: Any?, mode: Int) { - if (state is CompletedExceptionally) { - val exception = state.cause - if (exception is TimeoutCancellationException && exception.coroutine === this) - uCont.resumeUninterceptedMode(null, mode) else - uCont.resumeUninterceptedWithExceptionMode(exception, mode) - } else - uCont.resumeUninterceptedMode(state as T, mode) + var coroutine: TimeoutCoroutine? = null + try { + return suspendCoroutineUninterceptedOrReturn { uCont -> + TimeoutCoroutine(time, unit, uCont).let { + coroutine = it + setupTimeout(it, block) + } + } + } catch (e: TimeoutCancellationException) { + // Return null iff it's our exception, otherwise propagate it upstream (e.g. in case of nested withTimeouts) + if (e.coroutine === coroutine) { + return null + } + throw e } } diff --git a/common/kotlinx-coroutines-core-common/test/WithTimeoutOrNullTest.kt b/common/kotlinx-coroutines-core-common/test/WithTimeoutOrNullTest.kt index 74d142646f..c4a814123f 100644 --- a/common/kotlinx-coroutines-core-common/test/WithTimeoutOrNullTest.kt +++ b/common/kotlinx-coroutines-core-common/test/WithTimeoutOrNullTest.kt @@ -7,6 +7,7 @@ package kotlinx.coroutines.experimental +import kotlinx.coroutines.experimental.channels.* import kotlin.coroutines.experimental.* import kotlin.test.* @@ -81,6 +82,23 @@ class WithTimeoutOrNullTest : TestBase() { finish(2) } + @Test + fun testSmallTimeout() = runTest { + val channel = Channel(1) + val value = withTimeoutOrNull(1) { + channel.receive() + } + + assertNull(value) + } + + @Test + fun testThrowException() = runTest(expected = {it is AssertionError}) { + withTimeoutOrNull(Long.MAX_VALUE) { + throw AssertionError() + } + } + @Test fun testInnerTimeoutTest() = runTest( expected = { it is CancellationException } @@ -96,6 +114,19 @@ class WithTimeoutOrNullTest : TestBase() { expectUnreached() // will timeout } + @Test + fun testNestedTimeout() = runTest(expected = { it is TimeoutCancellationException }) { + withTimeoutOrNull(Long.MAX_VALUE) { + // Exception from this withTimeout is not suppressed by withTimeoutOrNull + withTimeout(10) { + delay(Long.MAX_VALUE) + 1 + } + } + + expectUnreached() + } + @Test fun testOuterTimeoutTest() = runTest { var counter = 0