Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,18 @@

package org.apache.pekko.stream.scaladsl

import java.util.Collections
import java.util.concurrent.ThreadLocalRandom
import java.util.concurrent.atomic.AtomicInteger

import scala.annotation.switch
import scala.concurrent._
import scala.concurrent.duration._
import scala.util.control.NoStackTrace

import org.apache.pekko
import pekko.NotUsed
import pekko.pattern.FutureTimeoutSupport
import pekko.stream._
import pekko.stream.stage.GraphStage
import pekko.stream.stage.GraphStageLogic
Expand All @@ -30,9 +37,11 @@ import pekko.testkit.TestLatch

import org.scalatest.exceptions.TestFailedException

class FlowFlattenMergeSpec extends StreamSpec {
class FlowFlattenMergeSpec extends StreamSpec with FutureTimeoutSupport {
import system.dispatcher

class BoomException extends RuntimeException("BOOM~~") with NoStackTrace

def src10(i: Int) = Source(i until (i + 10))
def blocked = Source.future(Promise[Int]().future)

Expand Down Expand Up @@ -280,5 +289,153 @@ class FlowFlattenMergeSpec extends StreamSpec {
probe.expectComplete()
}

val checkBreadths = List(1, 2, 4, 8, 16, 32, 64, 128)

for (b <- checkBreadths) {
s"work with value presented sources with breadth: $b" in {
Source(
List(
Source.empty[Int],
Source.single(1),
Source.empty[Int],
Source(List(2, 3, 4)),
Source.future(Future.successful(5)),
Source.lazyFuture(() => Future.successful(6)),
Source.future(after(1.millis)(Future.successful(7)))))
.flatMapMerge(b, identity)
.runWith(toSet)
.futureValue should ===((1 to 7).toSet)
}
}

def generateRandomValuePresentedSources(nums: Int): (Int, List[Source[Int, NotUsed]]) = {
val seq = List.tabulate(nums) { _ =>
val random = ThreadLocalRandom.current().nextInt(1, 10)
(random: @switch) match {
case 1 => Source.single(1)
case 2 => Source(List(1))
case 3 => Source.fromJavaStream(() => Collections.singleton(1).stream())
case 4 => Source.future(Future.successful(1))
case 5 => Source.future(after(1.millis)(Future.successful(1)))
case _ => Source.empty[Int]
}
}
val sum = seq.filterNot(_.eq(Source.empty[Int])).size
(sum, seq)
}

for (b <- checkBreadths) {
s"work with generated value presented sources with breadth: $b " in {
val (sum, sources @ _) = generateRandomValuePresentedSources(10000)
Source(sources)
.flatMapMerge(b, identity(_))
.runWith(Sink.seq)
.map(_.sum)(scala.concurrent.ExecutionContext.parasitic)
.futureValue shouldBe sum
}
}

"work with value presented failed sources" in {
val ex = new BoomException
Source(
List(
Source.empty[Int],
Source.single(1),
Source.empty[Int],
Source(List(2, 3, 4)),
Source.future(Future.failed(ex)),
Source.lazyFuture(() => Future.successful(5))))
.flatMapMerge(ThreadLocalRandom.current().nextInt(1, 129), identity)
.onErrorComplete[BoomException]()
.runWith(toSet)
.futureValue.subsetOf((1 to 5).toSet) should ===(true)
}

val breadth = ThreadLocalRandom.current().nextInt(4, 65)
s"avoid pre-materialization for value-presented sources, breadth = $breadth" in {
val materializationCounter = new AtomicInteger(0)
val n = breadth * 3
val probe = Source(1 to n)
.flatMapMerge(
breadth,
value =>
Source.lazySingle(() => {
materializationCounter.incrementAndGet()
value
}))
.runWith(TestSink())

probe.request(n.toLong)
probe.expectNextN(n.toLong).toSet should ===((1 to n).toSet)
probe.expectComplete()
// Source.lazySingle is not a value-presented source, so each is materialized.
materializationCounter.get() shouldBe n
}

s"only materialize non-value-presented inner sources, breadth = $breadth" in {
val materializationCounter = new AtomicInteger(0)
val n = breadth * 3
// Mix value-presented (Source.single, fast path) with non-value-presented
// (lazySingle.buffer, slow path). The counter sits inside the lazySingle
// factory and only fires when the inner source is materialized as a substream.
val probe = Source(1 to (n * 2))
.flatMapMerge(
breadth,
value =>
if (value % 2 == 0) Source.single(value)
else
Source
.lazySingle(() => {
materializationCounter.incrementAndGet()
value
})
.buffer(1, overflowStrategy = OverflowStrategy.backpressure))
.runWith(TestSink())

probe.request(n.toLong * 2)
probe.expectNextN(n.toLong * 2).toSet should ===((1 to (n * 2)).toSet)
probe.expectComplete()
// Only odd values (non-VP) take the substream materialization path.
materializationCounter.get() shouldBe n
}

"close JavaStream-backed inner sources on exhaustion" in {
val closeCount = new AtomicInteger(0)
val streams = (1 to 4).toList
Source(streams)
.flatMapMerge(
4,
(n: Int) =>
Source.fromJavaStream(() =>
java.util.stream.Stream.of((1 to n).map(Integer.valueOf): _*).onClose(() =>
closeCount.incrementAndGet())))
.runWith(Sink.ignore)
.futureValue
closeCount.get() shouldBe streams.size
}

"close JavaStream-backed inner sources on downstream cancel" in {
val closeCount = new AtomicInteger(0)
// Endless inner streams; when downstream cancels, the inflight wrappers
// queued in FlattenMerge must close their underlying Java streams.
val probe = Source
.repeat(())
.flatMapMerge(
4,
_ =>
Source.fromJavaStream(() =>
java.util.stream.Stream
.generate[Integer](() => 1)
.onClose(() => closeCount.incrementAndGet())))
.runWith(TestSink())

probe.request(8)
probe.expectNextN(8)
probe.cancel()
awaitAssert {
closeCount.get() should be >= 1
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,103 +27,11 @@ import pekko.stream.{ Attributes, FlowShape, Graph, Inlet, Outlet, SourceShape,
import pekko.stream.impl.{ Buffer => BufferImpl, FailedSource, JavaStreamSource, TraversalBuilder }
import pekko.stream.impl.Stages.DefaultAttributes
import pekko.stream.impl.fusing.GraphStages.{ FutureSource, RepeatSource, SingleSource }
import pekko.stream.impl.fusing.InflightSources._
import pekko.stream.scaladsl.Source
import pekko.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
import pekko.util.OptionVal

/**
* INTERNAL API
*/
@InternalApi
private[pekko] object FlattenConcat {
private sealed abstract class InflightSource[T] {
def hasNext: Boolean
def next(): T
def tryPull(): Unit
def cancel(cause: Throwable): Unit
def isClosed: Boolean
def hasFailed: Boolean = failure.isDefined
def failure: Option[Throwable] = None
def materialize(): Unit = ()
}

private final class InflightIteratorSource[T](iterator: Iterator[T]) extends InflightSource[T] {
override def hasNext: Boolean = iterator.hasNext
override def next(): T = iterator.next()
override def tryPull(): Unit = ()
override def cancel(cause: Throwable): Unit = ()
override def isClosed: Boolean = !hasNext
}

private final class InflightRangeSource[T](range: immutable.Range) extends InflightSource[T] {
private val isEmptyRange = range.isEmpty
private val rangeLast = if (isEmptyRange) 0 else range.last
private val rangeStep = range.step
private var nextElement = range.start
private var closed = isEmptyRange

override def hasNext: Boolean = !closed
override def next(): T =
if (closed) throw new NoSuchElementException("next called after completion")
else {
val current = nextElement
if (current == rangeLast) closed = true
else nextElement = current + rangeStep
current.asInstanceOf[T]
}
override def tryPull(): Unit = ()
override def cancel(cause: Throwable): Unit = ()
override def isClosed: Boolean = closed
}

private final class InflightRepeatSource[T](elem: T) extends InflightSource[T] {
override def hasNext: Boolean = true
override def next(): T = elem
override def tryPull(): Unit = ()
override def cancel(cause: Throwable): Unit = ()
override def isClosed: Boolean = false
}

private final class InflightCompletedFutureSource[T](result: Try[T]) extends InflightSource[T] {
private var _hasNext = result.isSuccess
override def hasNext: Boolean = _hasNext
override def next(): T = {
if (_hasNext) {
_hasNext = false
result.get
} else throw new NoSuchElementException("next called after completion")
}
override def hasFailed: Boolean = result.isFailure
override def failure: Option[Throwable] = result.failed.toOption
override def tryPull(): Unit = ()
override def cancel(cause: Throwable): Unit = ()
override def isClosed: Boolean = true
}

private final class InflightPendingFutureSource[T](cb: InflightSource[T] => Unit)
extends InflightSource[T]
with (Try[T] => Unit) {
private var result: Try[T] = MapAsync.NotYetThere
private var consumed = false
override def apply(result: Try[T]): Unit = {
this.result = result
cb(this)
}
override def hasNext: Boolean = (result ne MapAsync.NotYetThere) && !consumed && result.isSuccess
override def next(): T = {
if (!consumed) {
consumed = true
result.get
} else throw new NoSuchElementException("next called after completion")
}
override def hasFailed: Boolean = (result ne MapAsync.NotYetThere) && result.isFailure
override def failure: Option[Throwable] = if (result eq MapAsync.NotYetThere) None else result.failed.toOption
override def tryPull(): Unit = ()
override def cancel(cause: Throwable): Unit = ()
override def isClosed: Boolean = consumed || hasFailed
}
}

/**
* INTERNAL API
*/
Expand All @@ -138,7 +46,6 @@ private[pekko] final class FlattenConcat[T, M](parallelism: Int)
override val shape: FlowShape[Graph[SourceShape[T], M], T] = FlowShape(in, out)
override def createLogic(enclosingAttributes: Attributes) = {
object FlattenConcatLogic extends GraphStageLogic(shape) with InHandler with OutHandler {
import FlattenConcat._
// InflightSource[T] or SingleSource[T]
// AnyRef here to avoid lift the SingleSource[T] to InflightSource[T]
private var queue: BufferImpl[AnyRef] = _
Expand Down Expand Up @@ -269,6 +176,20 @@ private[pekko] final class FlattenConcat[T, M](parallelism: Int)
queue.enqueue(inflightSource)
}

private def addJavaStreamSource(javaStream: JavaStreamSource[T, _]): Unit = {
val inflightSource = new InflightJavaStreamSource[T](javaStream.open)
if (isAvailable(out) && queue.isEmpty) {
if (inflightSource.hasNext) {
push(out, inflightSource.next())
if (inflightSource.hasNext) {
queue.enqueue(inflightSource)
}
}
} else if (inflightSource.hasNext) {
queue.enqueue(inflightSource)
}
}

private def addCompletedFutureElem(elem: Try[T]): Unit = {
if (isAvailable(out) && queue.isEmpty) {
elem match {
Expand Down Expand Up @@ -336,13 +257,11 @@ private[pekko] final class FlattenConcat[T, M](parallelism: Int)
case Some(elem) => addCompletedFutureElem(elem)
case None => addPendingFutureElem(future)
}
case iterable: IterableSource[T] @unchecked => addSourceElements(iterable.elements.iterator)
case iterator: IteratorSource[T] @unchecked => addSourceElements(iterator.createIterator())
case range: RangeSource[T] @unchecked => addRangeSource(range.range)
case repeat: RepeatSource[T] @unchecked => addRepeatSource(repeat.elem)
case javaStream: JavaStreamSource[T, _] @unchecked =>
import scala.jdk.CollectionConverters._
addSourceElements(javaStream.open().iterator.asScala)
case iterable: IterableSource[T] @unchecked => addSourceElements(iterable.elements.iterator)
case iterator: IteratorSource[T] @unchecked => addSourceElements(iterator.createIterator())
case range: RangeSource[T] @unchecked => addRangeSource(range.range)
case repeat: RepeatSource[T] @unchecked => addRepeatSource(repeat.elem)
case javaStream: JavaStreamSource[T, _] @unchecked => addJavaStreamSource(javaStream)
case failed: FailedSource[T] @unchecked => addCompletedFutureElem(Failure(failed.failure))
case maybeEmpty if TraversalBuilder.isEmptySource(maybeEmpty) => // Empty source is discarded
case _ => attachAndMaterializeSource(source)
Expand Down
Loading