diff --git a/kotlinx-coroutines-core/common/src/flow/operators/Emitters.kt b/kotlinx-coroutines-core/common/src/flow/operators/Emitters.kt index f3a112682a..6a910764b7 100644 --- a/kotlinx-coroutines-core/common/src/flow/operators/Emitters.kt +++ b/kotlinx-coroutines-core/common/src/flow/operators/Emitters.kt @@ -128,13 +128,26 @@ public fun Flow.onStart( public fun Flow.onCompletion( action: suspend FlowCollector.(cause: Throwable?) -> Unit ): Flow = unsafeFlow { // Note: unsafe flow is used here, but safe collector is used to invoke completion action - var exception: Throwable? = null - try { - exception = catchImpl(this) - } finally { - // Separate method because of KT-32220 - SafeCollector(this, coroutineContext).invokeSafely(action, exception) - exception?.let { throw it } + val exception = try { + catchImpl(this) + } catch (e: Throwable) { + /* + * Exception from the downstream. + * Use throwing collector to prevent any emissions from the + * completion sequence when downstream has failed, otherwise it may + * lead to a non-sequential behaviour impossible with `finally` + */ + ThrowingCollector(e).invokeSafely(action, null) + throw e + } + // Exception from the upstream or normal completion + SafeCollector(this, coroutineContext).invokeSafely(action, exception) + exception?.let { throw it } +} + +private class ThrowingCollector(private val e: Throwable) : FlowCollector { + override suspend fun emit(value: Any?) { + throw e } } @@ -155,5 +168,3 @@ private suspend fun FlowCollector.invokeSafely( throw e } } - - diff --git a/kotlinx-coroutines-core/common/test/flow/operators/OnCompletionTest.kt b/kotlinx-coroutines-core/common/test/flow/operators/OnCompletionTest.kt index af50608a2a..c079500ef7 100644 --- a/kotlinx-coroutines-core/common/test/flow/operators/OnCompletionTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/operators/OnCompletionTest.kt @@ -5,6 +5,7 @@ package kotlinx.coroutines.flow import kotlinx.coroutines.* +import kotlinx.coroutines.flow.internal.* import kotlin.test.* class OnCompletionTest : TestBase() { @@ -171,14 +172,91 @@ class OnCompletionTest : TestBase() { .onCompletion { e -> expect(8) assertNull(e) - emit(TestData.Done(e)) + try { + emit(TestData.Done(e)) + expectUnreached() + } finally { + expect(9) + } }.collect { collected += it } } } - val expected = (1..5).map { TestData.Value(it) } + TestData.Done(null) + val expected = (1..5).map { TestData.Value(it) } assertEquals(expected, collected) - finish(9) + finish(10) + } + + @Test + fun testFailedEmit() = runTest { + val cause = TestException() + assertFailsWith { + flow { + expect(1) + emit(TestData.Value(2)) + expectUnreached() + }.onCompletion { + assertNull(it) + expect(3) + try { + emit(TestData.Done(it)) + expectUnreached() + } catch (e: TestException) { + assertSame(cause, e) + finish(4) + } + }.collect { + expect((it as TestData.Value).i) + throw cause + } + } + } + + @Test + fun testFirst() = runTest { + val value = flowOf(239).onCompletion { + assertNull(it) + expect(1) + try { + emit(42) + expectUnreached() + } catch (e: Throwable) { + assertTrue { e is AbortFlowException } + } + }.first() + assertEquals(239, value) + finish(2) + } + + @Test + fun testSingle() = runTest { + assertFailsWith { + flowOf(239).onCompletion { + assertNull(it) + expect(1) + try { + emit(42) + expectUnreached() + } catch (e: Throwable) { + // Second emit -- failure + assertTrue { e is IllegalStateException } + throw e + } + }.single() + expectUnreached() + } + finish(2) + } + + @Test + fun testEmptySingleInterference() = runTest { + val value = emptyFlow().onCompletion { + assertNull(it) + expect(1) + emit(42) + }.single() + assertEquals(42, value) + finish(2) } }