diff --git a/include/aws/http/private/connection_impl.h b/include/aws/http/private/connection_impl.h index f8d2da255..4d46a27a2 100644 --- a/include/aws/http/private/connection_impl.h +++ b/include/aws/http/private/connection_impl.h @@ -66,7 +66,6 @@ struct aws_http_connection { struct aws_channel_slot *channel_slot; struct aws_allocator *alloc; enum aws_http_version http_version; - size_t initial_window_size; aws_http_proxy_request_transform_fn *proxy_request_transform; void *user_data; diff --git a/include/aws/http/private/h2_stream.h b/include/aws/http/private/h2_stream.h index 2b0f4cba4..c5009915b 100644 --- a/include/aws/http/private/h2_stream.h +++ b/include/aws/http/private/h2_stream.h @@ -93,6 +93,7 @@ int aws_h2_stream_on_decoder_headers_end( bool malformed, enum aws_http_header_block block_type); +int aws_h2_stream_on_decoder_data(struct aws_h2_stream *stream, struct aws_byte_cursor data); int aws_h2_stream_on_decoder_end_stream(struct aws_h2_stream *stream); int aws_h2_stream_activate(struct aws_http_stream *stream); diff --git a/source/h1_connection.c b/source/h1_connection.c index 6e56c5fcb..8a6f93b2c 100644 --- a/source/h1_connection.c +++ b/source/h1_connection.c @@ -113,6 +113,8 @@ static const struct aws_h1_decoder_vtable s_h1_decoder_vtable = { struct h1_connection { struct aws_http_connection base; + size_t initial_window_size; + /* Single task used repeatedly for sending data from streams. */ struct aws_channel_task outgoing_stream_task; @@ -1257,7 +1259,6 @@ static struct h1_connection *s_connection_new( connection->base.channel_handler.alloc = alloc; connection->base.channel_handler.impl = connection; connection->base.http_version = AWS_HTTP_VERSION_1_1; - connection->base.initial_window_size = initial_window_size; connection->base.manual_window_management = manual_window_management; /* Init the next stream id (server must use even ids, client odd [RFC 7540 5.1.1])*/ @@ -1266,6 +1267,8 @@ static struct h1_connection *s_connection_new( /* 1 refcount for user */ aws_atomic_init_int(&connection->base.refcount, 1); + connection->initial_window_size = initial_window_size; + aws_h1_encoder_init(&connection->thread_data.encoder, alloc); aws_channel_task_init( @@ -1825,7 +1828,7 @@ static int s_handler_shutdown( static size_t s_handler_initial_window_size(struct aws_channel_handler *handler) { struct h1_connection *connection = handler->impl; - return connection->base.initial_window_size; + return connection->initial_window_size; } static size_t s_handler_message_overhead(struct aws_channel_handler *handler) { diff --git a/source/h2_connection.c b/source/h2_connection.c index 0912d710c..a855bdee3 100644 --- a/source/h2_connection.c +++ b/source/h2_connection.c @@ -74,6 +74,7 @@ static int s_decoder_on_headers_end( bool malformed, enum aws_http_header_block block_type, void *userdata); +static int s_decoder_on_data(uint32_t stream_id, struct aws_byte_cursor data, void *userdata); static int s_decoder_on_end_stream(uint32_t stream_id, void *userdata); static int s_decoder_on_ping(uint8_t opaque_data[AWS_H2_PING_DATA_SIZE], void *userdata); static int s_decoder_on_settings( @@ -107,6 +108,7 @@ static const struct aws_h2_decoder_vtable s_h2_decoder_vtable = { .on_headers_begin = s_decoder_on_headers_begin, .on_headers_i = s_decoder_on_headers_i, .on_headers_end = s_decoder_on_headers_end, + .on_data = s_decoder_on_data, .on_end_stream = s_decoder_on_end_stream, .on_ping = s_decoder_on_ping, .on_settings = s_decoder_on_settings, @@ -187,6 +189,7 @@ static struct aws_h2_connection *s_connection_new( bool server) { (void)server; + (void)initial_window_size; /* #TODO use this for our initial settings */ struct aws_h2_connection *connection = aws_mem_calloc(alloc, 1, sizeof(struct aws_h2_connection)); if (!connection) { @@ -199,7 +202,6 @@ static struct aws_h2_connection *s_connection_new( connection->base.channel_handler.alloc = alloc; connection->base.channel_handler.impl = connection; connection->base.http_version = AWS_HTTP_VERSION_2; - connection->base.initial_window_size = initial_window_size; /* Init the next stream id (server must use even ids, client odd [RFC 7540 5.1.1])*/ connection->base.next_stream_id = (server ? 2 : 1); connection->base.manual_window_management = manual_window_management; @@ -713,6 +715,26 @@ int s_decoder_on_headers_end( return AWS_OP_SUCCESS; } +int s_decoder_on_data(uint32_t stream_id, struct aws_byte_cursor data, void *userdata) { + struct aws_h2_connection *connection = userdata; + + /* #TODO Update connection's flow-control window */ + + /* Pass data to stream */ + struct aws_h2_stream *stream; + if (s_get_active_stream_for_incoming_frame(connection, stream_id, AWS_H2_FRAME_T_DATA, &stream)) { + return AWS_OP_ERR; + } + + if (stream) { + if (aws_h2_stream_on_decoder_data(stream, data)) { + return AWS_OP_ERR; + } + } + + return AWS_OP_SUCCESS; +} + int s_decoder_on_end_stream(uint32_t stream_id, void *userdata) { struct aws_h2_connection *connection = userdata; diff --git a/source/h2_stream.c b/source/h2_stream.c index ef8e8c4a4..6267757b8 100644 --- a/source/h2_stream.c +++ b/source/h2_stream.c @@ -430,6 +430,32 @@ int aws_h2_stream_on_decoder_headers_end( return AWS_OP_SUCCESS; } +int aws_h2_stream_on_decoder_data(struct aws_h2_stream *stream, struct aws_byte_cursor data) { + AWS_PRECONDITION_ON_CHANNEL_THREAD(stream); + + if (s_check_state_allows_frame_type(stream, AWS_H2_FRAME_T_DATA)) { + return s_send_rst_and_close_stream(stream, aws_last_error()); + } + + if (!stream->thread_data.received_main_headers) { + /* #TODO Not 100% sure whether this is Stream Error or Connection Error. */ + AWS_H2_STREAM_LOG(ERROR, stream, "Malformed message, received DATA before main HEADERS"); + return s_send_rst_and_close_stream(stream, AWS_ERROR_HTTP_PROTOCOL_ERROR); + } + + /* #TODO Update stream's flow-control window */ + + if (stream->base.on_incoming_body) { + if (stream->base.on_incoming_body(&stream->base, &data, stream->base.user_data)) { + AWS_H2_STREAM_LOGF( + ERROR, stream, "Incoming body callback raised error, %s", aws_error_name(aws_last_error())); + return AWS_OP_ERR; + } + } + + return AWS_OP_SUCCESS; +} + int aws_h2_stream_on_decoder_end_stream(struct aws_h2_stream *stream) { AWS_PRECONDITION_ON_CHANNEL_THREAD(stream); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 396de54d9..8ce11c59f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -339,6 +339,8 @@ add_test_case(h2_client_stream_ignores_some_frames_received_soon_after_closing) #TODO add_test_case(h2_client_stream_err_receive_info_headers_after_main) #TODO add_test_case(h2_client_stream_receive_trailing_headers) #TODO add_test_case(h2_client_stream_err_receive_trailing_before_main) +add_test_case(h2_client_stream_receive_data) +add_test_case(h2_client_stream_err_receive_data_before_headers) add_test_case(server_new_destroy) diff --git a/tests/h2_test_helper.c b/tests/h2_test_helper.c index 506980ac3..08824cfcd 100644 --- a/tests/h2_test_helper.c +++ b/tests/h2_test_helper.c @@ -16,6 +16,7 @@ #include "h2_test_helper.h" #include +#include #include /******************************************************************************* @@ -522,6 +523,35 @@ int h2_fake_peer_send_frame(struct h2_fake_peer *peer, struct aws_h2_frame *fram return AWS_OP_SUCCESS; } +int h2_fake_peer_send_data_frame( + struct h2_fake_peer *peer, + uint32_t stream_id, + struct aws_byte_cursor data, + bool end_stream) { + + struct aws_input_stream *body_stream = aws_input_stream_new_from_cursor(peer->alloc, &data); + ASSERT_NOT_NULL(body_stream); + + struct aws_io_message *msg = aws_channel_acquire_message_from_pool( + peer->testing_channel->channel, AWS_IO_MESSAGE_APPLICATION_DATA, g_aws_channel_max_fragment_size); + ASSERT_NOT_NULL(msg); + + bool body_complete; + ASSERT_SUCCESS(aws_h2_encode_data_frame( + &peer->encoder, stream_id, body_stream, end_stream, 0, &msg->message_data, &body_complete)); + + ASSERT_TRUE(body_complete); + ASSERT_TRUE(msg->message_data.len != 0); + + ASSERT_SUCCESS(testing_channel_push_read_message(peer->testing_channel, msg)); + aws_input_stream_destroy(body_stream); + return AWS_OP_SUCCESS; +} + +int h2_fake_peer_send_data_frame_str(struct h2_fake_peer *peer, uint32_t stream_id, const char *data, bool end_stream) { + return h2_fake_peer_send_data_frame(peer, stream_id, aws_byte_cursor_from_c_str(data), end_stream); +} + int h2_fake_peer_send_connection_preface(struct h2_fake_peer *peer, struct aws_h2_frame *settings) { if (!peer->is_server) { /* Client must first send magic string */ diff --git a/tests/h2_test_helper.h b/tests/h2_test_helper.h index 5804536ff..db0adcf09 100644 --- a/tests/h2_test_helper.h +++ b/tests/h2_test_helper.h @@ -147,6 +147,22 @@ int h2_fake_peer_decode_messages_from_testing_channel(struct h2_fake_peer *peer) */ int h2_fake_peer_send_frame(struct h2_fake_peer *peer, struct aws_h2_frame *frame); +/** + * Encode the entire byte cursor into a single DATA frame. + * Fails if the cursor is too large for this to work. + */ +int h2_fake_peer_send_data_frame( + struct h2_fake_peer *peer, + uint32_t stream_id, + struct aws_byte_cursor data, + bool end_stream); + +/** + * Encode the entire string into a single DATA frame. + * Fails if the string is too large for this to work. + */ +int h2_fake_peer_send_data_frame_str(struct h2_fake_peer *peer, uint32_t stream_id, const char *data, bool end_stream); + /** * Peer sends the connection preface with specified settings. * Takes ownership of frame and destroys after sending diff --git a/tests/test_h2_client.c b/tests/test_h2_client.c index c122c051f..1056924a6 100644 --- a/tests/test_h2_client.c +++ b/tests/test_h2_client.c @@ -451,3 +451,110 @@ TEST_CASE(h2_client_stream_ignores_some_frames_received_soon_after_closing) { client_stream_tester_clean_up(&stream_tester); return s_tester_clean_up(); } + +/* Test receiving a response with DATA frames */ +TEST_CASE(h2_client_stream_receive_data) { + ASSERT_SUCCESS(s_tester_init(allocator, ctx)); + + /* fake peer sends connection preface */ + ASSERT_SUCCESS(h2_fake_peer_send_connection_preface_default_settings(&s_tester.peer)); + testing_channel_drain_queued_tasks(&s_tester.testing_channel); + + /* send request */ + struct aws_http_message *request = aws_http_message_new_request(allocator); + ASSERT_NOT_NULL(request); + + struct aws_http_header request_headers_src[] = { + DEFINE_HEADER(":method", "GET"), + DEFINE_HEADER(":scheme", "https"), + DEFINE_HEADER(":path", "/"), + }; + aws_http_message_add_header_array(request, request_headers_src, AWS_ARRAY_SIZE(request_headers_src)); + + struct client_stream_tester stream_tester; + ASSERT_SUCCESS(s_stream_tester_init(&stream_tester, request)); + testing_channel_drain_queued_tasks(&s_tester.testing_channel); + + uint32_t stream_id = aws_http_stream_get_id(stream_tester.stream); + + /* fake peer sends response headers */ + struct aws_http_header response_headers_src[] = { + DEFINE_HEADER(":status", "200"), + }; + + struct aws_http_headers *response_headers = aws_http_headers_new(allocator); + aws_http_headers_add_array(response_headers, response_headers_src, AWS_ARRAY_SIZE(response_headers_src)); + + struct aws_h2_frame *response_frame = + aws_h2_frame_new_headers(allocator, stream_id, response_headers, false /*end_stream*/, 0, NULL); + ASSERT_SUCCESS(h2_fake_peer_send_frame(&s_tester.peer, response_frame)); + + /* fake peer sends response body */ + const char *body_src = "hello"; + ASSERT_SUCCESS(h2_fake_peer_send_data_frame_str(&s_tester.peer, stream_id, body_src, true /*end_stream*/)); + + /* validate that client received complete response */ + testing_channel_drain_queued_tasks(&s_tester.testing_channel); + ASSERT_TRUE(stream_tester.complete); + ASSERT_INT_EQUALS(AWS_ERROR_SUCCESS, stream_tester.on_complete_error_code); + ASSERT_INT_EQUALS(200, stream_tester.response_status); + ASSERT_SUCCESS(s_compare_headers(response_headers, stream_tester.response_headers)); + ASSERT_TRUE(aws_byte_buf_eq_c_str(&stream_tester.response_body, body_src)); + + ASSERT_TRUE(aws_http_connection_is_open(s_tester.connection)); + + /* clean up */ + aws_http_headers_release(response_headers); + aws_http_message_release(request); + client_stream_tester_clean_up(&stream_tester); + return s_tester_clean_up(); +} + +/* A message is malformed if DATA is received before HEADERS */ +TEST_CASE(h2_client_stream_err_receive_data_before_headers) { + ASSERT_SUCCESS(s_tester_init(allocator, ctx)); + + /* fake peer sends connection preface */ + ASSERT_SUCCESS(h2_fake_peer_send_connection_preface_default_settings(&s_tester.peer)); + testing_channel_drain_queued_tasks(&s_tester.testing_channel); + + /* send request */ + struct aws_http_message *request = aws_http_message_new_request(allocator); + ASSERT_NOT_NULL(request); + + struct aws_http_header request_headers_src[] = { + DEFINE_HEADER(":method", "GET"), + DEFINE_HEADER(":scheme", "https"), + DEFINE_HEADER(":path", "/"), + }; + aws_http_message_add_header_array(request, request_headers_src, AWS_ARRAY_SIZE(request_headers_src)); + + struct client_stream_tester stream_tester; + ASSERT_SUCCESS(s_stream_tester_init(&stream_tester, request)); + testing_channel_drain_queued_tasks(&s_tester.testing_channel); + + uint32_t stream_id = aws_http_stream_get_id(stream_tester.stream); + + /* fake peer sends response body BEFORE any response headers */ + const char *body_src = "hello"; + ASSERT_SUCCESS(h2_fake_peer_send_data_frame_str(&s_tester.peer, stream_id, body_src, true /*end_stream*/)); + + /* validate that stream completed with error */ + testing_channel_drain_queued_tasks(&s_tester.testing_channel); + ASSERT_TRUE(stream_tester.complete); + ASSERT_INT_EQUALS(AWS_ERROR_HTTP_PROTOCOL_ERROR, stream_tester.on_complete_error_code); + + /* a stream error should not affect the connection */ + ASSERT_TRUE(aws_http_connection_is_open(s_tester.connection)); + + /* validate that stream sent RST_STREAM */ + ASSERT_SUCCESS(h2_fake_peer_decode_messages_from_testing_channel(&s_tester.peer)); + struct h2_decoded_frame *rst_stream_frame = h2_decode_tester_latest_frame(&s_tester.peer.decode); + ASSERT_INT_EQUALS(AWS_H2_FRAME_T_RST_STREAM, rst_stream_frame->type); + ASSERT_UINT_EQUALS(AWS_H2_ERR_PROTOCOL_ERROR, rst_stream_frame->error_code); + + /* clean up */ + aws_http_message_release(request); + client_stream_tester_clean_up(&stream_tester); + return s_tester_clean_up(); +} \ No newline at end of file