Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Http2 incoming-side stream state machine #1064

Merged
merged 3 commits into from
Apr 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ private[http2] trait GenericOutlet[T] {
def setHandler(handler: OutHandler): Unit
def push(elem: T): Unit
def complete(): Unit
def fail(cause: Throwable): Unit
def canBePushed: Boolean
}

Expand Down Expand Up @@ -47,6 +48,11 @@ private[http2] class BufferedOutlet[T](outlet: GenericOutlet[T]) extends OutHand
if (buffer.isEmpty) outlet.complete()
}

def fail(cause: Throwable): Unit = {
buffer.clear()
outlet.fail(cause)
}

def tryFlush(): Unit = {
if (outlet.canBePushed && !buffer.isEmpty)
doPush(buffer.pop())
Expand Down Expand Up @@ -97,13 +103,15 @@ private[http2] trait GenericOutletSupport { logic: GraphStageLogic ⇒
def setHandler(handler: OutHandler): Unit = subSourceOutlet.setHandler(handler)
def push(elem: T): Unit = subSourceOutlet.push(elem)
def complete(): Unit = subSourceOutlet.complete()
def fail(cause: Throwable): Unit = subSourceOutlet.fail(cause)
def canBePushed: Boolean = subSourceOutlet.isAvailable
}
implicit def fromOutlet[T](outlet: Outlet[T]): GenericOutlet[T] =
new GenericOutlet[T] {
def setHandler(handler: OutHandler): Unit = logic.setHandler(outlet, handler)
def push(elem: T): Unit = logic.emit(outlet, elem)
def complete(): Unit = logic.complete(outlet)
def fail(cause: Throwable): Unit = logic.fail(outlet, cause)
def canBePushed: Boolean = logic.isAvailable(outlet)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import akka.util.ByteString

import scala.collection.immutable

sealed trait FrameEvent
sealed trait StreamFrameEvent extends FrameEvent {
sealed trait FrameEvent { self: Product ⇒
def frameTypeName: String = productPrefix
}
sealed trait StreamFrameEvent extends FrameEvent { self: Product ⇒
def streamId: Int
}

Expand Down Expand Up @@ -79,4 +81,4 @@ final case class ParsedHeadersFrame(
endStream: Boolean,
keyValuePairs: Seq[(String, String)],
priorityInfo: Option[PriorityFrame]
) extends FrameEvent
) extends StreamFrameEvent
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
package akka.http.impl.engine.http2

import akka.NotUsed
import akka.http.impl.engine.http2.Http2Compliance.Http2ProtocolException
import akka.http.impl.engine.http2.Http2Protocol.ErrorCode
import akka.http.impl.engine.http2.Http2Protocol.ErrorCode.{ COMPRESSION_ERROR, FLOW_CONTROL_ERROR, FRAME_SIZE_ERROR }
import akka.http.scaladsl.model.http2.{ Http2Exception, PeerClosedStreamException }
import akka.stream.Attributes
import akka.stream.BidiShape
import akka.stream.Inlet
Expand All @@ -17,7 +17,7 @@ import akka.stream.scaladsl.Source
import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, StageLogging }
import akka.util.ByteString

import scala.collection.mutable
import scala.collection.immutable
import scala.util.control.NonFatal

/**
Expand Down Expand Up @@ -68,9 +68,6 @@ import scala.util.control.NonFatal
* only available in this stage.
*/
class Http2ServerDemux extends GraphStage[BidiShape[Http2SubStream, FrameEvent, FrameEvent, Http2SubStream]] {

import Http2ServerDemux._

val frameIn = Inlet[FrameEvent]("Demux.frameIn")
val frameOut = Outlet[FrameEvent]("Demux.frameOut")

Expand All @@ -81,14 +78,10 @@ class Http2ServerDemux extends GraphStage[BidiShape[Http2SubStream, FrameEvent,
BidiShape(substreamIn, frameOut, frameIn, substreamOut)

def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with GenericOutletSupport with Http2MultiplexerSupport with StageLogging {
new GraphStageLogic(shape) with Http2MultiplexerSupport with Http2StreamHandling with GenericOutletSupport with StageLogging {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, I like these kinds of splitting up with traits 👍

logic ⇒

final case class SubStream(
streamId: Int,
state: StreamState,
outlet: Option[BufferedOutlet[ByteString]]
)
override protected def logSource: Class[_] = classOf[Http2ServerDemux]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


val multiplexer = createMultiplexer(frameOut, StreamPrioritizer.first())

Expand All @@ -101,15 +94,15 @@ class Http2ServerDemux extends GraphStage[BidiShape[Http2SubStream, FrameEvent,

// we should not handle streams later than the GOAWAY told us about with lastStreamId
private var closedAfter: Option[Int] = None
private var incomingStreams = mutable.Map.empty[Int, SubStream]
private var maxConcurrentStreams: Option[Int] = None

/**
* The "last peer-initiated stream that was or might be processed on the sending endpoint in this connection"
* @see http://httpwg.org/specs/rfc7540.html#rfc.section.6.8
*
* We currently don't support tracking that value accurately.
* TODO: track more accurately
*/
def lastStreamId: Int = {
incomingStreams.keys.toList.sortBy(-_).headOption.getOrElse(0) // FIXME should be optimised
}
def lastStreamId: Int = 1

def pushGOAWAY(errorCode: ErrorCode, debug: String): Unit = {
// http://httpwg.org/specs/rfc7540.html#rfc.section.6.8
Expand All @@ -125,62 +118,9 @@ class Http2ServerDemux extends GraphStage[BidiShape[Http2SubStream, FrameEvent,
def onPush(): Unit = {
val in = grab(frameIn)
in match {
case WindowUpdateFrame(streamId, increment) ⇒ multiplexer.updateWindow(streamId, increment)

case priorityInfo: PriorityFrame ⇒
multiplexer.updatePriority(priorityInfo)

case e: StreamFrameEvent if !Http2Compliance.isClientInitiatedStreamId(e.streamId) ⇒
pushGOAWAY(ErrorCode.PROTOCOL_ERROR, "Not a valid client initiated stream id! Was: " + e.streamId)

case e: StreamFrameEvent if e.streamId > closedAfter.getOrElse(Int.MaxValue) ⇒
// streams that will have a greater stream id than the one we sent with GOAWAY will be ignored

case frame @ ParsedHeadersFrame(streamId, endStream, headers, prioInfo) if lastStreamId < streamId ⇒
// TODO: process priority information
val (data: Source[ByteString, NotUsed], outlet: Option[BufferedOutlet[ByteString]]) =
if (endStream) (Source.empty, None)
else {
val subSource = new SubSourceOutlet[ByteString](s"substream-out-$streamId")
(Source.fromGraph(subSource.source), Some(new BufferedOutlet[ByteString](subSource) {
override def onDownstreamFinish(): Unit =
// FIXME: when substream (= request entity) is cancelled, we need to RST_STREAM
// if the stream is finished and sent a RST_STREAM we can just remove the incoming stream from our map
incomingStreams.remove(streamId)
}))
}
val entry = SubStream(streamId, StreamState.Open /* FIXME stream state */ , outlet)
incomingStreams += streamId → entry // TODO optimise for lookup later on

dispatchSubstream(Http2SubStream(frame, data))

prioInfo.foreach(multiplexer.updatePriority)

case e: StreamFrameEvent if !incomingStreams.contains(e.streamId) ⇒
// if a stream is invalid we will GO_AWAY
pushGOAWAY(ErrorCode.PROTOCOL_ERROR, "Unknown stream id: " + e.streamId)

case h: ParsedHeadersFrame ⇒
if (h.endStream)
incomingStreams(h.streamId).outlet match {
case Some(outlet) ⇒ outlet.complete()
case None ⇒ failSubstream(h.streamId, ErrorCode.STREAM_CLOSED, "Got HEADERS frame on closed stream")
}
// else just ignore intermediate HEADERS frames

case DataFrame(streamId, endStream, payload) ⇒
// technically this case is the same as StreamFrameEvent, however we're handling it earlier in the match here for efficiency
// pushing http entity, TODO: handle flow control from here somehow?
incomingStreams(streamId).outlet match {
case Some(outlet) ⇒
outlet.push(payload)
if (endStream) outlet.complete()
case None ⇒ failSubstream(streamId, ErrorCode.STREAM_CLOSED, "Got DATA frame on closed stream")
}

case RstStreamFrame(streamId, errorCode) ⇒
// FIXME: also need to handle the other case when no response has been produced yet (inlet still None)
multiplexer.cancelSubStream(streamId)
case WindowUpdateFrame(streamId, increment) ⇒ multiplexer.updateWindow(streamId, increment) // handled specially
case p: PriorityFrame ⇒ multiplexer.updatePriority(p)
case s: StreamFrameEvent ⇒ handleStreamEvent(s)

case SettingsFrame(settings) ⇒
if (settings.nonEmpty) debug(s"Got ${settings.length} settings!")
Expand All @@ -200,7 +140,6 @@ class Http2ServerDemux extends GraphStage[BidiShape[Http2SubStream, FrameEvent,
multiplexer.updateMaxFrameSize(value)
case Setting(Http2Protocol.SettingIdentifier.SETTINGS_MAX_CONCURRENT_STREAMS, value) ⇒
debug(s"Setting max concurrent streams to $value (not enforced)")
maxConcurrentStreams = Some(value)
case Setting(id, value) ⇒
debug(s"Ignoring setting $id -> $value (in Demux)")
}
Expand Down Expand Up @@ -231,8 +170,7 @@ class Http2ServerDemux extends GraphStage[BidiShape[Http2SubStream, FrameEvent,
pushGOAWAY(e.errorCode, e.getMessage)

case e: Http2Compliance.Http2ProtocolStreamException ⇒
incomingStreams.remove(e.streamId)
multiplexer.pushControlFrame(RstStreamFrame(e.streamId, e.errorCode))
resetStream(e.streamId, e.errorCode)

case e: ParsingException ⇒
e.getCause match {
Expand Down Expand Up @@ -267,18 +205,3 @@ class Http2ServerDemux extends GraphStage[BidiShape[Http2SubStream, FrameEvent,
}

}

object Http2ServerDemux {
sealed trait StreamState
object StreamState {
case object Idle extends StreamState
case object Open extends StreamState
case object Closed extends StreamState
case object HalfClosedLocal extends StreamState
case object HalfClosedRemote extends StreamState

// for PUSH_PROMISE
// case object ReservedLocal extends StreamState
// case object ReservedRemote extends StreamState
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (C) 2009-2017 Lightbend Inc. <http://www.lightbend.com>
*/

package akka.http.impl.engine.http2

import akka.NotUsed
import akka.annotation.InternalApi
import akka.http.impl.engine.http2.Http2Protocol.ErrorCode
import akka.http.scaladsl.model.http2.PeerClosedStreamException
import akka.stream.scaladsl.Source
import akka.stream.stage.{ GraphStageLogic, StageLogging }
import akka.util.ByteString

import scala.collection.immutable

/** INTERNAL API */
@InternalApi
private[http2] trait Http2StreamHandling { self: GraphStageLogic with GenericOutletSupport with StageLogging ⇒
// required API from demux
def multiplexer: Http2Multiplexer
def pushGOAWAY(errorCode: ErrorCode, debug: String): Unit
def dispatchSubstream(sub: Http2SubStream): Unit

private var incomingStreams = new immutable.TreeMap[Int, IncomingStreamState]
private var largestIncomingStreamId = 0
private def streamFor(streamId: Int): IncomingStreamState =
incomingStreams.get(streamId) match {
case Some(state) ⇒ state
case None ⇒
if (streamId <= largestIncomingStreamId) Closed // closed streams are never put into the map
else {
largestIncomingStreamId = streamId
incomingStreams += streamId → Idle
Idle
}
}
def handleStreamEvent(e: StreamFrameEvent): Unit = {
val newState = streamFor(e.streamId).handle(e)
if (newState == Closed) incomingStreams -= e.streamId
else incomingStreams += e.streamId → newState
}
def resetStream(streamId: Int, errorCode: ErrorCode): Unit = {
// FIXME: put this stream into an extra state where we allow some frames still to be received
incomingStreams -= streamId
multiplexer.pushControlFrame(RstStreamFrame(streamId, errorCode))
}

sealed abstract class IncomingStreamState { _: Product ⇒
def handle(event: StreamFrameEvent): IncomingStreamState

def stateName: String = productPrefix
def receivedUnexpectedFrame(e: StreamFrameEvent): IncomingStreamState = {
pushGOAWAY(ErrorCode.PROTOCOL_ERROR, s"Received unexpected frame of type ${e.frameTypeName} for stream ${e.streamId} in state $stateName")
Closed
}
}
case object Idle extends IncomingStreamState {
def handle(event: StreamFrameEvent): IncomingStreamState = event match {
case frame @ ParsedHeadersFrame(streamId, endStream, headers, prioInfo) ⇒
val (data: Source[ByteString, NotUsed], nextState) =
if (endStream) (Source.empty, HalfClosedRemote)
else {
val subSource = new SubSourceOutlet[ByteString](s"substream-out-$streamId")
(Source.fromGraph(subSource.source), Open(new BufferedOutlet[ByteString](subSource) {
override def onDownstreamFinish(): Unit =
// FIXME: when substream (= request entity) is cancelled, we need to RST_STREAM
// if the stream is finished and sent a RST_STREAM we can just remove the incoming stream from our map
incomingStreams -= streamId
}))
}

// FIXME: after multiplexer PR is merged
// prioInfo.foreach(multiplexer.updatePriority)
dispatchSubstream(Http2SubStream(frame, data))
nextState

case x ⇒ receivedUnexpectedFrame(x)
}
}
sealed abstract class ReceivingData(outlet: BufferedOutlet[ByteString], afterEndStreamReceived: IncomingStreamState) extends IncomingStreamState { _: Product ⇒
def handle(event: StreamFrameEvent): IncomingStreamState = event match {
case d: DataFrame ⇒
outlet.push(d.payload)
maybeFinishStream(d.endStream)
case r: RstStreamFrame ⇒
outlet.fail(new PeerClosedStreamException(r.streamId, r.errorCode.toString))
multiplexer.cancelSubStream(r.streamId)
Closed

case h: ParsedHeadersFrame ⇒
// ignored
log.debug(s"Ignored intermediate HEADERS frame: $h")

maybeFinishStream(h.endStream)
}

protected def maybeFinishStream(endStream: Boolean): IncomingStreamState =
if (endStream) {
outlet.complete()
afterEndStreamReceived
} else this
}
// on the incoming side there's (almost) no difference between Open and HalfClosedLocal
case class Open(outlet: BufferedOutlet[ByteString]) extends ReceivingData(outlet, HalfClosedRemote)
case class HalfClosedLocal(outlet: BufferedOutlet[ByteString]) extends ReceivingData(outlet, Closed)
case object HalfClosedRemote extends IncomingStreamState {
def handle(event: StreamFrameEvent): IncomingStreamState = event match {
case r: RstStreamFrame ⇒
multiplexer.cancelSubStream(r.streamId)
Closed
case _ ⇒ receivedUnexpectedFrame(event)
}
}
case object Closed extends IncomingStreamState {
def handle(event: StreamFrameEvent): IncomingStreamState = receivedUnexpectedFrame(event)
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nicely modeled states 👍

// needed once PUSH_PROMISE support was added
//case object ReservedLocal extends IncomingStreamState
//case object ReservedRemote extends IncomingStreamState
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright (C) 2009-2017 Lightbend Inc. <http://www.lightbend.com>
*/

package akka.http.scaladsl.model.http2

import scala.util.control.NoStackTrace

/**
* Base class for HTTP2 exceptions.
*/
class Http2Exception(msg: String) extends RuntimeException(msg)

/**
* Exception that will be reported on the request entity stream when the peer closed the stream.
*/
class PeerClosedStreamException(val streamId: Int, val errorCode: String)
extends Http2Exception(s"Stream with ID [$streamId] was closed by peer with code $errorCode") with NoStackTrace