diff --git a/src/h2/connection.py b/src/h2/connection.py index 25251e20..15d7536f 100644 --- a/src/h2/connection.py +++ b/src/h2/connection.py @@ -1486,16 +1486,17 @@ def _receive_frame(self, frame): # I don't love using __class__ here, maybe reconsider it. frames, events = self._frame_dispatch_table[frame.__class__](frame) except StreamClosedError as e: - # If the stream was closed by RST_STREAM, we just send a RST_STREAM - # to the remote peer. Otherwise, this is a connection error, and so - # we will re-raise to trigger one. - if self._stream_is_closed_by_reset(e.stream_id): + if e._connection_error: + raise + else: + # A StreamClosedError is raised when a stream wants to send a + # RST_STREAM frame. Since the H2Stream is the authoritative source + # of its own state, we always respect its wishes here. + f = RstStreamFrame(e.stream_id) f.error_code = e.error_code self._prepare_for_sending([f]) events = e._events - else: - raise except StreamIDTooLowError as e: # The stream ID seems invalid. This may happen when the closed # stream has been cleaned up, or when the remote peer has opened a @@ -1506,10 +1507,18 @@ def _receive_frame(self, frame): # is either a stream error or a connection error. if self._stream_is_closed_by_reset(e.stream_id): # Closed by RST_STREAM is a stream error. - f = RstStreamFrame(e.stream_id) - f.error_code = ErrorCodes.STREAM_CLOSED - self._prepare_for_sending([f]) - events = [] + if self._stream_is_closed_by_peer_reset(e.stream_id): + self._closed_streams[e.stream_id] = StreamClosedBy.SEND_RST_STREAM + + f = RstStreamFrame(e.stream_id) + f.error_code = ErrorCodes.STREAM_CLOSED + self._prepare_for_sending([f]) + events = [] + else: + # Stream was closed by a local reset. A stream SHOULD NOT + # send additional RST_STREAM frames. Ignore. + events = [] + pass elif self._stream_is_closed_by_end(e.stream_id): # Closed by END_STREAM is a connection error. raise StreamClosedError(e.stream_id) @@ -1655,13 +1664,32 @@ def _handle_data_on_closed_stream(self, events, exc, frame): "auto-emitted a WINDOW_UPDATE by %d", frame.stream_id, conn_increment ) - f = RstStreamFrame(exc.stream_id) - f.error_code = exc.error_code - frames.append(f) - self.config.logger.debug( - "Stream %d already CLOSED or cleaned up - " - "auto-emitted a RST_FRAME" % frame.stream_id - ) + + send_rst_frame = False + + if frame.stream_id in self._closed_streams: + closed_by = self._closed_streams[frame.stream_id] + + if closed_by == StreamClosedBy.RECV_RST_STREAM: + self._closed_streams[frame.stream_id] = StreamClosedBy.SEND_RST_STREAM + send_rst_frame = True + elif closed_by == StreamClosedBy.SEND_RST_STREAM: + # Do not send additional RST_STREAM frames + pass + else: + # Protocol error + raise StreamClosedError(frame.stream_id) + else: + send_rst_frame = True + + if send_rst_frame: + f = RstStreamFrame(exc.stream_id) + f.error_code = exc.error_code + frames.append(f) + self.config.logger.debug( + "Stream %d already CLOSED or cleaned up - " + "auto-emitted a RST_FRAME" % frame.stream_id + ) return frames, events + exc._events def _receive_data_frame(self, frame): @@ -1677,6 +1705,8 @@ def _receive_data_frame(self, frame): flow_controlled_length ) + stream = None + try: stream = self._get_stream_by_id(frame.stream_id) frames, stream_events = stream.receive_data( @@ -1685,6 +1715,11 @@ def _receive_data_frame(self, frame): flow_controlled_length ) except StreamClosedError as e: + # If this exception originated from a yet-to-be clenaed up stream, + # check if it should be a connection error + if stream is not None and e._connection_error: + raise + # This stream is either marked as CLOSED or already gone from our # internal state. return self._handle_data_on_closed_stream(events, e, frame) @@ -1962,7 +1997,7 @@ def _stream_closed_by(self, stream_id): before opening this one. """ if stream_id in self.streams: - return self.streams[stream_id].closed_by + return self.streams[stream_id].closed_by # pragma: no cover if stream_id in self._closed_streams: return self._closed_streams[stream_id] return None @@ -1976,6 +2011,14 @@ def _stream_is_closed_by_reset(self, stream_id): StreamClosedBy.RECV_RST_STREAM, StreamClosedBy.SEND_RST_STREAM ) + def _stream_is_closed_by_peer_reset(self, stream_id): + """ + Returns ``True`` if the stream was closed by sending or receiving a + RST_STREAM frame. Returns ``False`` otherwise. + """ + return (self._stream_closed_by(stream_id) == + StreamClosedBy.RECV_RST_STREAM) + def _stream_is_closed_by_end(self, stream_id): """ Returns ``True`` if the stream was closed by sending or receiving an diff --git a/src/h2/exceptions.py b/src/h2/exceptions.py index e22bebc0..0576be0b 100644 --- a/src/h2/exceptions.py +++ b/src/h2/exceptions.py @@ -104,7 +104,7 @@ class StreamClosedError(NoSuchStreamError): that the stream has since been closed, and that all state relating to that stream has been removed. """ - def __init__(self, stream_id): + def __init__(self, stream_id, connection_error=True): #: The stream ID corresponds to the nonexistent stream. self.stream_id = stream_id @@ -115,6 +115,12 @@ def __init__(self, stream_id): # external users that may receive a StreamClosedError. self._events = [] + # If this is a connection error or a stream error. This exception + # is used to send a `RST_STREAM` frame on stream errors. If + # connection_error is false, H2Connection will suppress this + # exception after sending the reset frame. + self._connection_error = connection_error + class InvalidSettingsValueError(ProtocolError, ValueError): """ diff --git a/src/h2/stream.py b/src/h2/stream.py index 817636f8..e9af0641 100644 --- a/src/h2/stream.py +++ b/src/h2/stream.py @@ -316,7 +316,7 @@ def reset_stream_on_error(self, previous_state): """ self.stream_closed_by = StreamClosedBy.SEND_RST_STREAM - error = StreamClosedError(self.stream_id) + error = StreamClosedError(self.stream_id, connection_error=False) event = StreamReset() event.stream_id = self.stream_id @@ -334,8 +334,31 @@ def recv_on_closed_stream(self, previous_state): a stream error or connection error with type STREAM_CLOSED, depending on the specific frame. The error handling is done at a higher level: this just raises the appropriate error. - """ - raise StreamClosedError(self.stream_id) + + RFC: + Normally, an endpoint SHOULD NOT send more than one RST_STREAM + frame for any stream. However, an endpoint MAY send additional + RST_STREAM frames if it receives frames on a closed stream after + more than a round-trip time. This behavior is permitted to deal + with misbehaving implementations. + + Implementation: + Raising StreamClosedError causes the RST_STREAM frame to be sent. + If the stream closed_by value is SEND_RST_STREAM, ignore this + instead of raising, such that only one RST_STREAM frame is sent. + There is currently now latency tracking, and as such measuring + round-trip time for allowed additional RST_STREAM frames which + MAY be sent cannot be implemented. + """ + + if self.stream_closed_by == StreamClosedBy.RECV_RST_STREAM: + self.stream_closed_by = StreamClosedBy.SEND_RST_STREAM + raise StreamClosedError(self.stream_id, connection_error=False) + elif self.stream_closed_by in (StreamClosedBy.RECV_END_STREAM, + StreamClosedBy.SEND_END_STREAM): + raise StreamClosedError(self.stream_id) + + return [] def send_on_closed_stream(self, previous_state): """ @@ -1040,23 +1063,24 @@ def receive_headers(self, headers, end_stream, header_encoding): events = self.state_machine.process_input(input_) - if end_stream: - es_events = self.state_machine.process_input( - StreamInputs.RECV_END_STREAM - ) - events[0].stream_ended = es_events[0] - events += es_events + if len(events) > 0: + if end_stream: + es_events = self.state_machine.process_input( + StreamInputs.RECV_END_STREAM + ) + events[0].stream_ended = es_events[0] + events += es_events - self._initialize_content_length(headers) + self._initialize_content_length(headers) - if isinstance(events[0], TrailersReceived): - if not end_stream: - raise ProtocolError("Trailers must have END_STREAM set") + if isinstance(events[0], TrailersReceived): + if not end_stream: + raise ProtocolError("Trailers must have END_STREAM set") - hdr_validation_flags = self._build_hdr_validation_flags(events) - events[0].headers = self._process_received_headers( - headers, hdr_validation_flags, header_encoding - ) + hdr_validation_flags = self._build_hdr_validation_flags(events) + events[0].headers = self._process_received_headers( + headers, hdr_validation_flags, header_encoding + ) return [], events def receive_data(self, data, end_stream, flow_control_len): @@ -1068,18 +1092,20 @@ def receive_data(self, data, end_stream, flow_control_len): "set to %d", self, end_stream, flow_control_len ) events = self.state_machine.process_input(StreamInputs.RECV_DATA) - self._inbound_window_manager.window_consumed(flow_control_len) - self._track_content_length(len(data), end_stream) - if end_stream: - es_events = self.state_machine.process_input( - StreamInputs.RECV_END_STREAM - ) - events[0].stream_ended = es_events[0] - events.extend(es_events) + if len(events) > 0: + self._inbound_window_manager.window_consumed(flow_control_len) + self._track_content_length(len(data), end_stream) + + if end_stream: + es_events = self.state_machine.process_input( + StreamInputs.RECV_END_STREAM + ) + events[0].stream_ended = es_events[0] + events.extend(es_events) - events[0].data = data - events[0].flow_controlled_length = flow_control_len + events[0].data = data + events[0].flow_controlled_length = flow_control_len return [], events def receive_window_update(self, increment): diff --git a/test/test_closed_streams.py b/test/test_closed_streams.py index ef88d8e4..c009a8d9 100644 --- a/test/test_closed_streams.py +++ b/test/test_closed_streams.py @@ -207,6 +207,8 @@ class TestStreamsClosedByEndStream(object): self.example_request_headers, flags=['END_STREAM']), lambda self, ff: ff.build_headers_frame( self.example_request_headers), + lambda self, ff: ff.build_data_frame( + data=b'some data') ] ) @pytest.mark.parametrize("clear_streams", [True, False]) @@ -352,9 +354,13 @@ class TestStreamsClosedByRstStream(object): self.example_request_headers, flags=['END_STREAM']), ] ) + @pytest.mark.parametrize( + "clear_streams_before_send", [True, False] + ) def test_resets_further_frames_after_recv_reset(self, frame_factory, - frame): + frame, + clear_streams_before_send): """ A stream that is closed by receive RST_STREAM can receive further frames: it simply sends RST_STREAM for it, and additionally @@ -381,6 +387,9 @@ def test_resets_further_frames_after_recv_reset(self, c.receive_data(rst_frame.serialize()) c.clear_outbound_data_buffer() + if clear_streams_before_send: + c.open_outbound_streams + f = frame(self, frame_factory) events = c.receive_data(f.serialize()) @@ -390,20 +399,29 @@ def test_resets_further_frames_after_recv_reset(self, assert not events assert c.data_to_send() == rst_frame.serialize() + # "An endpoint MUST ignore frames that it receives on closed streams + # after it has sent a RST_STREAM frame." + # The initial RST_STREAM was seen in the previous assert. Additional + # frames should be ignored. events = c.receive_data(f.serialize() * 3) assert not events - assert c.data_to_send() == rst_frame.serialize() * 3 + assert c.data_to_send() == b"" # Iterate over the streams to make sure it's gone, then confirm the # behaviour is unchanged. c.open_outbound_streams + # Additional frames should continue to be ignored events = c.receive_data(f.serialize() * 3) assert not events - assert c.data_to_send() == rst_frame.serialize() * 3 + assert c.data_to_send() == b"" + @pytest.mark.parametrize( + "clear_streams_before_send", [True, False] + ) def test_resets_further_data_frames_after_recv_reset(self, - frame_factory): + frame_factory, + clear_streams_before_send): """ A stream that is closed by receive RST_STREAM can receive further DATA frames: it simply sends WINDOW_UPDATE for the connection flow @@ -430,6 +448,9 @@ def test_resets_further_data_frames_after_recv_reset(self, c.receive_data(rst_frame.serialize()) c.clear_outbound_data_buffer() + if clear_streams_before_send: + c.open_outbound_streams + f = frame_factory.build_data_frame( data=b'some data' ) @@ -445,7 +466,7 @@ def test_resets_further_data_frames_after_recv_reset(self, events = c.receive_data(f.serialize() * 3) assert not events - assert c.data_to_send() == expected * 3 + assert c.data_to_send() == b"" # Iterate over the streams to make sure it's gone, then confirm the # behaviour is unchanged. @@ -453,7 +474,7 @@ def test_resets_further_data_frames_after_recv_reset(self, events = c.receive_data(f.serialize() * 3) assert not events - assert c.data_to_send() == expected * 3 + assert c.data_to_send() == b"" @pytest.mark.parametrize( "frame", @@ -486,25 +507,23 @@ def test_resets_further_frames_after_send_reset(self, end_stream=False ) + # Send initial RST_STREAM c.reset_stream(1, h2.errors.ErrorCodes.INTERNAL_ERROR) - - rst_frame = frame_factory.build_rst_stream_frame( - 1, h2.errors.ErrorCodes.STREAM_CLOSED - ) c.clear_outbound_data_buffer() f = frame(self, frame_factory) events = c.receive_data(f.serialize()) - rst_frame = frame_factory.build_rst_stream_frame( - 1, h2.errors.ErrorCodes.STREAM_CLOSED - ) + # "An endpoint MUST ignore frames that it receives on closed streams + # after it has sent a RST_STREAM frame." + # The initial RST_STREAM was sent in the test setup. Additional frames + # should be ignored. assert not events - assert c.data_to_send() == rst_frame.serialize() + assert c.data_to_send() == b"" events = c.receive_data(f.serialize() * 3) assert not events - assert c.data_to_send() == rst_frame.serialize() * 3 + assert c.data_to_send() == b"" # Iterate over the streams to make sure it's gone, then confirm the # behaviour is unchanged. @@ -512,7 +531,7 @@ def test_resets_further_frames_after_send_reset(self, events = c.receive_data(f.serialize() * 3) assert not events - assert c.data_to_send() == rst_frame.serialize() * 3 + assert c.data_to_send() == b"" def test_resets_further_data_frames_after_send_reset(self, frame_factory): @@ -535,6 +554,7 @@ def test_resets_further_data_frames_after_send_reset(self, end_stream=False ) + # Send initial RST_STREAM c.reset_stream(1, h2.errors.ErrorCodes.INTERNAL_ERROR) c.clear_outbound_data_buffer() @@ -544,15 +564,11 @@ def test_resets_further_data_frames_after_send_reset(self, ) events = c.receive_data(f.serialize()) assert not events - expected = frame_factory.build_rst_stream_frame( - stream_id=1, - error_code=h2.errors.ErrorCodes.STREAM_CLOSED, - ).serialize() - assert c.data_to_send() == expected + assert c.data_to_send() == b"" events = c.receive_data(f.serialize() * 3) assert not events - assert c.data_to_send() == expected * 3 + assert c.data_to_send() == b"" # Iterate over the streams to make sure it's gone, then confirm the # behaviour is unchanged. @@ -560,4 +576,4 @@ def test_resets_further_data_frames_after_send_reset(self, events = c.receive_data(f.serialize() * 3) assert not events - assert c.data_to_send() == expected * 3 + assert c.data_to_send() == b"" diff --git a/test/test_flow_control_window.py b/test/test_flow_control_window.py index 223cf39f..e3d99f6c 100644 --- a/test/test_flow_control_window.py +++ b/test/test_flow_control_window.py @@ -652,15 +652,9 @@ def test_send_update_on_closed_streams(self, frame_factory): events = c.receive_data(f.serialize()*3) assert not events - expected = frame_factory.build_rst_stream_frame( - stream_id=1, - error_code=h2.errors.ErrorCodes.STREAM_CLOSED, - ).serialize() * 2 + frame_factory.build_window_update_frame( + expected = frame_factory.build_window_update_frame( stream_id=0, increment=40500, - ).serialize() + frame_factory.build_rst_stream_frame( - stream_id=1, - error_code=h2.errors.ErrorCodes.STREAM_CLOSED, ).serialize() assert c.data_to_send() == expected @@ -668,11 +662,8 @@ def test_send_update_on_closed_streams(self, frame_factory): events = c.receive_data(f.serialize()) assert not events - expected = frame_factory.build_rst_stream_frame( - stream_id=1, - error_code=h2.errors.ErrorCodes.STREAM_CLOSED, - ).serialize() - assert c.data_to_send() == expected + # RST_STREAM has already been sent. Expect no data here. + assert c.data_to_send() == b"" class TestAutomaticFlowControl(object): diff --git a/test/test_invalid_frame_sequences.py b/test/test_invalid_frame_sequences.py index 05832cbb..379c10fa 100644 --- a/test/test_invalid_frame_sequences.py +++ b/test/test_invalid_frame_sequences.py @@ -153,6 +153,7 @@ def test_reject_data_on_closed_streams(self, frame_factory): bad_frame = frame_factory.build_data_frame( data=b'some data' ) + c.receive_data(bad_frame.serialize()) expected = frame_factory.build_rst_stream_frame( @@ -349,15 +350,25 @@ def test_one_one_stream_reset(self, frame_factory): bad_frame = frame_factory.build_data_frame( data=b'some data' ) - # Receive 5 frames. - events = c.receive_data(bad_frame.serialize() * 5) + + rst_frame = frame_factory.build_rst_stream_frame( + stream_id=1, + error_code=h2.errors.ErrorCodes.STREAM_CLOSED, + ) expected = frame_factory.build_rst_stream_frame( stream_id=1, error_code=h2.errors.ErrorCodes.STREAM_CLOSED, ).serialize() - assert c.data_to_send() == expected * 5 + # Receive 5 frames. + events = c.receive_data(bad_frame.serialize()) + assert len(events) == 1 + assert c.data_to_send() == expected + + events += c.receive_data(rst_frame.serialize() * 4) + + assert c.data_to_send() == b"" assert len(events) == 1 event = events[0] assert isinstance(event, h2.events.StreamReset) diff --git a/test/test_stream_reset.py b/test/test_stream_reset.py index 77844551..faf8e7dd 100644 --- a/test/test_stream_reset.py +++ b/test/test_stream_reset.py @@ -39,6 +39,7 @@ def test_reset_stream_keeps_header_state_correct(self, frame_factory): c = h2.connection.H2Connection() c.initiate_connection() c.send_headers(stream_id=1, headers=self.example_request_headers) + # Send initial RST_STREAM c.reset_stream(stream_id=1) c.send_headers(stream_id=3, headers=self.example_request_headers) c.clear_outbound_data_buffer() @@ -46,12 +47,11 @@ def test_reset_stream_keeps_header_state_correct(self, frame_factory): f = frame_factory.build_headers_frame( headers=self.example_response_headers, stream_id=1 ) - rst_frame = frame_factory.build_rst_stream_frame( - 1, h2.errors.ErrorCodes.STREAM_CLOSED - ) + + # RST_STREAM already sent. Expect no data here. events = c.receive_data(f.serialize()) assert not events - assert c.data_to_send() == rst_frame.serialize() + assert c.data_to_send() == b"" # This works because the header state should be intact from the headers # frame that was send on stream 1, so they should decode cleanly. @@ -85,6 +85,7 @@ def test_reset_stream_keeps_flow_control_correct(self, headers=self.example_response_headers, stream_id=close_id ) c.receive_data(f.serialize()) + # Send initial reset c.reset_stream(stream_id=close_id) c.clear_outbound_data_buffer() @@ -94,11 +95,8 @@ def test_reset_stream_keeps_flow_control_correct(self, ) c.receive_data(f.serialize()) - expected = frame_factory.build_rst_stream_frame( - stream_id=close_id, - error_code=h2.errors.ErrorCodes.STREAM_CLOSED, - ).serialize() - assert c.data_to_send() == expected + # RST_STREAM already sent. Expect no data here. + assert c.data_to_send() == b"" new_window = c.remote_flow_control_window(stream_id=other_id) assert initial_window - len(b'some data') == new_window