diff --git a/include/aws/http/http.h b/include/aws/http/http.h index 05f2024e5..6af06ed9e 100644 --- a/include/aws/http/http.h +++ b/include/aws/http/http.h @@ -30,6 +30,7 @@ enum aws_http_errors { AWS_ERROR_HTTP_OUTGOING_STREAM_LENGTH_INCORRECT, AWS_ERROR_HTTP_CALLBACK_FAILURE, AWS_ERROR_HTTP_WEBSOCKET_CLOSE_FRAME_SENT, + AWS_ERROR_HTTP_WEBSOCKET_IS_MIDCHANNEL_HANDLER, AWS_ERROR_HTTP_END_RANGE = 0x0C00, }; diff --git a/include/aws/http/private/websocket_impl.h b/include/aws/http/private/websocket_impl.h index 0ce6bf381..ce2e54d7e 100644 --- a/include/aws/http/private/websocket_impl.h +++ b/include/aws/http/private/websocket_impl.h @@ -61,7 +61,7 @@ struct aws_websocket_frame { struct aws_websocket_handler_options { struct aws_allocator *allocator; - struct aws_channel_slot *channel_slot; + struct aws_channel *channel; size_t initial_window_size; void *user_data; @@ -88,12 +88,10 @@ AWS_HTTP_API uint64_t aws_websocket_frame_encoded_size(const struct aws_websocket_frame *frame); /** - * Returns channel-handler for websocket. - * handler->impl is the aws_websocket* - * To destroy a handler that was never put into a channel, invoke: `handler->vtable.destroy(handler)` + * Create a websocket channel-handler and insert it into the channel. */ AWS_HTTP_API -struct aws_channel_handler *aws_websocket_handler_new(const struct aws_websocket_handler_options *options); +struct aws_websocket *aws_websocket_handler_new(const struct aws_websocket_handler_options *options); AWS_EXTERN_C_END #endif /* AWS_HTTP_WEBSOCKET_IMPL_H */ diff --git a/include/aws/http/websocket.h b/include/aws/http/websocket.h index e242c506c..af82b0d11 100644 --- a/include/aws/http/websocket.h +++ b/include/aws/http/websocket.h @@ -17,10 +17,8 @@ #include -struct aws_channel_handler; -struct aws_http_header; - /* TODO: Document lifetime stuff */ +/* TODO: Should shutdown callback fire when it's a midchannel handler? */ /* TODO: Document CLOSE frame behavior (when auto-sent during close, when auto-closed) */ /* TODO: Document auto-pong behavior */ @@ -236,25 +234,15 @@ bool aws_websocket_is_data_frame(uint8_t opcode); AWS_HTTP_API int aws_websocket_client_connect(const struct aws_websocket_client_connection_options *options); -/* TODO: Require all users to manually grab a hold? Http doesn't work like that... */ -/* TODO: should the last release trigger a shutdown automatically? http does that, channel doesn't. */ - -/** - * Ensure that the websocket cannot be destroyed until aws_websocket_release_hold() is called. - * The websocket might still shutdown/close, but the public API will not crash when this websocket pointer is used. - * If acquire_hold() is never called, the websocket is destroyed when its channel its channel is destroyed. - * This function may be called from any thread. - */ -AWS_HTTP_API -void aws_websocket_acquire_hold(struct aws_websocket *websocket); - /** - * See aws_websocket_acquire_hold(). - * The websocket will shut itself down when the last hold is released. + * Users must release the websocket when they are done with it (unless it's been converted to a mid-channel handler). + * The websocket's memory cannot be reclaimed until this is done. + * If the websocket connection was not already shutting down, it will be shut down. + * Callbacks may continue firing after this is called, with "shutdown" being the final callback. * This function may be called from any thread. */ AWS_HTTP_API -void aws_websocket_release_hold(struct aws_websocket *websocket); +void aws_websocket_release(struct aws_websocket *websocket); /** * Close the websocket connection. @@ -285,11 +273,29 @@ int aws_websocket_send_frame(struct aws_websocket *websocket, const struct aws_w AWS_HTTP_API void aws_websocket_increment_read_window(struct aws_websocket *websocket, size_t size); -/* WIP */ +/** + * Convert the websocket into a mid-channel handler. + * The websocket will stop being usable via its public API and become just another handler in the channel. + * The caller will likely install a channel handler to the right. + * This must not be called in the middle of an incoming frame (between "frame begin" and "frame complete" callbacks). + * This MUST be called from the websocket's thread. + * + * If successful, the channel that the websocket belongs to is returned and: + * - The websocket will ignore all further calls to aws_websocket_X() functions. + * - The websocket will no longer invoke any "incoming frame" callbacks. + * - There is no need to invoke aws_websocket_release(), the websocket will be destroyed when the channel is destroyed. + * The caller should acquire a hold on the channel if they need to prevent its destruction. + * - aws_io_messages written by a downstream handler will be wrapped in binary data frames and sent upstream. + * The data may be split/combined as it is sent along. + * - aws_io_messages read from upstream handlers will be scanned for binary data frames. + * The payloads of these frames will be sent downstream. + * The payloads may be split/combined as they are sent along. + * - An incoming close frame will automatically result in channel-shutdown. + * + * If unsuccessful, NULL is returned and the websocket is unchanged. + */ AWS_HTTP_API -int aws_websocket_install_channel_handler_to_right( - struct aws_websocket *websocket, - struct aws_channel_handler *right_handler); +struct aws_channel *aws_websocket_convert_to_midchannel_handler(struct aws_websocket *websocket); AWS_EXTERN_C_END diff --git a/source/http.c b/source/http.c index feab11365..29c7153c3 100644 --- a/source/http.c +++ b/source/http.c @@ -58,6 +58,9 @@ static struct aws_error_info s_errors[] = { AWS_DEFINE_ERROR_INFO_HTTP( AWS_ERROR_HTTP_WEBSOCKET_CLOSE_FRAME_SENT, "Websocket has sent CLOSE frame, no more data will be sent."), + AWS_DEFINE_ERROR_INFO_HTTP( + AWS_ERROR_HTTP_WEBSOCKET_IS_MIDCHANNEL_HANDLER, + "Operation cannot be performed because websocket has been converted to a midchannel handler."), AWS_DEFINE_ERROR_INFO_HTTP( AWS_ERROR_HTTP_END_RANGE, "Not a real error and should never be seen."), diff --git a/source/websocket.c b/source/websocket.c index 5d1c9b04c..23a478095 100644 --- a/source/websocket.c +++ b/source/websocket.c @@ -66,7 +66,7 @@ struct aws_websocket { struct aws_channel_task move_synced_data_to_thread_task; struct aws_channel_task shutdown_channel_task; struct aws_channel_task increment_read_window_task; - struct aws_atomic_var refcount; + struct aws_channel_task finish_midchannel_conversion_task; bool is_server; struct { @@ -101,6 +101,10 @@ struct aws_websocket { /* Wait until each aws_io_message is completely written to * the socket before sending the next aws_io_message */ bool is_waiting_for_write_completion; + + /* True if this websocket is being used as a dumb mid-channel handler. + * The websocket will no longer respond to its public API or invoke callbacks. */ + bool is_midchannel_handler; } thread_data; struct { @@ -119,6 +123,12 @@ struct aws_websocket { bool is_shutdown_channel_task_scheduled; bool is_move_synced_data_to_thread_task_scheduled; + + /* Mirrors variable from thread_data */ + bool is_midchannel_handler; + + /* Whether aws_websocket_release() has been called */ + bool is_released; } synced_data; }; @@ -161,8 +171,15 @@ static void s_io_message_write_completed( struct aws_io_message *message, int err_code, void *user_data); +static int s_send_frame( + struct aws_websocket *websocket, + const struct aws_websocket_send_frame_options *options, + bool from_public_api); +static bool s_midchannel_send_payload(struct aws_websocket *websocket, struct aws_byte_buf *out_buf, void *user_data); +static void s_midchannel_send_complete(struct aws_websocket *websocket, int error_code, void *user_data); static void s_move_synced_data_to_thread_task(struct aws_channel_task *task, void *arg, enum aws_task_status status); static void s_increment_read_window_task(struct aws_channel_task *task, void *arg, enum aws_task_status status); +static void s_finish_midchannel_conversion_task(struct aws_channel_task *task, void *arg, enum aws_task_status status); static void s_shutdown_channel_task(struct aws_channel_task *task, void *arg, enum aws_task_status status); static void s_schedule_channel_shutdown(struct aws_websocket *websocket, int error_code); static void s_shutdown_due_to_write_err(struct aws_websocket *websocket, int error_code); @@ -216,12 +233,26 @@ void s_unlock_synced_data(struct aws_websocket *websocket) { (void)err; } -struct aws_channel_handler *aws_websocket_handler_new(const struct aws_websocket_handler_options *options) { +struct aws_websocket *aws_websocket_handler_new(const struct aws_websocket_handler_options *options) { /* TODO: validate options */ - struct aws_websocket *websocket = aws_mem_acquire(options->allocator, sizeof(struct aws_websocket)); + struct aws_channel_slot *slot = NULL; + struct aws_websocket *websocket = NULL; + int err; + + slot = aws_channel_slot_new(options->channel); + if (!slot) { + goto error; + } + + err = aws_channel_slot_insert_end(options->channel, slot); + if (err) { + goto error; + } + + websocket = aws_mem_acquire(options->allocator, sizeof(struct aws_websocket)); if (!websocket) { - return NULL; + goto error; } AWS_ZERO_STRUCT(*websocket); @@ -230,7 +261,7 @@ struct aws_channel_handler *aws_websocket_handler_new(const struct aws_websocket websocket->channel_handler.alloc = options->allocator; websocket->channel_handler.impl = websocket; - websocket->channel_slot = options->channel_slot; + websocket->channel_slot = slot; websocket->initial_window_size = options->initial_window_size; @@ -240,13 +271,13 @@ struct aws_channel_handler *aws_websocket_handler_new(const struct aws_websocket websocket->on_incoming_frame_payload = options->on_incoming_frame_payload; websocket->on_incoming_frame_complete = options->on_incoming_frame_complete; - aws_atomic_init_int(&websocket->refcount, 0); - websocket->is_server = options->is_server; aws_channel_task_init(&websocket->move_synced_data_to_thread_task, s_move_synced_data_to_thread_task, websocket); aws_channel_task_init(&websocket->shutdown_channel_task, s_shutdown_channel_task, websocket); aws_channel_task_init(&websocket->increment_read_window_task, s_increment_read_window_task, websocket); + aws_channel_task_init( + &websocket->finish_midchannel_conversion_task, s_finish_midchannel_conversion_task, websocket); aws_linked_list_init(&websocket->thread_data.outgoing_frame_list); @@ -254,7 +285,9 @@ struct aws_channel_handler *aws_websocket_handler_new(const struct aws_websocket aws_websocket_decoder_init(&websocket->thread_data.decoder, s_decoder_on_frame, s_decoder_on_payload, websocket); - int err = aws_mutex_init(&websocket->synced_data.lock); + aws_linked_list_init(&websocket->synced_data.outgoing_frame_list); + + err = aws_mutex_init(&websocket->synced_data.lock); if (err) { AWS_LOGF_ERROR( AWS_LS_HTTP_WEBSOCKET, @@ -265,12 +298,23 @@ struct aws_channel_handler *aws_websocket_handler_new(const struct aws_websocket goto error; } - aws_linked_list_init(&websocket->synced_data.outgoing_frame_list); + err = aws_channel_slot_set_handler(slot, &websocket->channel_handler); + if (err) { + goto error; + } + + /* Ensure websocket (and the rest of the channel) can't be destroyed until aws_websocket_release() is called */ + aws_channel_acquire_hold(options->channel); - return &websocket->channel_handler; + return websocket; error: - websocket->channel_handler.vtable->destroy(&websocket->channel_handler); + if (slot) { + if (websocket && !slot->handler) { + websocket->channel_handler.vtable->destroy(&websocket->channel_handler); + } + aws_channel_slot_remove(slot); + } return NULL; } @@ -285,46 +329,118 @@ static void s_handler_destroy(struct aws_channel_handler *handler) { aws_mem_release(websocket->alloc, websocket); } -void aws_websocket_acquire_hold(struct aws_websocket *websocket) { - size_t prev_refcount = aws_atomic_fetch_add(&websocket->refcount, 1); - AWS_LOGF_TRACE( - AWS_LS_HTTP_WEBSOCKET, - "id=%p: Websocket refcount increased, currently %zu.", - (void *)websocket, - prev_refcount + 1); +void aws_websocket_release(struct aws_websocket *websocket) { + AWS_ASSERT(websocket); + AWS_ASSERT(websocket->channel_slot); + + enum { OK, IS_MIDCHANNEL_HANDLER, ALREADY_RELEASED } outcome; - if (prev_refcount == 0) { - /* Prevent channel from destroying the websocket unexpectedly */ - aws_channel_acquire_hold(websocket->channel_slot->channel); + /* BEGIN CRITICAL SECTION */ + s_lock_synced_data(websocket); + if (websocket->synced_data.is_released) { + outcome = ALREADY_RELEASED; + } else if (websocket->synced_data.is_midchannel_handler) { + outcome = IS_MIDCHANNEL_HANDLER; + } else { + websocket->synced_data.is_released = true; + outcome = OK; } -} + s_unlock_synced_data(websocket); + /* END CRITICAL SECTION */ -void aws_websocket_release_hold(struct aws_websocket *websocket) { - AWS_ASSERT(websocket); - AWS_ASSERT(websocket->channel_slot); + if (outcome == IS_MIDCHANNEL_HANDLER) { + AWS_LOGF_TRACE( + AWS_LS_HTTP_WEBSOCKET, + "id=%p: Ignoring release call, websocket has converted to mid-channel handler" + " and will be destroyed when its channel is destroyed.", + (void *)websocket); + return; + } - size_t prev_refcount = aws_atomic_fetch_sub(&websocket->refcount, 1); - if (prev_refcount == 1) { + if (outcome == ALREADY_RELEASED) { AWS_LOGF_TRACE( + AWS_LS_HTTP_WEBSOCKET, "id=%p: Ignoring multiple calls to websocket release.", (void *)websocket); + return; + } + + AWS_LOGF_TRACE(AWS_LS_HTTP_WEBSOCKET, "id=%p: Websocket released, shut down if necessary.", (void *)websocket); + + /* Channel might already be shut down, but make sure */ + s_schedule_channel_shutdown(websocket, AWS_ERROR_SUCCESS); + + /* Channel won't destroy its slots/handlers until its refcount reaches 0 */ + aws_channel_release_hold(websocket->channel_slot->channel); +} + +struct aws_channel *aws_websocket_convert_to_midchannel_handler(struct aws_websocket *websocket) { + if (!aws_channel_thread_is_callers_thread(websocket->channel_slot->channel)) { + AWS_LOGF_ERROR( + AWS_LS_HTTP_WEBSOCKET, "id=%p: Cannot convert to midchannel handler on this thread.", (void *)websocket); + aws_raise_error(AWS_ERROR_IO_EVENT_LOOP_THREAD_ONLY); + return NULL; + } + + if (websocket->thread_data.is_midchannel_handler) { + AWS_LOGF_ERROR( + AWS_LS_HTTP_WEBSOCKET, "id=%p: Websocket has already converted to midchannel handler.", (void *)websocket); + aws_raise_error(AWS_ERROR_HTTP_WEBSOCKET_IS_MIDCHANNEL_HANDLER); + return NULL; + } + + if (websocket->thread_data.is_reading_stopped || websocket->thread_data.is_writing_stopped) { + AWS_LOGF_ERROR( AWS_LS_HTTP_WEBSOCKET, - "id=%p: Final websocket refcount released, shut down if necessary.", + "id=%p: Cannot convert websocket to midchannel handler because it is closed or closing.", (void *)websocket); + aws_raise_error(AWS_ERROR_HTTP_CONNECTION_CLOSED); + } - /* Channel might already be shut down, but make sure */ - aws_channel_shutdown(websocket->channel_slot->channel, AWS_ERROR_SUCCESS); + if (websocket->thread_data.current_incoming_frame) { + AWS_LOGF_ERROR( + AWS_LS_HTTP_WEBSOCKET, + "id=%p: Cannot convert to midchannel handler in the middle of an incoming frame.", + (void *)websocket); + aws_raise_error(AWS_ERROR_INVALID_STATE); + return NULL; + } - /* Channel won't destroy its slots/handlers until its refcount reaches 0 */ - aws_channel_release_hold(websocket->channel_slot->channel); + bool was_released = false; + /* BEGIN CRITICAL SECTION */ + s_lock_synced_data(websocket); + if (websocket->synced_data.is_released) { + was_released = true; } else { - AWS_ASSERT(prev_refcount != 0); + websocket->synced_data.is_midchannel_handler = true; + } + s_unlock_synced_data(websocket); + /* END CRITICAL SECTION */ - AWS_LOGF_TRACE( + if (was_released) { + AWS_LOGF_ERROR( AWS_LS_HTTP_WEBSOCKET, - "id=%p: Websocket refcount released, %zu remaining.", - (void *)websocket, - prev_refcount - 1); + "id=%p: Cannot convert websocket to midchannel handler because it was already released.", + (void *)websocket); + aws_raise_error(AWS_ERROR_HTTP_CONNECTION_CLOSED); + return NULL; } + + websocket->thread_data.is_midchannel_handler = true; + + aws_channel_schedule_task_now(websocket->channel_slot->channel, &websocket->finish_midchannel_conversion_task); + + return websocket->channel_slot->channel; +} + +static void s_finish_midchannel_conversion_task(struct aws_channel_task *task, void *arg, enum aws_task_status status) { + (void)task; + (void)status; + struct aws_websocket *websocket = arg; + + /* Once websocket is converted into a midchannel handler, it no longer prevents the channel from being destroyed. + * The channel hold is released as a post-conversion task so that whoever initiated the conversion has a chance to + * put their own hold on the channel. */ + aws_channel_release_hold(websocket->channel_slot->channel); } /* Insert frame into list, sorting by priority, then by age (high-priority and older frames towards the front) */ @@ -343,7 +459,10 @@ static void s_enqueue_prioritized_frame(struct aws_linked_list *list, struct out aws_linked_list_insert_after(rev_iter, &to_add->node); } -int aws_websocket_send_frame(struct aws_websocket *websocket, const struct aws_websocket_send_frame_options *options) { +static int s_send_frame( + struct aws_websocket *websocket, + const struct aws_websocket_send_frame_options *options, + bool from_public_api) { AWS_ASSERT(websocket); AWS_ASSERT(options); @@ -376,7 +495,9 @@ int aws_websocket_send_frame(struct aws_websocket *websocket, const struct aws_w /* BEGIN CRITICAL SECTION */ s_lock_synced_data(websocket); - if (websocket->synced_data.send_frame_error_code) { + if (websocket->synced_data.is_midchannel_handler && from_public_api) { + send_error = AWS_ERROR_HTTP_WEBSOCKET_IS_MIDCHANNEL_HANDLER; + } else if (websocket->synced_data.send_frame_error_code) { send_error = websocket->synced_data.send_frame_error_code; } else { aws_linked_list_push_back(&websocket->synced_data.outgoing_frame_list, &frame->node); @@ -419,6 +540,10 @@ int aws_websocket_send_frame(struct aws_websocket *websocket, const struct aws_w return AWS_OP_SUCCESS; } +int aws_websocket_send_frame(struct aws_websocket *websocket, const struct aws_websocket_send_frame_options *options) { + return s_send_frame(websocket, options, true); +} + static void s_move_synced_data_to_thread_task(struct aws_channel_task *task, void *arg, enum aws_task_status status) { (void)task; if (status != AWS_TASK_STATUS_RUN_READY) { @@ -702,10 +827,59 @@ static int s_handler_process_write_message( struct aws_channel_slot *slot, struct aws_io_message *message) { - (void)handler; (void)slot; - (void)message; - return aws_raise_error(AWS_ERROR_UNIMPLEMENTED); + struct aws_websocket *websocket = handler->impl; + AWS_ASSERT(aws_channel_thread_is_callers_thread(websocket->channel_slot->channel)); + + /* For each aws_io_message headed in the write direction, send a BINARY frame, + * where the frame's payload is the data from this aws_io_message. */ + struct aws_websocket_send_frame_options options = { + .payload_length = message->message_data.len, + .user_data = message, + .stream_outgoing_payload = s_midchannel_send_payload, + .on_complete = s_midchannel_send_complete, + .opcode = AWS_WEBSOCKET_OPCODE_BINARY, + .fin = true, + }; + + /* Use copy_mark to track progress as the data is streamed out */ + message->copy_mark = 0; + + int err = s_send_frame(websocket, &options, false); + if (err) { + /* TODO: mqtt handler needs to clean up messsages that fail to send. */ + return AWS_OP_ERR; + } + + return AWS_OP_SUCCESS; +} + +/* Callback for writing data from downstream aws_io_messages into payload of BINARY frames headed upstream */ +static bool s_midchannel_send_payload(struct aws_websocket *websocket, struct aws_byte_buf *out_buf, void *user_data) { + (void)websocket; + struct aws_io_message *io_msg = user_data; + + /* copy_mark is used to track progress */ + size_t src_available = io_msg->message_data.capacity - io_msg->copy_mark; + size_t dst_available = out_buf->capacity - out_buf->len; + size_t sending = dst_available < src_available ? dst_available : src_available; + + bool success = aws_byte_buf_write(out_buf, io_msg->message_data.buffer + io_msg->copy_mark, sending); + + io_msg->copy_mark += sending; + return success; +} + +/* Callback when data from downstream aws_io_messages, finishes being sent as a BINARY frame upstream. */ +static void s_midchannel_send_complete(struct aws_websocket *websocket, int error_code, void *user_data) { + (void)websocket; + struct aws_io_message *io_msg = user_data; + + if (io_msg->on_completion) { + io_msg->on_completion(io_msg->owning_channel, io_msg, error_code, io_msg->user_data); + } + + aws_mem_release(io_msg->allocator, io_msg); } static void s_destroy_outgoing_frame(struct aws_websocket *websocket, struct outgoing_frame *frame, int error_code) { @@ -847,10 +1021,25 @@ static void s_schedule_channel_shutdown(struct aws_websocket *websocket, int err } void aws_websocket_close(struct aws_websocket *websocket, bool free_scarce_resources_immediately) { - int error_code = AWS_ERROR_SUCCESS; + bool is_midchannel_handler; + + /* BEGIN CRITICAL SECTION */ + s_lock_synced_data(websocket); + is_midchannel_handler = websocket->synced_data.is_midchannel_handler; + s_unlock_synced_data(websocket); + /* END CRITICAL SECTION */ + + if (is_midchannel_handler) { + AWS_LOGF_ERROR( + AWS_LS_HTTP_WEBSOCKET, + "id=%p: Ignoring close call, websocket has converted to midchannel handler.", + (void *)websocket); + return; + } /* TODO: aws_channel_shutdown() should let users specify error_code and "immediate" as separate parameters. * Currently, any non-zero error_code results in "immediate" shutdown */ + int error_code = AWS_ERROR_SUCCESS; if (free_scarce_resources_immediately) { error_code = AWS_ERROR_HTTP_CONNECTION_CLOSED; } @@ -964,7 +1153,7 @@ static void s_finish_shutdown(struct aws_websocket *websocket) { s_destroy_outgoing_frame(websocket, frame, AWS_ERROR_HTTP_CONNECTION_CLOSED); } - if (websocket->on_connection_shutdown) { + if (websocket->on_connection_shutdown && !websocket->thread_data.is_midchannel_handler) { AWS_LOGF_TRACE(AWS_LS_HTTP_WEBSOCKET, "id=%p: Invoking user's shutdown callback.", (void *)websocket); websocket->on_connection_shutdown( websocket, websocket->thread_data.channel_shutdown_error_code, websocket->user_data); @@ -1034,7 +1223,8 @@ static int s_handler_process_read_message( if (err) { AWS_LOGF_ERROR( AWS_LS_HTTP_WEBSOCKET, - "id=%p: Failed to increment read window after message processing, error %d (%s). Closing connection.", + "id=%p: Failed to increment read window after message processing, error %d (%s). Closing " + "connection.", (void *)websocket, aws_last_error(), aws_error_name(aws_last_error())); @@ -1078,7 +1268,7 @@ static int s_decoder_on_frame(const struct aws_websocket_frame *frame, void *use /* Invoke user cb */ bool callback_result = true; - if (websocket->on_incoming_frame_begin) { + if (websocket->on_incoming_frame_begin && !websocket->thread_data.is_midchannel_handler) { callback_result = websocket->on_incoming_frame_begin( websocket, websocket->thread_data.current_incoming_frame, websocket->user_data); } @@ -1100,7 +1290,7 @@ static int s_decoder_on_payload(struct aws_byte_cursor data, void *user_data) { /* Invoke user cb */ bool callback_result = true; - if (websocket->on_incoming_frame_payload) { + if (websocket->on_incoming_frame_payload && !websocket->thread_data.is_midchannel_handler) { size_t window_update_size = data.len; callback_result = websocket->on_incoming_frame_payload( @@ -1150,7 +1340,7 @@ static void s_complete_incoming_frame(struct aws_websocket *websocket, int error /* Invoke user cb */ bool callback_result = true; - if (websocket->on_incoming_frame_complete) { + if (websocket->on_incoming_frame_complete && !websocket->thread_data.is_midchannel_handler) { callback_result = websocket->on_incoming_frame_complete( websocket, websocket->thread_data.current_incoming_frame, error_code, websocket->user_data); } @@ -1180,7 +1370,8 @@ static int s_handler_increment_read_window( (void)handler; (void)slot; (void)size; - return aws_raise_error(AWS_ERROR_UNIMPLEMENTED); + /* TODO: implement */ + return AWS_OP_SUCCESS; } static void s_increment_read_window_action(struct aws_websocket *websocket, size_t size) { @@ -1230,25 +1421,17 @@ void aws_websocket_increment_read_window(struct aws_websocket *websocket, size_t return; } - /* If we're on thread just do it. */ - if (aws_channel_thread_is_callers_thread(websocket->channel_slot->channel)) { - AWS_LOGF_TRACE( - AWS_LS_HTTP_WEBSOCKET, - "id=%p: Incrementing read window immediately with size %zu.", - (void *)websocket, - size); - s_increment_read_window_action(websocket, size); - return; - } - - /* Otherwise schedule a task to do it. + /* Schedule a task to do the increment. * If task is already scheduled, just increase size to be incremented */ + bool is_midchannel_handler = false; bool should_schedule_task = false; /* BEGIN CRITICAL SECTION */ s_lock_synced_data(websocket); - if (websocket->synced_data.window_increment_size == 0) { + if (websocket->synced_data.is_midchannel_handler) { + is_midchannel_handler = true; + } else if (websocket->synced_data.window_increment_size == 0) { should_schedule_task = true; websocket->synced_data.window_increment_size = size; } else { @@ -1259,7 +1442,12 @@ void aws_websocket_increment_read_window(struct aws_websocket *websocket, size_t s_unlock_synced_data(websocket); /* END CRITICAL SECTION */ - if (should_schedule_task) { + if (is_midchannel_handler) { + AWS_LOGF_TRACE( + AWS_LS_HTTP_WEBSOCKET, + "id=%p: Ignoring window increment call, websocket has converted to midchannel handler.", + (void *)websocket); + } else if (should_schedule_task) { AWS_LOGF_TRACE( AWS_LS_HTTP_WEBSOCKET, "id=%p: Scheduling task to increment read window by %zu.", (void *)websocket, size); aws_channel_schedule_task_now(websocket->channel_slot->channel, &websocket->increment_read_window_task); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d82ee6c44..faa7ec712 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -98,7 +98,6 @@ add_test_case(websocket_handler_shutdown_automatically_sends_close_frame) add_test_case(websocket_handler_shutdown_handles_queued_close_frame) add_test_case(websocket_handler_shutdown_immediately_in_emergency) add_test_case(websocket_handler_shutdown_handles_unexpected_write_error) -add_test_case(websocket_handler_shutdown_on_zero_refcount) add_test_case(websocket_handler_close_on_thread) add_test_case(websocket_handler_close_off_thread) add_test_case(websocket_handler_read_frame) @@ -111,6 +110,13 @@ add_test_case(websocket_handler_read_halts_if_complete_fn_returns_false) add_test_case(websocket_handler_window_reopens_by_default) add_test_case(websocket_handler_window_manual_increment) add_test_case(websocket_handler_window_manual_increment_off_thread) +add_test_case(websocket_midchannel_sanity_check) +add_test_case(websocket_midchannel_write_message) +add_test_case(websocket_midchannel_write_multiple_messages) +add_test_case(websocket_midchannel_write_huge_message) +#add_test_case(websocket_midchannel_frames_sent_before_conversion_succeed) +#add_test_case(websocket_midchannel_no_more_incoming_frame_callbacks) +#add_test_case(websocket_midchannel_no_more_shutdown_callback) ??? add_test_case(hpack_encode_integer) add_test_case(hpack_decode_integer) diff --git a/tests/test_websocket_handler.c b/tests/test_websocket_handler.c index 79bcc6f5f..729c92288 100644 --- a/tests/test_websocket_handler.c +++ b/tests/test_websocket_handler.c @@ -50,6 +50,8 @@ struct tester { struct testing_channel testing_channel; struct aws_websocket *websocket; + bool is_midchannel_handler; + size_t on_shutdown_count; int shutdown_error_code; @@ -60,6 +62,7 @@ struct tester { * We're not testing the decoder here, just using it as a tool (decoder tests go in test_websocket_decoder.c). */ struct written_frame written_frames[100]; size_t num_written_frames; + size_t num_written_io_messages; struct aws_websocket_decoder written_frame_decoder; /* Frames reported via the websocket's on_incoming_frame callbacks are recorded here */ @@ -77,6 +80,10 @@ struct tester { size_t num_readpush_frames; size_t readpush_frame_index; struct aws_websocket_encoder readpush_encoder; + + /* For pushing messages upstream, to test a websocket that's been converted to midchannel handler. */ + size_t num_writepush_messages; + struct aws_byte_buf all_writepush_data; /* All data that's been writepushed, concatenated together */ }; /* Helps track the progress of a frame being sent. */ @@ -136,6 +143,8 @@ static int s_drain_written_messages(struct tester *tester) { struct aws_linked_list_node *node = aws_linked_list_pop_front(io_msgs); struct aws_io_message *msg = AWS_CONTAINER_OF(node, struct aws_io_message, queueing_handle); + tester->num_written_io_messages++; + struct aws_byte_cursor msg_cursor = aws_byte_cursor_from_buf(&msg->message_data); while (msg_cursor.len) { /* Make sure our arbitrarily sized buffer hasn't overflowed. */ @@ -399,6 +408,44 @@ static int s_readpush_check(struct tester *tester, size_t frame_i, int expected_ return AWS_OP_SUCCESS; } +static int s_writepush(struct tester *tester, struct aws_byte_cursor data) { + if (!tester->all_writepush_data.allocator) { + ASSERT_SUCCESS(aws_byte_buf_init(&tester->all_writepush_data, tester->alloc, data.len)); + } + + while (data.len) { + struct aws_io_message *msg = aws_channel_acquire_message_from_pool( + tester->testing_channel.channel, AWS_IO_MESSAGE_APPLICATION_DATA, data.len); + ASSERT_NOT_NULL(msg); + size_t chunk_size = msg->message_data.capacity < data.len ? msg->message_data.capacity : data.len; + struct aws_byte_cursor chunk = aws_byte_cursor_advance(&data, chunk_size); + ASSERT_NOT_NULL(chunk.ptr); + ASSERT_TRUE(aws_byte_buf_write_from_whole_cursor(&msg->message_data, chunk)); + ASSERT_SUCCESS(testing_channel_push_write_message(&tester->testing_channel, msg)); + + /* Update tracking data in tester */ + tester->num_writepush_messages++; + ASSERT_SUCCESS(aws_byte_buf_append_dynamic(&tester->all_writepush_data, &chunk)); + } + return AWS_OP_SUCCESS; +} + +/* Scan all written_frames, and ensure that payloads of the binary frames match data */ +static int s_writepush_check(struct tester *tester, size_t ignore_n_written_frames) { + struct aws_byte_cursor expected_cursor = aws_byte_cursor_from_buf(&tester->all_writepush_data); + for (size_t i = ignore_n_written_frames; i < tester->num_written_frames; ++i) { + struct written_frame *frame_i = &tester->written_frames[i]; + if (aws_websocket_is_data_frame(frame_i->def.opcode)) { + ASSERT_UINT_EQUALS(AWS_WEBSOCKET_OPCODE_BINARY, frame_i->def.opcode); + struct aws_byte_cursor expected_i = + aws_byte_cursor_advance(&expected_cursor, (size_t)frame_i->def.payload_length); + ASSERT_TRUE(expected_i.len > 0); + ASSERT_TRUE(aws_byte_cursor_eq_byte_buf(&expected_i, &frame_i->payload)); + } + } + return AWS_OP_SUCCESS; +} + static int s_tester_init(struct tester *tester, struct aws_allocator *alloc) { aws_load_error_strings(); aws_io_load_error_strings(); @@ -417,12 +464,9 @@ static int s_tester_init(struct tester *tester, struct aws_allocator *alloc) { ASSERT_SUCCESS(testing_channel_init(&tester->testing_channel, alloc)); - struct aws_channel_slot *channel_slot = aws_channel_slot_new(tester->testing_channel.channel); - ASSERT_NOT_NULL(channel_slot); - struct aws_websocket_handler_options ws_options = { .allocator = alloc, - .channel_slot = channel_slot, + .channel = tester->testing_channel.channel, .initial_window_size = SIZE_MAX, .user_data = tester, .on_incoming_frame_begin = s_on_incoming_frame_begin, @@ -430,13 +474,8 @@ static int s_tester_init(struct tester *tester, struct aws_allocator *alloc) { .on_incoming_frame_complete = s_on_incoming_frame_complete, .on_connection_shutdown = s_on_connection_shutdown, }; - struct aws_channel_handler *channel_handler = aws_websocket_handler_new(&ws_options); - ASSERT_NOT_NULL(channel_handler); - - tester->websocket = channel_handler->impl; - - ASSERT_SUCCESS(aws_channel_slot_insert_end(tester->testing_channel.channel, channel_slot)); - ASSERT_SUCCESS(aws_channel_slot_set_handler(channel_slot, channel_handler)); + tester->websocket = aws_websocket_handler_new(&ws_options); + ASSERT_NOT_NULL(tester->websocket); aws_websocket_decoder_init(&tester->written_frame_decoder, s_on_written_frame, s_on_written_frame_payload, tester); aws_websocket_encoder_init(&tester->readpush_encoder, s_stream_readpush_payload, tester); @@ -445,7 +484,11 @@ static int s_tester_init(struct tester *tester, struct aws_allocator *alloc) { } static int s_tester_clean_up(struct tester *tester) { - aws_channel_shutdown(tester->testing_channel.channel, AWS_ERROR_SUCCESS); + if (tester->is_midchannel_handler) { + aws_channel_shutdown(tester->testing_channel.channel, AWS_ERROR_SUCCESS); + } else { + aws_websocket_release(tester->websocket); + } ASSERT_SUCCESS(s_drain_written_messages(tester)); ASSERT_SUCCESS(testing_channel_clean_up(&tester->testing_channel)); @@ -458,11 +501,21 @@ static int s_tester_clean_up(struct tester *tester) { aws_byte_buf_clean_up(&tester->incoming_frames[i].payload); } + aws_byte_buf_clean_up(&tester->all_writepush_data); + aws_http_library_clean_up(); aws_logger_clean_up(&tester->logger); return AWS_OP_SUCCESS; } +static int s_install_downstream_handler(struct tester *tester, size_t initial_window) { + ASSERT_NOT_NULL(aws_websocket_convert_to_midchannel_handler(tester->websocket)); + tester->is_midchannel_handler = true; + + ASSERT_SUCCESS(testing_channel_install_downstream_handler(&tester->testing_channel, initial_window)); + return AWS_OP_SUCCESS; +} + static bool s_on_stream_outgoing_payload( struct aws_websocket *websocket, struct aws_byte_buf *out_buf, @@ -689,6 +742,8 @@ TEST_CASE(websocket_handler_send_huge_frame) { /* transmit giant buffer with random contents */ struct aws_byte_buf giant_buf; ASSERT_SUCCESS(aws_byte_buf_init(&giant_buf, allocator, 100000)); + while (aws_byte_buf_write_be32(&giant_buf, (uint32_t)rand())) { + } while (aws_byte_buf_write_u8(&giant_buf, (uint8_t)rand())) { } @@ -1249,21 +1304,6 @@ TEST_CASE(websocket_handler_shutdown_handles_unexpected_write_error) { return AWS_OP_SUCCESS; } -TEST_CASE(websocket_handler_shutdown_on_zero_refcount) { - (void)ctx; - struct tester tester; - ASSERT_SUCCESS(s_tester_init(&tester, allocator)); - - aws_websocket_acquire_hold(tester.websocket); - aws_websocket_release_hold(tester.websocket); - - ASSERT_SUCCESS(s_drain_written_messages(&tester)); - ASSERT_UINT_EQUALS(1, tester.on_shutdown_count); - - ASSERT_SUCCESS(s_tester_clean_up(&tester)); - return AWS_OP_SUCCESS; -} - TEST_CASE(websocket_handler_close_on_thread) { (void)ctx; struct tester tester; @@ -1658,3 +1698,82 @@ TEST_CASE(websocket_handler_window_manual_increment_off_thread) { (void)ctx; return s_window_manual_increment_common(allocator, false); } + +TEST_CASE(websocket_midchannel_sanity_check) { + (void)ctx; + struct tester tester; + ASSERT_SUCCESS(s_tester_init(&tester, allocator)); + ASSERT_SUCCESS(s_install_downstream_handler(&tester, SIZE_MAX)); + ASSERT_SUCCESS(s_tester_clean_up(&tester)); + return AWS_OP_SUCCESS; +} + +TEST_CASE(websocket_midchannel_write_message) { + (void)ctx; + struct tester tester; + ASSERT_SUCCESS(s_tester_init(&tester, allocator)); + ASSERT_SUCCESS(s_install_downstream_handler(&tester, SIZE_MAX)); + + /* Write data */ + struct aws_byte_cursor writing = aws_byte_cursor_from_c_str("My hat it has three corners"); + ASSERT_SUCCESS(s_writepush(&tester, writing)); + + /* Compare results */ + ASSERT_SUCCESS(s_drain_written_messages(&tester)); + ASSERT_SUCCESS(s_writepush_check(&tester, 0)); + + ASSERT_SUCCESS(s_tester_clean_up(&tester)); + return AWS_OP_SUCCESS; +} + +TEST_CASE(websocket_midchannel_write_multiple_messages) { + (void)ctx; + struct tester tester; + ASSERT_SUCCESS(s_tester_init(&tester, allocator)); + ASSERT_SUCCESS(s_install_downstream_handler(&tester, SIZE_MAX)); + + struct aws_byte_cursor writing[] = { + aws_byte_cursor_from_c_str("My hat it has three corners."), + aws_byte_cursor_from_c_str("Three corners has my hat."), + aws_byte_cursor_from_c_str("And had it not three corners, it would not be my hat."), + }; + + /* Write data */ + for (size_t i = 0; i < AWS_ARRAY_SIZE(writing); ++i) { + ASSERT_SUCCESS(s_writepush(&tester, writing[i])); + } + + /* Compare results */ + ASSERT_SUCCESS(s_drain_written_messages(&tester)); + ASSERT_SUCCESS(s_writepush_check(&tester, 0)); + + ASSERT_SUCCESS(s_tester_clean_up(&tester)); + return AWS_OP_SUCCESS; +} + +TEST_CASE(websocket_midchannel_write_huge_message) { + (void)ctx; + struct tester tester; + ASSERT_SUCCESS(s_tester_init(&tester, allocator)); + ASSERT_SUCCESS(s_install_downstream_handler(&tester, SIZE_MAX)); + + /* Fill big buffer with random data */ + struct aws_byte_buf writing; + ASSERT_SUCCESS(aws_byte_buf_init(&writing, allocator, 1000000)); + while (aws_byte_buf_write_be32(&writing, (uint32_t)rand())) { + } + while (aws_byte_buf_write_u8(&writing, (uint8_t)rand())) { + } + + /* Send as multiple aws_io_messages that are as full as they can be */ + ASSERT_SUCCESS(s_writepush(&tester, aws_byte_cursor_from_buf(&writing))); + + /* Compare results */ + ASSERT_SUCCESS(s_drain_written_messages(&tester)); + ASSERT_TRUE(tester.num_written_io_messages > 1); /* Assert that message was huge enough to stress limits */ + ASSERT_SUCCESS(s_writepush_check(&tester, 0)); + + aws_byte_buf_clean_up(&writing); + ASSERT_SUCCESS(s_tester_clean_up(&tester)); + return AWS_OP_SUCCESS; +}