Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

+str Add flatmapConcat with parallelism. #32024

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import java.util.concurrent.TimeUnit

import scala.concurrent.Await
import scala.concurrent.duration._
import scala.concurrent.Future

import com.typesafe.config.ConfigFactory
import org.openjdk.jmh.annotations._
Expand Down Expand Up @@ -88,6 +89,18 @@ class FlatMapConcatBenchmark {
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def completedFuture(): Unit = {
val latch = new CountDownLatch(1)

testSource
.flatMapConcat(n => Source.future(Future.successful(n)))
.runWith(new LatchSink(OperationsPerInvocation, latch))

awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def mapBaseline(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright (C) 2014-2023 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.stream.scaladsl

import akka.stream.OverflowStrategy
import akka.stream.testkit._
import akka.stream.testkit.scaladsl.TestSink

import java.util.concurrent.ThreadLocalRandom
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.Future
import scala.concurrent.duration.DurationInt
import scala.util.control.NoStackTrace

class FlowFlatMapConcatSpec extends StreamSpec("""
akka.stream.materializer.initial-input-buffer-size = 2
""") with ScriptedTest {
val toSeq = Flow[Int].grouped(1000).toMat(Sink.head)(Keep.right)

class BoomException extends RuntimeException("BOOM~~") with NoStackTrace
"A flatMapConcat" must {

"work with value presented sources" in {
Source(
List(
Source.empty[Int],
Source.single(1),
Source.empty[Int],
Source(List(2, 3, 4)),
Source.future(Future.successful(5)),
Source.lazyFuture(() => Future.successful(6))))
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random is good for trying it out while developing but means that we'd have no idea what value it fails for if it fails. Parameterize over a few values or use a single random chosen and logged on test class instantiation instead, so failures can be repeated.

.runWith(toSeq)
.futureValue should ===(1 to 6)
}

"work with value presented failed sources" in {
val ex = new BoomException
Source(
List(
Source.empty[Int],
Source.single(1),
Source.empty[Int],
Source(List(2, 3, 4)),
Source.future(Future.failed(ex)),
Source.lazyFuture(() => Future.successful(5))))
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
.onErrorComplete[BoomException]()
.runWith(toSeq)
.futureValue should ===(1 to 4)
}

"work with value presented sources when demands slow" in {
val prob = Source(
List(Source.empty[Int], Source.single(1), Source(List(2, 3, 4)), Source.lazyFuture(() => Future.successful(5))))
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
.runWith(TestSink())

prob.request(1)
prob.expectNext(1)
prob.expectNoMessage(1.seconds)
prob.request(2)
prob.expectNext(2, 3)
prob.expectNoMessage(1.seconds)
prob.request(2)
prob.expectNext(4, 5)
prob.expectComplete()
}

"can do pre materialization when parallelism > 1" in {
val materializationCounter = new AtomicInteger(0)
val randomParallelism = ThreadLocalRandom.current().nextInt(4, 65)
val prob = Source(1 to (randomParallelism * 3))
.flatMapConcat(
randomParallelism,
value => {
Source
.lazySingle(() => {
materializationCounter.incrementAndGet()
value
})
.buffer(1, overflowStrategy = OverflowStrategy.backpressure)
})
.runWith(TestSink())

expectNoMessage(1.seconds)
materializationCounter.get() shouldBe 0

prob.request(1)
prob.expectNext(1.seconds, 1)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe (randomParallelism + 1)
materializationCounter.set(0)

prob.request(2)
prob.expectNextN(List(2, 3))
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe 2
materializationCounter.set(0)

prob.request(randomParallelism - 3)
prob.expectNextN(4 to randomParallelism)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe (randomParallelism - 3)
materializationCounter.set(0)

prob.request(randomParallelism)
prob.expectNextN(randomParallelism + 1 to randomParallelism * 2)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe randomParallelism
materializationCounter.set(0)

prob.request(randomParallelism)
prob.expectNextN(randomParallelism * 2 + 1 to randomParallelism * 3)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe 0
prob.expectComplete()
}

}

}
1 change: 1 addition & 0 deletions akka-stream/src/main/scala/akka/stream/impl/Stages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ import akka.stream.Attributes._
val mergePreferred = name("mergePreferred")
val mergePrioritized = name("mergePrioritized")
val flattenMerge = name("flattenMerge")
val flattenConcat = name("flattenConcat")
val recoverWith = name("recoverWith")
val onErrorComplete = name("onErrorComplete")
val broadcast = name("broadcast")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ package akka.stream.impl

import scala.collection.immutable.Map.Map1
import scala.language.existentials

import akka.annotation.{ DoNotInherit, InternalApi }
import akka.stream._
import akka.stream.impl.StreamLayout.AtomicModule
import akka.stream.impl.TraversalBuilder.{ AnyFunction1, AnyFunction2 }
import akka.stream.impl.fusing.GraphStageModule
import akka.stream.impl.fusing.GraphStages.IterableSource
import akka.stream.impl.fusing.GraphStages.SingleSource
import akka.stream.scaladsl.Keep
import akka.util.OptionVal
Expand Down Expand Up @@ -370,6 +370,37 @@ import akka.util.unused
}
}

def getValuePresentedSource[A >: Null](graph: Graph[SourceShape[A], _]): OptionVal[Graph[SourceShape[A], _]] = {
def isValuePresentedSource(graph: Graph[SourceShape[_ <: A], _]): Boolean = graph match {
case _: SingleSource[_] | _: IterableSource[_] | EmptySource => true
case _ => false
}
graph match {
case _ if isValuePresentedSource(graph) => OptionVal.Some(graph)
case _ =>
graph.traversalBuilder match {
case l: LinearTraversalBuilder =>
l.pendingBuilder match {
case OptionVal.Some(a: AtomicTraversalBuilder) =>
a.module match {
case m: GraphStageModule[_, _] =>
m.stage match {
case _ if isValuePresentedSource(m.stage.asInstanceOf[Graph[SourceShape[A], _]]) =>
// It would be != EmptyTraversal if mapMaterializedValue was used and then we can't optimize.
if ((l.traversalSoFar eq EmptyTraversal) && !l.attributes.isAsync)
OptionVal.Some(m.stage.asInstanceOf[Graph[SourceShape[A], _]])
else OptionVal.None
case _ => OptionVal.None
}
case _ => OptionVal.None
}
case _ => OptionVal.None
}
case _ => OptionVal.None
}
}
}

/**
* Test if a Graph is an empty Source.
* */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,7 @@ private[stream] object Collect {
*/
@InternalApi private[akka] final case class MapAsync[In, Out](parallelism: Int, f: In => Future[Out])
extends GraphStage[FlowShape[In, Out]] {
require(parallelism >= 1, "parallelism should >= 1")

import MapAsync._

Expand Down