Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fail queued requests on connection failure #1423 #1851

Merged
merged 3 commits into from Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,66 @@
/*
* Copyright (C) 2023 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.grpc.scaladsl

import akka.actor.ActorSystem
import akka.grpc.{ GrpcClientSettings, GrpcServiceException }
import akka.http.scaladsl.model.HttpResponse
import akka.testkit.TestKit
import com.typesafe.config.ConfigFactory
import example.myapp.helloworld.grpc.helloworld.{ GreeterServiceClient, HelloRequest }
import io.grpc.Status
import org.scalatest.Inspectors.forAll
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpecLike

import scala.concurrent.Future
import scala.concurrent.duration.DurationInt
import scala.util.{ Failure, Success }

class AkkaHttpClientConnectionFailSpec
extends TestKit(
ActorSystem(
"GrpcExceptionHandlerSpec",
ConfigFactory
.parseString("""
akka.grpc.client."*".backend = "akka-http"
akka.http.client.http2.max-persistent-attempts = 2
""".stripMargin)
.withFallback(ConfigFactory.load())))
with AnyWordSpecLike
with Matchers
with ScalaFutures {

"The Akka HTTP client backend" should {
"fail queued requests when connection fails" in {

// Note that the Akka HTTP client does not strictly adhere to the gRPC backoff protocol but has its own
// backoff algorithm
val client = GreeterServiceClient(GrpcClientSettings.connectToServiceAt("127.0.0.1", 5).withTls(false))

val futures = (1 to 10).map { _ =>
client.sayHello(HelloRequest())
}
// all should be failed
import system.dispatcher
val lifted = Future.sequence(futures.map(_.map(Success(_)).recover {
case th: Throwable => Failure[HttpResponse](th)
}))
johanandren marked this conversation as resolved.
Show resolved Hide resolved
val results = lifted.futureValue(timeout(5.seconds))
forAll(results) { it =>
it.isFailure should be(true)
it.failed.get match {
case ex: GrpcServiceException =>
ex.status.getCode shouldBe (Status.Code.UNAVAILABLE)
case unexpected =>
unexpected.printStackTrace()
fail(s"Exception ${unexpected} was not a GrpcServiceException")
}
}
}
}

}
90 changes: 73 additions & 17 deletions runtime/src/main/scala/akka/grpc/internal/AkkaHttpClientUtils.scala
Expand Up @@ -4,32 +4,62 @@

package akka.grpc.internal

import java.net.InetSocketAddress
import java.security.SecureRandom
import java.util.concurrent.CompletionStage
import scala.concurrent.duration._
import akka.{ Done, NotUsed }
import akka.Done
import akka.NotUsed
import akka.actor.ClassicActorSystemProvider
import akka.annotation.InternalApi
import akka.event.LoggingAdapter
import akka.grpc.GrpcClientSettings
import akka.grpc.GrpcProtocol.GrpcProtocolReader
import akka.grpc.{ GrpcClientSettings, GrpcResponseMetadata, GrpcSingleResponse, ProtobufSerializer }
import akka.grpc.GrpcResponseMetadata
import akka.grpc.GrpcServiceException
import akka.grpc.GrpcSingleResponse
import akka.grpc.ProtobufSerializer
import akka.grpc.scaladsl.StringEntry
import akka.http.scaladsl.model.HttpEntity.{ Chunk, Chunked, LastChunk, Strict }
import akka.http.scaladsl.{ ClientTransport, ConnectionContext, Http }
import akka.http.scaladsl.model._
import akka.http.scaladsl.ClientTransport
import akka.http.scaladsl.ConnectionContext
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.AttributeKey
import akka.http.scaladsl.model.AttributeKeys
import akka.http.scaladsl.model.HttpEntity.Chunk
import akka.http.scaladsl.model.HttpEntity.Chunked
import akka.http.scaladsl.model.HttpEntity.LastChunk
import akka.http.scaladsl.model.HttpEntity.Strict
import akka.http.scaladsl.model.HttpHeader
import akka.http.scaladsl.model.HttpRequest
import akka.http.scaladsl.model.HttpResponse
import akka.http.scaladsl.model.RequestResponseAssociation
import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.model.Uri
import akka.http.scaladsl.settings.ClientConnectionSettings
import akka.stream.{ Materializer, OverflowStrategy }
import akka.stream.scaladsl.{ Keep, Sink, Source }
import akka.stream.FlowShape
import akka.stream.Materializer
import akka.stream.OverflowStrategy
import akka.stream.scaladsl.Flow
import akka.stream.scaladsl.GraphDSL
import akka.stream.scaladsl.Keep
import akka.stream.scaladsl.Sink
import akka.stream.scaladsl.Source
import akka.util.ByteString
import io.grpc.{ CallOptions, MethodDescriptor, Status, StatusRuntimeException }
import io.grpc.CallOptions
import io.grpc.MethodDescriptor
import io.grpc.Status
import io.grpc.StatusRuntimeException

import javax.net.ssl.{ KeyManager, SSLContext, TrustManager }
import java.net.InetSocketAddress
import java.security.SecureRandom
import java.util.concurrent.CompletionStage
import javax.net.ssl.KeyManager
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManager
import scala.collection.immutable
import scala.compat.java8.FutureConverters.FutureOps
import scala.concurrent.{ ExecutionContext, Future, Promise }
import scala.util.{ Failure, Success }
import akka.http.scaladsl.model.StatusCodes
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.concurrent.duration._
import scala.util.Failure
import scala.util.Success

/**
* INTERNAL API
Expand Down Expand Up @@ -103,9 +133,29 @@ object AkkaHttpClientUtils {
builder.managedPersistentHttp2WithPriorKnowledge()
}

// make sure we always fail all queued on http client fail to connect
val cancelFailed: Flow[HttpRequest, HttpRequest, NotUsed] = {
Flow.fromGraph(GraphDSL.create() { implicit b =>
import GraphDSL.Implicits._

val switch = b.add(new SwitchOnCancel[HttpRequest])

// when failed over
val failover = b.add(Sink.foreach[(Throwable, HttpRequest)] {
case (error, request) =>
request.entity.discardBytes()
request.getAttribute(ResponsePromise.Key).get().promise.tryFailure(error)
})
switch.out1 ~> failover.in

FlowShape[HttpRequest, HttpRequest](switch.in, switch.out0)
})
}

val (queue, doneFuture) =
Source
.queue[HttpRequest](4242, OverflowStrategy.fail)
.via(cancelFailed)
.via(http2client)
.toMat(Sink.foreach { res =>
res.attribute(ResponsePromise.Key).get.promise.trySuccess(res)
Expand All @@ -114,7 +164,13 @@ object AkkaHttpClientUtils {

def singleRequest(request: HttpRequest): Future[HttpResponse] = {
val p = Promise[HttpResponse]()
queue.offer(request.addAttribute(ResponsePromise.Key, ResponsePromise(p))).flatMap(_ => p.future)
queue.offer(request.addAttribute(ResponsePromise.Key, ResponsePromise(p))).flatMap(_ => p.future).recover {
case ex: RuntimeException if ex.getMessage.contains("Connection failed") =>
throw new GrpcServiceException(
Status.UNAVAILABLE
.withCause(ex)
.withDescription(s"Connection to ${settings.serviceName}:${settings.defaultPort} failed"))
}
}

def serializerFromMethodDescriptor[I, O](descriptor: MethodDescriptor[I, O]): ProtobufSerializer[I] =
Expand Down
74 changes: 74 additions & 0 deletions runtime/src/main/scala/akka/grpc/internal/SwitchOnCancel.scala
@@ -0,0 +1,74 @@
/*
* Copyright (C) 2009-2023 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.grpc.internal

import akka.stream.Attributes
import akka.stream.FanOutShape2
import akka.stream.Inlet
import akka.stream.Outlet
import akka.stream.stage.GraphStage
import akka.stream.stage.GraphStageLogic
import akka.stream.stage.InHandler
import akka.stream.stage.OutHandler
import akka.util.OptionVal

/**
* Identity stage that feeds all incoming events to output 1 until it cancels, then switches over to output 2, after
* completing the materialized value future.
*
* INTERNAL API
* @tparam T
*/
final private[akka] class SwitchOnCancel[T] extends GraphStage[FanOutShape2[T, T, (Throwable, T)]] {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be useful enough that we should backport it to Akka proper. It's kind of like a Source.recover but in the other direction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the status of this PR? Should it be fixed in Akka or here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could keep it as a custom stage here and, maybe decide to port back to Akka as public operator in the future.


val in = Inlet[T]("in")
val mainOut = Outlet[T]("mainOut")
val failoverOut = Outlet[(Throwable, T)]("failoverOut")

override def shape: FanOutShape2[T, T, (Throwable, T)] = new FanOutShape2(in, mainOut, failoverOut)

override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
var failedOver: OptionVal[Throwable] = OptionVal.None

setHandler(
in,
new InHandler {
override def onPush(): Unit = {
val elem = grab(in)
failedOver match {
case OptionVal.Some(error) => push(failoverOut, (error, elem))
case _ => push(mainOut, elem)
}

}
})

setHandler(
mainOut,
new OutHandler {
override def onPull(): Unit =
pull(in)

override def onDownstreamFinish(cause: Throwable): Unit = {
// on downstream cancel or failure switch to second out
failedOver = OptionVal.Some(cause)
if (isAvailable(failoverOut) && !hasBeenPulled(in)) {
pull(in)
}
}
})

setHandler(
failoverOut,
new OutHandler {
override def onPull(): Unit = {
// may have been pulled and then failed over
if (!hasBeenPulled(in)) pull(in)
}
})

}

}