Skip to content

Commit

Permalink
+str Add flatmapConcat with parallelism.
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin committed Aug 10, 2023
1 parent 45e73c3 commit 9813a33
Show file tree
Hide file tree
Showing 11 changed files with 473 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import java.util.concurrent.TimeUnit

import scala.concurrent.Await
import scala.concurrent.duration._
import scala.concurrent.Future

import com.typesafe.config.ConfigFactory
import org.openjdk.jmh.annotations._
Expand Down Expand Up @@ -88,6 +89,18 @@ class FlatMapConcatBenchmark {
awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def completedFuture(): Unit = {
val latch = new CountDownLatch(1)

testSource
.flatMapConcat(n => Source.future(Future.successful(n)))
.runWith(new LatchSink(OperationsPerInvocation, latch))

awaitLatch(latch)
}

@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def mapBaseline(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright (C) 2014-2023 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.stream.scaladsl

import akka.stream.OverflowStrategy
import akka.stream.testkit._
import akka.stream.testkit.scaladsl.TestSink

import java.util.concurrent.ThreadLocalRandom
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.Future
import scala.concurrent.duration.DurationInt
import scala.util.control.NoStackTrace

class FlowFlatMapConcatSpec extends StreamSpec("""
akka.stream.materializer.initial-input-buffer-size = 2
""") with ScriptedTest {
val toSeq = Flow[Int].grouped(1000).toMat(Sink.head)(Keep.right)

class BoomException extends RuntimeException("BOOM~~") with NoStackTrace
"A flatMapConcat" must {

"work with value presented sources" in {
Source(
List(
Source.empty[Int],
Source.single(1),
Source.empty[Int],
Source(List(2, 3, 4)),
Source.future(Future.successful(5)),
Source.lazyFuture(() => Future.successful(6))))
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
.runWith(toSeq)
.futureValue should ===(1 to 6)
}

"work with value presented failed sources" in {
val ex = new BoomException
Source(
List(
Source.empty[Int],
Source.single(1),
Source.empty[Int],
Source(List(2, 3, 4)),
Source.future(Future.failed(ex)),
Source.lazyFuture(() => Future.successful(5))))
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
.onErrorComplete[BoomException]()
.runWith(toSeq)
.futureValue should ===(1 to 4)
}

"work with value presented sources when demands slow" in {
val prob = Source(
List(Source.empty[Int], Source.single(1), Source(List(2, 3, 4)), Source.lazyFuture(() => Future.successful(5))))
.flatMapConcat(ThreadLocalRandom.current().nextInt(1, 129), identity)
.runWith(TestSink())

prob.request(1)
prob.expectNext(1)
prob.expectNoMessage(1.seconds)
prob.request(2)
prob.expectNext(2, 3)
prob.expectNoMessage(1.seconds)
prob.request(2)
prob.expectNext(4, 5)
prob.expectComplete()
}

"can do pre materialization when parallelism > 1" in {
val materializationCounter = new AtomicInteger(0)
val randomParallelism = ThreadLocalRandom.current().nextInt(4, 65)
val prob = Source(1 to (randomParallelism * 3))
.flatMapConcat(
randomParallelism,
value => {
Source
.lazySingle(() => {
materializationCounter.incrementAndGet()
value
})
.buffer(1, overflowStrategy = OverflowStrategy.backpressure)
})
.runWith(TestSink())

expectNoMessage(1.seconds)
materializationCounter.get() shouldBe 0

prob.request(1)
prob.expectNext(1.seconds, 1)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe (randomParallelism + 1)
materializationCounter.set(0)

prob.request(2)
prob.expectNextN(List(2, 3))
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe 2
materializationCounter.set(0)

prob.request(randomParallelism - 3)
prob.expectNextN(4 to randomParallelism)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe (randomParallelism - 3)
materializationCounter.set(0)

prob.request(randomParallelism)
prob.expectNextN(randomParallelism + 1 to randomParallelism * 2)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe randomParallelism
materializationCounter.set(0)

prob.request(randomParallelism)
prob.expectNextN(randomParallelism * 2 + 1 to randomParallelism * 3)
expectNoMessage(1.seconds)
materializationCounter.get() shouldBe 0
prob.expectComplete()
}

}

}
1 change: 1 addition & 0 deletions akka-stream/src/main/scala/akka/stream/impl/Stages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ import akka.stream.Attributes._
val mergePreferred = name("mergePreferred")
val mergePrioritized = name("mergePrioritized")
val flattenMerge = name("flattenMerge")
val flattenConcat = name("flattenConcat")
val recoverWith = name("recoverWith")
val onErrorComplete = name("onErrorComplete")
val broadcast = name("broadcast")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ package akka.stream.impl

import scala.collection.immutable.Map.Map1
import scala.language.existentials

import akka.annotation.{ DoNotInherit, InternalApi }
import akka.stream._
import akka.stream.impl.StreamLayout.AtomicModule
import akka.stream.impl.TraversalBuilder.{ AnyFunction1, AnyFunction2 }
import akka.stream.impl.fusing.GraphStageModule
import akka.stream.impl.fusing.GraphStages.IterableSource
import akka.stream.impl.fusing.GraphStages.SingleSource
import akka.stream.scaladsl.Keep
import akka.util.OptionVal
Expand Down Expand Up @@ -371,6 +371,37 @@ import akka.util.unused
}
}

def getValuePresentedSource[A >: Null](graph: Graph[SourceShape[A], _]): OptionVal[Graph[SourceShape[A], _]] = {
def isValuePresentedSource(graph: Graph[SourceShape[_ <: A], _]): Boolean = graph match {
case _: SingleSource[_] | _: IterableSource[_] | EmptySource => true
case _ => false
}
graph match {
case _ if isValuePresentedSource(graph) => OptionVal.Some(graph)
case _ =>
graph.traversalBuilder match {
case l: LinearTraversalBuilder =>
l.pendingBuilder match {
case OptionVal.Some(a: AtomicTraversalBuilder) =>
a.module match {
case m: GraphStageModule[_, _] =>
m.stage match {
case _ if isValuePresentedSource(m.stage.asInstanceOf[Graph[SourceShape[A], _]]) =>
// It would be != EmptyTraversal if mapMaterializedValue was used and then we can't optimize.
if ((l.traversalSoFar eq EmptyTraversal) && !l.attributes.isAsync)
OptionVal.Some(m.stage.asInstanceOf[Graph[SourceShape[A], _]])
else OptionVal.None
case _ => OptionVal.None
}
case _ => OptionVal.None
}
case _ => OptionVal.None
}
case _ => OptionVal.None
}
}
}

/**
* Test if a Graph is an empty Source.
* */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,7 @@ private[stream] object Collect {
*/
@InternalApi private[akka] final case class MapAsync[In, Out](parallelism: Int, f: In => Future[Out])
extends GraphStage[FlowShape[In, Out]] {
require(parallelism >= 1, "parallelism should >= 1")

import MapAsync._

Expand Down

0 comments on commit 9813a33

Please sign in to comment.