diff --git a/source/mqtt_client_connection.c b/source/mqtt_client_connection.c index 526157f29..5cf5e6f96 100644 --- a/source/mqtt_client_connection.c +++ b/source/mqtt_client_connection.c @@ -109,6 +109,63 @@ static void s_mqtt_python_connection_destructor(PyObject *connection_capsule) { } } +static void s_on_connection_success( + struct aws_mqtt_client_connection *connection, + enum aws_mqtt_connect_return_code return_code, + bool session_present, + void *user_data) { + + if (connection == NULL || user_data == NULL) { + return; // The connection is dead - skip! + } + + struct mqtt_connection_binding *py_connection = user_data; + + PyGILState_STATE state; + if (aws_py_gilstate_ensure(&state)) { + return; /* Python has shut down. Nothing matters anymore, but don't crash */ + } + + PyObject *self = PyWeakref_GetObject(py_connection->self_proxy); /* borrowed reference */ + if (self != Py_None) { + PyObject *success_result = + PyObject_CallMethod(self, "_on_connection_success", "(iN)", return_code, PyBool_FromLong(session_present)); + if (success_result) { + Py_DECREF(success_result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); + } + } + + PyGILState_Release(state); +} + +static void s_on_connection_failure(struct aws_mqtt_client_connection *connection, int error_code, void *user_data) { + + if (connection == NULL || user_data == NULL) { + return; // The connection is dead - skip! + } + + struct mqtt_connection_binding *py_connection = user_data; + + PyGILState_STATE state; + if (aws_py_gilstate_ensure(&state)) { + return; /* Python has shut down. Nothing matters anymore, but don't crash */ + } + + PyObject *self = PyWeakref_GetObject(py_connection->self_proxy); /* borrowed reference */ + if (self != Py_None) { + PyObject *success_result = PyObject_CallMethod(self, "_on_connection_failure", "(i)", error_code); + if (success_result) { + Py_DECREF(success_result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); + } + } + + PyGILState_Release(state); +} + static void s_on_connection_interrupted(struct aws_mqtt_client_connection *connection, int error_code, void *userdata) { if (connection == NULL || userdata == NULL) { @@ -167,15 +224,6 @@ static void s_on_connection_resumed( } } - /* call _on_connection_success */ - PyObject *success_result = - PyObject_CallMethod(self, "_on_connection_success", "(iN)", return_code, PyBool_FromLong(session_present)); - if (success_result) { - Py_DECREF(success_result); - } else { - PyErr_WriteUnraisable(PyErr_Occurred()); - } - PyGILState_Release(state); } @@ -250,6 +298,12 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) { goto connection_new_failed; } + if (aws_mqtt_client_connection_set_connection_result_handlers( + py_connection->native, s_on_connection_success, py_connection, s_on_connection_failure, py_connection)) { + PyErr_SetAwsLastError(); + goto set_connection_handlers_failed; + } + if (aws_mqtt_client_connection_set_connection_interruption_handlers( py_connection->native, s_on_connection_interrupted, @@ -306,6 +360,7 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) { proxy_new_failed: use_websockets_failed: set_interruption_failed: +set_connection_handlers_failed: aws_mqtt_client_connection_release(py_connection->native); connection_new_failed: aws_mem_release(allocator, py_connection); @@ -355,29 +410,6 @@ static void s_on_connect( Py_XDECREF(callback); } - /* Call on_connection_success or failure based on the result */ - PyObject *self = PyWeakref_GetObject(py_connection->self_proxy); /* borrowed reference */ - if (self != Py_None) { - /* Successful connection - call _on_connection_success */ - if (error_code == AWS_ERROR_SUCCESS) { - PyObject *success_result = PyObject_CallMethod( - self, "_on_connection_success", "(iN)", return_code, PyBool_FromLong(session_present)); - if (success_result) { - Py_DECREF(success_result); - } else { - PyErr_WriteUnraisable(PyErr_Occurred()); - } - /* Unsuccessful connection - call _on_connection_failure */ - } else { - PyObject *success_result = PyObject_CallMethod(self, "_on_connection_failure", "(i)", error_code); - if (success_result) { - Py_DECREF(success_result); - } else { - PyErr_WriteUnraisable(PyErr_Occurred()); - } - } - } - PyGILState_Release(state); } diff --git a/test/test_mqtt.py b/test/test_mqtt.py index 8397c6807..41d722e48 100644 --- a/test/test_mqtt.py +++ b/test/test_mqtt.py @@ -35,9 +35,11 @@ def _create_connection( tls_context, port=8883, use_static_singletons=False, + client_id=None, on_connection_success_callback=None, on_connection_failure_callback=None, - on_connection_closed_callback=None): + on_connection_closed_callback=None, + on_connection_resumed_callback=None): if use_static_singletons: client = Client(tls_ctx=tls_context) else: @@ -48,12 +50,13 @@ def _create_connection( connection = Connection( client=client, - client_id=create_client_id(), + client_id=client_id if client_id else create_client_id(), host_name=endpoint, port=port, on_connection_closed=on_connection_closed_callback, on_connection_failure=on_connection_failure_callback, - on_connection_success=on_connection_success_callback) + on_connection_success=on_connection_success_callback, + on_connection_resumed=on_connection_resumed_callback) return connection def test_connect_disconnect(self): @@ -400,18 +403,18 @@ def test_connect_disconnect_with_callbacks_happy(self): test_tls_opts = TlsContextOptions.create_client_with_mtls_from_path(test_input_cert, test_input_key) test_tls = ClientTlsContext(test_tls_opts) - onConnectionSuccessFuture = Future() - onConnectionClosedFuture = Future() + on_connection_success_future = Future() + on_connection_closed_future = Future() def on_connection_success_callback(connection, callback_data: OnConnectionSuccessData): - onConnectionSuccessFuture.set_result( + on_connection_success_future.set_result( {'return_code': callback_data.return_code, "session_present": callback_data.session_present}) def on_connection_failure_callback(connection, callback_data: OnConnectionFailureData): pass def on_connection_closed_callback(connection, callback_data: OnConnectionClosedData): - onConnectionClosedFuture.set_result({}) + on_connection_closed_future.set_result({}) connection = self._create_connection( endpoint=test_input_endpoint, @@ -420,11 +423,11 @@ def on_connection_closed_callback(connection, callback_data: OnConnectionClosedD on_connection_failure_callback=on_connection_failure_callback, on_connection_closed_callback=on_connection_closed_callback) connection.connect().result(TIMEOUT) - successData = onConnectionSuccessFuture.result(TIMEOUT) - self.assertEqual(successData['return_code'], ConnectReturnCode.ACCEPTED) - self.assertEqual(successData['session_present'], False) + success_data = on_connection_success_future.result(TIMEOUT) + self.assertEqual(success_data['return_code'], ConnectReturnCode.ACCEPTED) + self.assertEqual(success_data['session_present'], False) connection.disconnect().result(TIMEOUT) - onConnectionClosedFuture.result(TIMEOUT) + on_connection_closed_future.result(TIMEOUT) def test_connect_disconnect_with_callbacks_unhappy(self): test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST") @@ -433,13 +436,13 @@ def test_connect_disconnect_with_callbacks_unhappy(self): test_tls_opts = TlsContextOptions.create_client_with_mtls_from_path(test_input_cert, test_input_key) test_tls = ClientTlsContext(test_tls_opts) - onConnectionFailureFuture = Future() + on_onnection_failure_future = Future() def on_connection_success_callback(connection, callback_data: OnConnectionSuccessData): pass def on_connection_failure_callback(connection, callback_data: OnConnectionFailureData): - onConnectionFailureFuture.set_result({'error': callback_data.error}) + on_onnection_failure_future.set_result({'error': callback_data.error}) def on_connection_closed_callback(connection, callback_data: OnConnectionClosedData): pass @@ -459,8 +462,77 @@ def on_connection_closed_callback(connection, callback_data: OnConnectionClosedD exception_occurred = True self.assertTrue(exception_occurred, "Exception did not occur when connecting with invalid arguments!") - failureData = onConnectionFailureFuture.result(TIMEOUT) - self.assertTrue(failureData['error'] is not None) + failure_data = on_onnection_failure_future.result(TIMEOUT) + self.assertTrue(failure_data['error'] is not None) + + def test_connect_disconnect_with_callbacks_happy_on_resume(self): + # Check that an on_connection_success callback fires on a resumed connection. + + # NOTE Since there is no mocked server available on this abstraction level, the only sensible approach + # is to interrupt a connection, and wait for it to be resumed automatically. For that, another client + # with the same client_id connects to the server and then immediately disconnects. + + test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST") + test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT") + test_input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_KEY") + test_tls_opts = TlsContextOptions.create_client_with_mtls_from_path(test_input_cert, test_input_key) + test_tls = ClientTlsContext(test_tls_opts) + + on_connection_success_future = Future() + on_connection_closed_future = Future() + on_connection_resumed_future = Future() + + def on_connection_success_callback(connection, callback_data: OnConnectionSuccessData): + on_connection_success_future.set_result( + {'return_code': callback_data.return_code, "session_present": callback_data.session_present}) + + def on_connection_closed_callback(connection, callback_data: OnConnectionClosedData): + on_connection_closed_future.set_result({}) + + def on_connection_resumed_callback(connection, return_code: ConnectReturnCode, session_present): + on_connection_resumed_future.set_result( + {'return_code': return_code, "session_present": session_present}) + + connection = self._create_connection( + endpoint=test_input_endpoint, + tls_context=test_tls, + on_connection_success_callback=on_connection_success_callback, + on_connection_closed_callback=on_connection_closed_callback, + on_connection_resumed_callback=on_connection_resumed_callback) + connection.connect().result(TIMEOUT) + success_data = on_connection_success_future.result(TIMEOUT) + self.assertEqual(success_data['return_code'], ConnectReturnCode.ACCEPTED) + self.assertEqual(success_data['session_present'], False) + + # Reset the future for the reconnect attempt. + on_connection_success_future = Future() + + on_connection_success_future_dup = Future() + + def on_connection_success_callback_dup(connection, callback_data: OnConnectionSuccessData): + on_connection_success_future_dup.set_result({}) + + # Reuse the same client_id to displace the first client. + connection_dup = self._create_connection( + endpoint=test_input_endpoint, + tls_context=test_tls, + client_id=connection.client_id, + on_connection_success_callback=on_connection_success_callback_dup) + + connection_dup.connect().result(TIMEOUT) + on_connection_success_future_dup.result(TIMEOUT) + connection_dup.disconnect().result(TIMEOUT) + + # After the second client disconnects, the first one should reconnect, + # and on_connection_success callback should be fired once again. + on_connection_resumed_future.result(TIMEOUT) + success_data = on_connection_success_future.result(TIMEOUT) + + self.assertEqual(success_data['return_code'], ConnectReturnCode.ACCEPTED) + self.assertEqual(success_data['session_present'], False) + + connection.disconnect().result(TIMEOUT) + on_connection_closed_future.result(TIMEOUT) # ============================================================== # MOSQUITTO CONNECTION TESTS