Skip to content

Commit

Permalink
Restore context preservation invariant in flatMapMerge (#1452)
Browse files Browse the repository at this point in the history
* Introduce (again) flowProduce in order to properly propagate cancellation to the upstream in flatMapMerge.

Previously this issue was masked by SerializingCollector fast-path

* Re-implement flatMapMerge via the channel to have context preservation property

Fixes #1440
  • Loading branch information
qwwdfsad committed Aug 22, 2019
1 parent bcf4a8c commit 0342a0a
Show file tree
Hide file tree
Showing 14 changed files with 216 additions and 151 deletions.
35 changes: 0 additions & 35 deletions benchmarks/src/jmh/kotlin/benchmarks/YieldRelativeCostBenchmark.kt

This file was deleted.

47 changes: 47 additions & 0 deletions benchmarks/src/jmh/kotlin/benchmarks/flow/FlatMapMergeBenchmark.kt
@@ -0,0 +1,47 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package benchmarks.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import org.openjdk.jmh.annotations.*
import java.util.concurrent.*

@Warmup(iterations = 7, time = 1, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 7, time = 1, timeUnit = TimeUnit.SECONDS)
@Fork(value = 1)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Benchmark)
open class FlatMapMergeBenchmark {

// Note: tests only absence of contention on downstream

@Param("10", "100", "1000")
private var iterations = 100

@Benchmark
fun flatMapUnsafe() = runBlocking {
benchmarks.flow.scrabble.flow {
repeat(iterations) { emit(it) }
}.flatMapMerge { value ->
flowOf(value)
}.collect {
if (it == -1) error("")
}
}

@Benchmark
fun flatMapSafe() = runBlocking {
kotlinx.coroutines.flow.flow {
repeat(iterations) { emit(it) }
}.flatMapMerge { value ->
flowOf(value)
}.collect {
if (it == -1) error("")
}
}

}
Expand Up @@ -3,7 +3,7 @@
*/


package benchmarks.flow.misc
package benchmarks.flow

