From e0da12368e13549081dfa7095eb2b2dfda79bb87 Mon Sep 17 00:00:00 2001 From: Johannes Rudolph Date: Thu, 12 Nov 2020 15:37:56 +0100 Subject: [PATCH] http2: more robust handling of IncomingStreamBuffer.onPull --- .../engine/http2/Http2StreamHandling.scala | 61 ++++++++++++------- .../impl/engine/http2/Http2ServerSpec.scala | 24 ++++++++ 2 files changed, 63 insertions(+), 22 deletions(-) diff --git a/akka-http-core/src/main/scala/akka/http/impl/engine/http2/Http2StreamHandling.scala b/akka-http-core/src/main/scala/akka/http/impl/engine/http2/Http2StreamHandling.scala index 58f04beaf60..2e802723977 100644 --- a/akka-http-core/src/main/scala/akka/http/impl/engine/http2/Http2StreamHandling.scala +++ b/akka-http-core/src/main/scala/akka/http/impl/engine/http2/Http2StreamHandling.scala @@ -137,6 +137,10 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper def pullNextFrame(streamId: Int, maxSize: Int): PullFrameResult = updateStateAndReturn(streamId, _.pullNextFrame(maxSize)) + /** Entry-point to handle IncomingStreamBuffer.onPull through the state machine */ + def incomingStreamPulled(streamId: Int): Unit = + updateState(streamId, _.incomingStreamPulled()) + private def updateAllStates(handle: StreamState => StreamState): Unit = streamStates.keys.foreach(updateState(_, handle)) @@ -257,6 +261,7 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper } def pullNextFrame(maxSize: Int): (StreamState, PullFrameResult) = throw new IllegalStateException(s"pullNextFrame not supported in state $stateName") + def incomingStreamPulled(): StreamState = throw new IllegalStateException(s"incomingStreamPulled not supported in state $stateName") /** Called to cleanup any state when the connection is torn down */ def shutdown(): Unit = () @@ -359,8 +364,8 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper outstandingConnectionLevelWindow += windowSizeIncrement } - buffer.onDataFrame(d).getOrElse( - maybeFinishStream(d.endStream)) + buffer.onDataFrame(d) + afterBufferEvent } case r: RstStreamFrame => buffer.onRstStreamFrame(r) @@ -369,8 +374,7 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper case h: ParsedHeadersFrame => buffer.onTrailingHeaders(h) - - maybeFinishStream(h.endStream) + afterBufferEvent case w: WindowUpdateFrame => incrementWindow(w.windowSizeIncrement) @@ -379,8 +383,10 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper } protected def onReset(streamId: Int): Unit - protected def maybeFinishStream(endStream: Boolean): StreamState = - if (endStream) afterEndStreamReceived else this + override def incomingStreamPulled(): StreamState = { + buffer.dispatchNextChunk() + afterBufferEvent + } override def shutdown(): Unit = { buffer.shutdown() @@ -388,6 +394,8 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper } def incrementWindow(delta: Int): StreamState + + def afterBufferEvent: StreamState = if (buffer.isDone) afterEndStreamReceived else this } // on the incoming side there's (almost) no difference between Open and HalfClosedLocal @@ -473,31 +481,38 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper private var outstandingStreamWindow: Int = Http2Protocol.InitialWindowSize // adapt if we negotiate greater sizes by settings outlet.setHandler(this) - def onPull(): Unit = dispatchNextChunk() + def onPull(): Unit = incomingStreamPulled(streamId) override def onDownstreamFinish(): Unit = { debug(s"Incoming side of stream [$streamId]: cancelling because downstream finished") multiplexer.pushControlFrame(RstStreamFrame(streamId, ErrorCode.CANCEL)) // FIXME: go through state machine and don't manipulate vars directly here streamStates -= streamId + wasClosed = true + buffer = ByteString.empty + trailingHeaders = None } - def onDataFrame(data: DataFrame): Option[StreamState] = { - if (data.endStream) wasClosed = true + def isDone: Boolean = outlet.isClosed - outstandingStreamWindow -= data.sizeInWindow - if (outstandingStreamWindow < 0) { + def onDataFrame(data: DataFrame): Unit = + if (wasClosed) { shutdown() - multiplexer.pushControlFrame(RstStreamFrame(streamId, ErrorCode.FLOW_CONTROL_ERROR)) - // also close response delivery if that has already started - multiplexer.closeStream(streamId) - Some(Closed) + pushGOAWAY(ErrorCode.PROTOCOL_ERROR, s"Received unexpected DATA frame after stream was already (half-)closed") } else { - buffer ++= data.payload - debug(s"Received DATA ${data.sizeInWindow} for stream [$streamId], remaining window space now $outstandingStreamWindow, buffered: ${buffer.size}") - dispatchNextChunk() - None // don't change state + if (data.endStream) wasClosed = true + + outstandingStreamWindow -= data.sizeInWindow + if (outstandingStreamWindow < 0) { + shutdown() + multiplexer.pushControlFrame(RstStreamFrame(streamId, ErrorCode.FLOW_CONTROL_ERROR)) + // also close response delivery if that has already started + multiplexer.closeStream(streamId) + } else { + buffer ++= data.payload + debug(s"Received DATA ${data.sizeInWindow} for stream [$streamId], remaining window space now $outstandingStreamWindow, buffered: ${buffer.size}") + dispatchNextChunk() + } } - } def onTrailingHeaders(headers: ParsedHeadersFrame): Unit = { trailingHeaders = wrapTrailingHeaders(headers) if (headers.endStream) @@ -513,7 +528,7 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper wasClosed = true } - private def dispatchNextChunk(): Unit = { + def dispatchNextChunk(): Unit = { if (buffer.nonEmpty && outlet.isAvailable) { val dataSize = buffer.size min settings.requestEntityChunkSize outlet.push(wrapData(buffer.take(dataSize))) @@ -535,6 +550,7 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper case None => outlet.complete() } + } } @@ -558,7 +574,8 @@ private[http2] trait Http2StreamHandling { self: GraphStageLogic with LogHelper s"remaining connection window space now $outstandingConnectionLevelWindow, total buffered: $totalBufferedData") } - def shutdown(): Unit = outlet.fail(Http2StreamHandling.ConnectionWasAbortedException) + def shutdown(): Unit = + if (!outlet.isClosed) outlet.fail(Http2StreamHandling.ConnectionWasAbortedException) } trait OutStream { diff --git a/akka-http2-support/src/test/scala/akka/http/impl/engine/http2/Http2ServerSpec.scala b/akka-http2-support/src/test/scala/akka/http/impl/engine/http2/Http2ServerSpec.scala index e0c06b52a76..aa4cb567641 100644 --- a/akka-http2-support/src/test/scala/akka/http/impl/engine/http2/Http2ServerSpec.scala +++ b/akka-http2-support/src/test/scala/akka/http/impl/engine/http2/Http2ServerSpec.scala @@ -414,6 +414,30 @@ class Http2ServerSpec extends AkkaSpecWithMaterializer(""" sendFrame(DataFrame(TheStreamId, endStream = false, ByteString("0" * 512001))) // more than default `incoming-stream-level-buffer-size = 512kB` expectRST_STREAM(TheStreamId, ErrorCode.FLOW_CONTROL_ERROR) } + "not leak stream if request entity is not fully pulled when connection dies" inAssertAllStagesStopped new WaitingForRequestData { + sendDATA(TheStreamId, endStream = false, ByteString("0000")) + entityDataIn.expectUtf8EncodedString("0000") + pollForWindowUpdates(500.millis) + + sendDATA(TheStreamId, endStream = false, ByteString("1111")) + sendDATA(TheStreamId, endStream = true, ByteString.empty) + + // DATA is left in IncomingStreamBuffer because we never pulled + // test infra closes connection + } + "fail if DATA frame arrives after incoming stream has already been closed (before response was sent)" inAssertAllStagesStopped new WaitingForRequestData { + sendDATA(TheStreamId, endStream = false, ByteString("0000")) + entityDataIn.expectUtf8EncodedString("0000") + pollForWindowUpdates(500.millis) + + sendDATA(TheStreamId, endStream = false, ByteString("1111")) + sendDATA(TheStreamId, endStream = true, ByteString.empty) // close stream + + // now send more DATA: checks that we have moved into a state where DATA is not expected any more + sendDATA(TheStreamId, endStream = false, ByteString("more data")) + val (_, errorCode) = expectGOAWAY() + errorCode shouldEqual ErrorCode.PROTOCOL_ERROR + } "fail entity stream if advertised content-length doesn't match" in pending }