From e2b16e2887aab16f07663ece76ae202e51c634c6 Mon Sep 17 00:00:00 2001 From: kerr Date: Fri, 28 Oct 2022 17:01:05 +0800 Subject: [PATCH] =str Avoid subMaterialization when the provided recover source is empty. (#31669) --- .../stream/scaladsl/FlowRecoverWithSpec.scala | 13 +++++++++++++ .../scala/akka/stream/impl/fusing/Ops.scala | 18 ++++++++++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowRecoverWithSpec.scala b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowRecoverWithSpec.scala index c5f2ea0530d..25cdd9b545f 100644 --- a/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowRecoverWithSpec.scala +++ b/akka-stream-tests/src/test/scala/akka/stream/scaladsl/FlowRecoverWithSpec.scala @@ -39,6 +39,19 @@ class FlowRecoverWithSpec extends StreamSpec { .expectComplete() } + "recover with empty source" in { + Source(1 to 4) + .map { a => + if (a == 3) throw ex else a + } + .recoverWith { case _: Throwable => Source.empty } + .runWith(TestSink[Int]()) + .request(2) + .expectNextN(1 to 2) + .request(1) + .expectComplete() + } + "cancel substream if parent is terminated when there is a handler" in { Source(1 to 4) .map { a => diff --git a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala index a5b225c40ca..15fb0d617e7 100644 --- a/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala +++ b/akka-stream/src/main/scala/akka/stream/impl/fusing/Ops.scala @@ -29,6 +29,7 @@ import akka.stream.OverflowStrategies._ import akka.stream.Supervision.Decider import akka.stream.impl.{ ContextPropagation, ReactiveStreamsCompliance, Buffer => BufferImpl } import akka.stream.impl.Stages.DefaultAttributes +import akka.stream.impl.TraversalBuilder import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage import akka.stream.scaladsl.{ DelayStrategy, Source } import akka.stream.stage._ @@ -2147,12 +2148,21 @@ private[akka] object TakeWithin { override def onPush(): Unit = push(out, grab(in)) override def onUpstreamFailure(ex: Throwable): Unit = onFailure(ex) override def onPull(): Unit = pull(in) - def onFailure(ex: Throwable): Unit = - if ((maximumRetries < 0 || attempt < maximumRetries) && pf.isDefinedAt(ex)) { - switchTo(pf(ex)) - attempt += 1 + def onFailure(ex: Throwable): Unit = { + import Collect.NotApplied + if (maximumRetries < 0 || attempt < maximumRetries) { + pf.applyOrElse(ex, NotApplied) match { + case NotApplied => failStage(ex) + case source: Graph[SourceShape[T] @unchecked, M @unchecked] if TraversalBuilder.isEmptySource(source) => + completeStage() + case other: Graph[SourceShape[T] @unchecked, M @unchecked] => + switchTo(other) + attempt += 1 + case _ => throw new IllegalStateException() // won't happen, compiler exhaustiveness check pleaser + } } else failStage(ex) + } def switchTo(source: Graph[SourceShape[T], M]): Unit = { val sinkIn = new SubSinkInlet[T]("RecoverWithSink")