import benchmarks.flow.scrabble.flow
import io.reactivex.*
Expand Down Expand Up @@ -35,7 +35,7 @@ import java.util.concurrent.*
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Benchmark)
open class Numbers {
open class NumbersBenchmark {

companion object {
private const val primes = 100
Expand Down
Expand Up @@ -2,7 +2,7 @@
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package benchmarks.flow.misc
package benchmarks.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
Expand Down
Expand Up @@ -993,7 +993,7 @@ public final class kotlinx/coroutines/flow/internal/SafeCollectorKt {
public static final fun unsafeFlow (Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow;
}

public final class kotlinx/coroutines/flow/internal/SendingCollector : kotlinx/coroutines/flow/internal/ConcurrentFlowCollector {
public final class kotlinx/coroutines/flow/internal/SendingCollector : kotlinx/coroutines/flow/FlowCollector {
public fun <init> (Lkotlinx/coroutines/channels/SendChannel;)V
public fun emit (Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}
Expand Down
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/common/src/channels/Produce.kt
Expand Up @@ -126,7 +126,7 @@ public fun <E> CoroutineScope.produce(
return coroutine
}

private class ProducerCoroutine<E>(
internal open class ProducerCoroutine<E>(
parentContext: CoroutineContext, channel: Channel<E>
) : ChannelCoroutine<E>(parentContext, channel, active = true), ProducerScope<E> {
override val isActive: Boolean
Expand Down
12 changes: 2 additions & 10 deletions kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt
Expand Up @@ -58,7 +58,7 @@ public abstract class ChannelFlow<T>(
protected abstract suspend fun collectTo(scope: ProducerScope<T>)

// shared code to create a suspend lambda from collectTo function in one place
private val collectToFun: suspend (ProducerScope<T>) -> Unit
internal val collectToFun: suspend (ProducerScope<T>) -> Unit
get() = { collectTo(it) }

private val produceCapacity: Int
Expand Down Expand Up @@ -140,13 +140,11 @@ internal class ChannelFlowOperatorImpl<T>(
private fun <T> FlowCollector<T>.withUndispatchedContextCollector(emitContext: CoroutineContext): FlowCollector<T> = when (this) {
// SendingCollector & NopCollector do not care about the context at all and can be used as is
is SendingCollector, is NopCollector -> this
// Original collector is concurrent, so wrap into ConcurrentUndispatchedContextCollector (also concurrent)
is ConcurrentFlowCollector -> ConcurrentUndispatchedContextCollector(this, emitContext)
// Otherwise just wrap into UndispatchedContextCollector interface implementation
else -> UndispatchedContextCollector(this, emitContext)
}

private open class UndispatchedContextCollector<T>(
private class UndispatchedContextCollector<T>(
downstream: FlowCollector<T>,
private val emitContext: CoroutineContext
) : FlowCollector<T> {
Expand All @@ -157,12 +155,6 @@ private open class UndispatchedContextCollector<T>(
withContextUndispatched(emitContext, countOrElement, emitRef, value)
}

// named class for a combination of UndispatchedContextCollector & ConcurrentFlowCollector interface
private class ConcurrentUndispatchedContextCollector<T>(
downstream: ConcurrentFlowCollector<T>,
emitContext: CoroutineContext
) : UndispatchedContextCollector<T>(downstream, emitContext), ConcurrentFlowCollector<T>

// Efficiently computes block(value) in the newContext
private suspend fun <T, V> withContextUndispatched(
newContext: CoroutineContext,
Expand Down
81 changes: 0 additions & 81 deletions kotlinx-coroutines-core/common/src/flow/internal/Concurrent.kt

This file was deleted.

22 changes: 22 additions & 0 deletions kotlinx-coroutines-core/common/src/flow/internal/FlowCoroutine.kt
Expand Up @@ -52,6 +52,18 @@ internal fun <R> scopedFlow(@BuilderInference block: suspend CoroutineScope.(Flo
flowScope { block(collector) }
}

internal fun <T> CoroutineScope.flowProduce(
context: CoroutineContext,
capacity: Int = 0,
@BuilderInference block: suspend ProducerScope<T>.() -> Unit
): ReceiveChannel<T> {
val channel = Channel<T>(capacity)
val newContext = newCoroutineContext(context)
val coroutine = FlowProduceCoroutine(newContext, channel)
coroutine.start(CoroutineStart.DEFAULT, coroutine, block)
return coroutine
}

private class FlowCoroutine<T>(
context: CoroutineContext,
uCont: Continuation<T>
Expand All @@ -61,3 +73,13 @@ private class FlowCoroutine<T>(
return cancelImpl(cause)
}
}

private class FlowProduceCoroutine<T>(
parentContext: CoroutineContext,
channel: Channel<T>
) : ProducerCoroutine<T>(parentContext, channel) {
public override fun childCancelled(cause: Throwable): Boolean {
if (cause is ChildCancelledException) return true
return cancelImpl(cause)
}
}
27 changes: 9 additions & 18 deletions kotlinx-coroutines-core/common/src/flow/internal/Merge.kt
Expand Up @@ -38,17 +38,21 @@ internal class ChannelFlowTransformLatest<T, R>(
}

internal class ChannelFlowMerge<T>(
flow: Flow<Flow<T>>,
private val flow: Flow<Flow<T>>,
private val concurrency: Int,
context: CoroutineContext = EmptyCoroutineContext,
capacity: Int = Channel.OPTIONAL_CHANNEL
) : ChannelFlowOperator<Flow<T>, T>(flow, context, capacity) {
capacity: Int = Channel.BUFFERED
) : ChannelFlow<T>(context, capacity) {
override fun create(context: CoroutineContext, capacity: Int): ChannelFlow<T> =
ChannelFlowMerge(flow, concurrency, context, capacity)

// The actual merge implementation with concurrency limit
private suspend fun mergeImpl(scope: CoroutineScope, collector: ConcurrentFlowCollector<T>) {
override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> {
return scope.flowProduce(context, capacity, block = collectToFun)
}

override suspend fun collectTo(scope: ProducerScope<T>) {
val semaphore = Semaphore(concurrency)
val collector = SendingCollector(scope)
val job: Job? = coroutineContext[Job]
flow.collect { inner ->
/*
Expand All @@ -68,19 +72,6 @@ internal class ChannelFlowMerge<T>(
}
}

// Fast path in ChannelFlowOperator calls this function (channel was not created yet)
override suspend fun flowCollect(collector: FlowCollector<T>) {
// this function should not have been invoked when channel was explicitly requested
assert { capacity == Channel.OPTIONAL_CHANNEL }
flowScope {
mergeImpl(this, collector.asConcurrentFlowCollector())
}
}

// Slow path when output channel is required (and was created)
override suspend fun collectTo(scope: ProducerScope<T>) =
mergeImpl(scope, SendingCollector(scope))

override fun additionalToStringProps(): String =
"concurrency=$concurrency, "
}
Expand Up @@ -4,7 +4,9 @@

package kotlinx.coroutines.flow.internal

internal object NopCollector : ConcurrentFlowCollector<Any?> {
import kotlinx.coroutines.flow.*

internal object NopCollector : FlowCollector<Any?> {
override suspend fun emit(value: Any?) {
// does nothing
}
Expand Down
@@ -0,0 +1,20 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.flow.internal

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.*

/**
* Collection that sends to channel
* @suppress **This an internal API and should not be used from general code.**
*/
@InternalCoroutinesApi
public class SendingCollector<T>(
private val channel: SendChannel<T>
) : FlowCollector<T> {
override suspend fun emit(value: T) = channel.send(value)
}

0 comments on commit 0342a0a

Please sign in to comment.