diff --git a/awscrt/mqtt.py b/awscrt/mqtt.py index 0c24ae79d..6990aa34f 100644 --- a/awscrt/mqtt.py +++ b/awscrt/mqtt.py @@ -127,6 +127,39 @@ def __init__(self, topic, qos, payload, retain): self.retain = retain +@dataclass +class OnConnectionSuccessData: + """Dataclass containing data related to a on_connection_success Callback + + Args: + return_code (ConnectReturnCode): Connect return. code received from the server. + session_present (bool): True if the connection resumes an existing session. + False if new session. Note that the server has forgotten all previous subscriptions + if this is False. + Subscriptions can be re-established via resubscribe_existing_topics() if the connection was a reconnection. + """ + return_code: ConnectReturnCode = None + session_present: bool = False + + +@dataclass +class OnConnectionFailureData: + """Dataclass containing data related to a on_connection_failure Callback + + Args: + error (ConnectReturnCode): Error code with reason for connection failure + """ + error: awscrt.exceptions.AwsCrtError = None + + +@dataclass +class OnConnectionClosedData: + """Dataclass containing data related to a on_connection_closed Callback. + Currently unused. + """ + pass + + class Client(NativeResource): """MQTT client. @@ -213,6 +246,31 @@ class Connection(NativeResource): * `**kwargs` (dict): Forward-compatibility kwargs. + on_connection_success: Optional callback invoked whenever the connection successfully connects. + This callback is invoked for every successful connect and every successful reconnect. + + Function should take the following arguments and return nothing: + + * `connection` (:class:`Connection`): This MQTT Connection + + * `callback_data` (:class:`OnConnectionSuccessData`): The data returned from the connection success. + + on_connection_failure: Optional callback invoked whenever the connection fails to connect. + This callback is invoked for every failed connect and every failed reconnect. + + Function should take the following arguments and return nothing: + + * `connection` (:class:`Connection`): This MQTT Connection + + * `callback_data` (:class:`OnConnectionFailureData`): The data returned from the connection failure. + + on_connection_closed: Optional callback invoked whenever the connection has been disconnected and shutdown successfully. + Function should take the following arguments and return nothing: + + * `connection` (:class:`Connection`): This MQTT Connection + + * `callback_data` (:class:`OnConnectionClosedData`): The data returned from the connection close. + reconnect_min_timeout_secs (int): Minimum time to wait between reconnect attempts. Must be <= `reconnect_max_timeout_secs`. Wait starts at min and doubles with each attempt until max is reached. @@ -286,7 +344,10 @@ def __init__(self, use_websockets=False, websocket_proxy_options=None, websocket_handshake_transform=None, - proxy_options=None + proxy_options=None, + on_connection_success=None, + on_connection_failure=None, + on_connection_closed=None ): assert isinstance(client, Client) @@ -297,6 +358,9 @@ def __init__(self, assert isinstance(websocket_proxy_options, HttpProxyOptions) or websocket_proxy_options is None assert isinstance(proxy_options, HttpProxyOptions) or proxy_options is None assert callable(websocket_handshake_transform) or websocket_handshake_transform is None + assert callable(on_connection_success) or on_connection_success is None + assert callable(on_connection_failure) or on_connection_failure is None + assert callable(on_connection_closed) or on_connection_closed is None if reconnect_min_timeout_secs > reconnect_max_timeout_secs: raise ValueError("'reconnect_min_timeout_secs' cannot exceed 'reconnect_max_timeout_secs'") @@ -316,6 +380,9 @@ def __init__(self, self._on_connection_resumed_cb = on_connection_resumed self._use_websockets = use_websockets self._ws_handshake_transform_cb = websocket_handshake_transform + self._on_connection_success_cb = on_connection_success + self._on_connection_failure_cb = on_connection_failure + self._on_connection_closed_cb = on_connection_closed # may be changed at runtime, take effect the the next time connect/reconnect occurs self.client_id = client_id @@ -385,6 +452,26 @@ def _on_complete(f): if not future.done(): transform_args.set_done(e) + def _on_connection_closed(self): + if self: + if self._on_connection_closed_cb: + data = OnConnectionClosedData() + self._on_connection_closed_cb(connection=self, callback_data=data) + + def _on_connection_success(self, return_code, session_present): + if self: + if self._on_connection_success_cb: + data = OnConnectionSuccessData( + return_code=ConnectReturnCode(return_code), + session_present=session_present) + self._on_connection_success_cb(connection=self, callback_data=data) + + def _on_connection_failure(self, error_code): + if self: + if self._on_connection_failure_cb: + data = OnConnectionFailureData(error=awscrt.exceptions.from_code(error_code)) + self._on_connection_failure_cb(connection=self, callback_data=data) + def connect(self): """Open the actual connection to the server (async). diff --git a/source/mqtt_client_connection.c b/source/mqtt_client_connection.c index bd06d290e..8a78615e5 100644 --- a/source/mqtt_client_connection.c +++ b/source/mqtt_client_connection.c @@ -48,6 +48,10 @@ struct mqtt_connection_binding { * Lets us invoke callbacks on the python object without preventing the GC from cleaning it up. */ PyObject *self_proxy; + /* To not run into a segfault calling on_close with the connection being freed before the callback + * can be invoked, we need to keep the PyCapsule alive. */ + PyObject *self_capsule; + PyObject *on_connect; PyObject *on_any_publish; @@ -56,6 +60,10 @@ struct mqtt_connection_binding { }; static void s_mqtt_python_connection_finish_destruction(struct mqtt_connection_binding *py_connection) { + + /* Do not call the on_stopped callback if the python object is finished/destroyed */ + aws_mqtt_client_connection_set_connection_closed_handler(py_connection->native, NULL, NULL); + aws_mqtt_client_connection_release(py_connection->native); Py_DECREF(py_connection->self_proxy); @@ -69,7 +77,10 @@ static void s_mqtt_python_connection_destructor_on_disconnect( struct aws_mqtt_client_connection *connection, void *userdata) { - (void)connection; + if (connection == NULL || userdata == NULL) { + return; // The connection is dead - skip! + } + struct mqtt_connection_binding *py_connection = userdata; PyGILState_STATE state; @@ -87,6 +98,9 @@ static void s_mqtt_python_connection_destructor(PyObject *connection_capsule) { PyCapsule_GetPointer(connection_capsule, s_capsule_name_mqtt_client_connection); assert(py_connection); + /* This is the destructor from Python - so we can ignore the closed callback here */ + aws_mqtt_client_connection_set_connection_closed_handler(py_connection->native, NULL, NULL); + if (aws_mqtt_client_connection_disconnect( py_connection->native, s_mqtt_python_connection_destructor_on_disconnect, py_connection)) { @@ -97,7 +111,9 @@ static void s_mqtt_python_connection_destructor(PyObject *connection_capsule) { static void s_on_connection_interrupted(struct aws_mqtt_client_connection *connection, int error_code, void *userdata) { - (void)connection; + if (connection == NULL || userdata == NULL) { + return; // The connection is dead - skip! + } struct mqtt_connection_binding *py_connection = userdata; @@ -126,6 +142,10 @@ static void s_on_connection_resumed( bool session_present, void *userdata) { + if (connection == NULL || userdata == NULL) { + return; // The connection is dead - skip! + } + (void)connection; struct mqtt_connection_binding *py_connection = userdata; @@ -147,6 +167,54 @@ static void s_on_connection_resumed( } } + /* call _on_connection_success */ + PyObject *success_result = + PyObject_CallMethod(self, "_on_connection_success", "(iN)", return_code, PyBool_FromLong(session_present)); + if (success_result) { + Py_DECREF(success_result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); + } + + PyGILState_Release(state); +} + +static void s_on_connection_closed( + struct aws_mqtt_client_connection *connection, + struct on_connection_closed_data *data, + void *userdata) { + + if (connection == NULL || userdata == NULL) { + return; // The connection is dead - skip! + } + (void)data; // Not used for anything currently, but in the future it could be. + + PyGILState_STATE state; + if (aws_py_gilstate_ensure(&state)) { + return; /* Python has shut down. Nothing matters anymore, but don't crash */ + } + + struct mqtt_connection_binding *py_connection = userdata; + /* Ensure that python class is still alive */ + PyObject *self = PyWeakref_GetObject(py_connection->self_proxy); /* borrowed reference */ + if (self != Py_None) { + PyObject *result = PyObject_CallMethod(self, "_on_connection_closed", "()"); + if (result) { + Py_DECREF(result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); + } + } + Py_DECREF(py_connection->self_proxy); + + /** Allow the PyCapsule to be freed like normal again. + * If this is the last reference (I.E customer code called disconnect and threw the Python object away) + * Then this will allow the MQTT311 class to be fully cleaned. + * If it is not the last reference (customer still has reference) then when the customer is done + * it will be freed like normal. + **/ + Py_DECREF(py_connection->self_capsule); + PyGILState_Release(state); } @@ -193,6 +261,12 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) { goto set_interruption_failed; } + if (aws_mqtt_client_connection_set_connection_closed_handler( + py_connection->native, s_on_connection_closed, py_connection)) { + PyErr_SetAwsLastError(); + goto set_interruption_failed; + } + if (PyObject_IsTrue(use_websocket_py)) { if (aws_mqtt_client_connection_use_websockets( py_connection->native, @@ -219,6 +293,7 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) { /* From hereon, nothing will fail */ + py_connection->self_capsule = capsule; py_connection->self_proxy = self_proxy; py_connection->client = client_py; @@ -253,15 +328,18 @@ static void s_on_connect( bool session_present, void *user_data) { - (void)connection; + if (connection == NULL || user_data == NULL) { + return; // The connection is dead - skip! + } struct mqtt_connection_binding *py_connection = user_data; + PyGILState_STATE state; + if (aws_py_gilstate_ensure(&state)) { + return; /* Python has shut down. Nothing matters anymore, but don't crash */ + } + if (py_connection->on_connect) { - PyGILState_STATE state; - if (aws_py_gilstate_ensure(&state)) { - return; /* Python has shut down. Nothing matters anymore, but don't crash */ - } PyObject *callback = py_connection->on_connect; py_connection->on_connect = NULL; @@ -275,15 +353,42 @@ static void s_on_connect( } Py_XDECREF(callback); + } - PyGILState_Release(state); + /* Call on_connection_success or failure based on the result */ + PyObject *self = PyWeakref_GetObject(py_connection->self_proxy); /* borrowed reference */ + if (self != Py_None) { + /* Successful connection - call _on_connection_success */ + if (error_code == AWS_ERROR_SUCCESS) { + PyObject *success_result = PyObject_CallMethod( + self, "_on_connection_success", "(iN)", return_code, PyBool_FromLong(session_present)); + if (success_result) { + Py_DECREF(success_result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); + } + /* Unsuccessful connection - call _on_connection_failure */ + } else { + PyObject *success_result = PyObject_CallMethod(self, "_on_connection_failure", "(i)", error_code); + if (success_result) { + Py_DECREF(success_result); + } else { + PyErr_WriteUnraisable(PyErr_Occurred()); + } + } } + + PyGILState_Release(state); } /* If unsuccessful, false is returned and a Python error has been set */ bool s_set_will(struct aws_mqtt_client_connection *connection, PyObject *will) { assert(will && (will != Py_None)); + if (connection == NULL) { + return false; // The connection is dead - skip! + } + bool success = false; /* These references all need to be cleaned up before function returns */ @@ -697,7 +802,10 @@ static void s_publish_complete( uint16_t packet_id, int error_code, void *userdata) { - (void)connection; + + if (connection == NULL || userdata == NULL) { + return; // The connection is dead - skip! + } struct publish_complete_userdata *metadata = userdata; assert(metadata); @@ -805,7 +913,9 @@ static void s_subscribe_callback( bool retain, void *user_data) { - (void)connection; + if (connection == NULL || user_data == NULL) { + return; // The connection is dead - skip! + } PyObject *callback = user_data; if (callback == Py_None) { @@ -858,7 +968,9 @@ static void s_suback_callback( int error_code, void *userdata) { - (void)connection; + if (connection == NULL || userdata == NULL) { + return; // The connection is dead - skip! + } PyObject *callback = userdata; AWS_FATAL_ASSERT(callback && callback != Py_None); @@ -967,7 +1079,10 @@ static void s_unsuback_callback( uint16_t packet_id, int error_code, void *userdata) { - (void)connection; + + if (connection == NULL || userdata == NULL) { + return; // The connection is dead - skip! + } PyObject *callback = userdata; @@ -1029,7 +1144,9 @@ static void s_suback_multi_callback( int error_code, void *userdata) { - (void)connection; + if (connection == NULL || userdata == NULL) { + return; // The connection is dead - skip! + } /* These must be DECREF'd when function ends */ PyObject *callback = userdata; @@ -1124,7 +1241,9 @@ PyObject *aws_py_mqtt_client_connection_resubscribe_existing_topics(PyObject *se static void s_on_disconnect(struct aws_mqtt_client_connection *connection, void *user_data) { - (void)connection; + if (connection == NULL || user_data == NULL) { + return; // The connection is dead - skip! + } PyObject *on_disconnect = user_data; @@ -1163,16 +1282,24 @@ PyObject *aws_py_mqtt_client_connection_disconnect(PyObject *self, PyObject *arg } Py_INCREF(on_disconnect); + Py_INCREF(connection->self_proxy); /* We need to keep self_proxy alive for on_closed, which will dec-ref this */ + Py_INCREF(connection->self_capsule); /* Do not allow the PyCapsule to be freed, we need it alive for on_closed */ int err = aws_mqtt_client_connection_disconnect(connection->native, s_on_disconnect, on_disconnect); if (err) { Py_DECREF(on_disconnect); + Py_DECREF(connection->self_proxy); + Py_DECREF(connection->self_capsule); return PyErr_AwsLastError(); } Py_RETURN_NONE; } +/******************************************************************************* + * Client Statistics + ******************************************************************************/ + PyObject *aws_py_mqtt_client_connection_get_stats(PyObject *self, PyObject *args) { (void)self; bool success = false; diff --git a/test/test_mqtt.py b/test/test_mqtt.py index 67a6499ae..8397c6807 100644 --- a/test/test_mqtt.py +++ b/test/test_mqtt.py @@ -3,7 +3,7 @@ from awscrt.io import ClientBootstrap, ClientTlsContext, DefaultHostResolver, EventLoopGroup, Pkcs11Lib, TlsContextOptions from awscrt import http -from awscrt.mqtt import Client, Connection, QoS, Will +from awscrt.mqtt import Client, Connection, QoS, Will, OnConnectionClosedData, OnConnectionFailureData, OnConnectionSuccessData, ConnectReturnCode from test import NativeResourceTest from concurrent.futures import Future import os @@ -29,7 +29,15 @@ class MqttConnectionTest(NativeResourceTest): TEST_TOPIC = '/test/me/senpai/' + str(uuid.uuid4()) TEST_MSG = 'NOTICE ME!'.encode('utf8') - def _create_connection(self, endpoint, tls_context, use_static_singletons=False): + def _create_connection( + self, + endpoint, + tls_context, + port=8883, + use_static_singletons=False, + on_connection_success_callback=None, + on_connection_failure_callback=None, + on_connection_closed_callback=None): if use_static_singletons: client = Client(tls_ctx=tls_context) else: @@ -42,7 +50,10 @@ def _create_connection(self, endpoint, tls_context, use_static_singletons=False) client=client, client_id=create_client_id(), host_name=endpoint, - port=8883) + port=port, + on_connection_closed=on_connection_closed_callback, + on_connection_failure=on_connection_failure_callback, + on_connection_success=on_connection_success_callback) return connection def test_connect_disconnect(self): @@ -382,6 +393,75 @@ def test_connect_publish_statistics_wait_disconnect(self): # disconnect connection.disconnect().result(TIMEOUT) + def test_connect_disconnect_with_callbacks_happy(self): + test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST") + test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT") + test_input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_KEY") + test_tls_opts = TlsContextOptions.create_client_with_mtls_from_path(test_input_cert, test_input_key) + test_tls = ClientTlsContext(test_tls_opts) + + onConnectionSuccessFuture = Future() + onConnectionClosedFuture = Future() + + def on_connection_success_callback(connection, callback_data: OnConnectionSuccessData): + onConnectionSuccessFuture.set_result( + {'return_code': callback_data.return_code, "session_present": callback_data.session_present}) + + def on_connection_failure_callback(connection, callback_data: OnConnectionFailureData): + pass + + def on_connection_closed_callback(connection, callback_data: OnConnectionClosedData): + onConnectionClosedFuture.set_result({}) + + connection = self._create_connection( + endpoint=test_input_endpoint, + tls_context=test_tls, + on_connection_success_callback=on_connection_success_callback, + on_connection_failure_callback=on_connection_failure_callback, + on_connection_closed_callback=on_connection_closed_callback) + connection.connect().result(TIMEOUT) + successData = onConnectionSuccessFuture.result(TIMEOUT) + self.assertEqual(successData['return_code'], ConnectReturnCode.ACCEPTED) + self.assertEqual(successData['session_present'], False) + connection.disconnect().result(TIMEOUT) + onConnectionClosedFuture.result(TIMEOUT) + + def test_connect_disconnect_with_callbacks_unhappy(self): + test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST") + test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT") + test_input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_KEY") + test_tls_opts = TlsContextOptions.create_client_with_mtls_from_path(test_input_cert, test_input_key) + test_tls = ClientTlsContext(test_tls_opts) + + onConnectionFailureFuture = Future() + + def on_connection_success_callback(connection, callback_data: OnConnectionSuccessData): + pass + + def on_connection_failure_callback(connection, callback_data: OnConnectionFailureData): + onConnectionFailureFuture.set_result({'error': callback_data.error}) + + def on_connection_closed_callback(connection, callback_data: OnConnectionClosedData): + pass + + connection = self._create_connection( + endpoint=test_input_endpoint, + tls_context=test_tls, + port=1234, + on_connection_success_callback=on_connection_success_callback, + on_connection_failure_callback=on_connection_failure_callback, + on_connection_closed_callback=on_connection_closed_callback) + + exception_occurred = False + try: + connection.connect().result(TIMEOUT) + except Exception: + exception_occurred = True + self.assertTrue(exception_occurred, "Exception did not occur when connecting with invalid arguments!") + + failureData = onConnectionFailureFuture.result(TIMEOUT) + self.assertTrue(failureData['error'] is not None) + # ============================================================== # MOSQUITTO CONNECTION TESTS # ==============================================================