Skip to content

Commit

Permalink
Use on_connection_* callbacks from aws-c-mqtt (#490)
Browse files Browse the repository at this point in the history
* Use on_connection_success and on_connection_failure handlers
* Add test for on_connection_success fired on resume
  • Loading branch information
sfod committed Jul 18, 2023
1 parent b17685b commit 18e9169
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 47 deletions.
96 changes: 64 additions & 32 deletions source/mqtt_client_connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,63 @@ static void s_mqtt_python_connection_destructor(PyObject *connection_capsule) {
}
}

static void s_on_connection_success(
struct aws_mqtt_client_connection *connection,
enum aws_mqtt_connect_return_code return_code,
bool session_present,
void *user_data) {

if (connection == NULL || user_data == NULL) {
return; // The connection is dead - skip!
}

struct mqtt_connection_binding *py_connection = user_data;

PyGILState_STATE state;
if (aws_py_gilstate_ensure(&state)) {
return; /* Python has shut down. Nothing matters anymore, but don't crash */
}

PyObject *self = PyWeakref_GetObject(py_connection->self_proxy); /* borrowed reference */
if (self != Py_None) {
PyObject *success_result =
PyObject_CallMethod(self, "_on_connection_success", "(iN)", return_code, PyBool_FromLong(session_present));
if (success_result) {
Py_DECREF(success_result);
} else {
PyErr_WriteUnraisable(PyErr_Occurred());
}
}

PyGILState_Release(state);
}

static void s_on_connection_failure(struct aws_mqtt_client_connection *connection, int error_code, void *user_data) {

if (connection == NULL || user_data == NULL) {
return; // The connection is dead - skip!
}

struct mqtt_connection_binding *py_connection = user_data;

PyGILState_STATE state;
if (aws_py_gilstate_ensure(&state)) {
return; /* Python has shut down. Nothing matters anymore, but don't crash */
}

PyObject *self = PyWeakref_GetObject(py_connection->self_proxy); /* borrowed reference */
if (self != Py_None) {
PyObject *success_result = PyObject_CallMethod(self, "_on_connection_failure", "(i)", error_code);
if (success_result) {
Py_DECREF(success_result);
} else {
PyErr_WriteUnraisable(PyErr_Occurred());
}
}

PyGILState_Release(state);
}

static void s_on_connection_interrupted(struct aws_mqtt_client_connection *connection, int error_code, void *userdata) {

if (connection == NULL || userdata == NULL) {
Expand Down Expand Up @@ -167,15 +224,6 @@ static void s_on_connection_resumed(
}
}

/* call _on_connection_success */
PyObject *success_result =
PyObject_CallMethod(self, "_on_connection_success", "(iN)", return_code, PyBool_FromLong(session_present));
if (success_result) {
Py_DECREF(success_result);
} else {
PyErr_WriteUnraisable(PyErr_Occurred());
}

PyGILState_Release(state);
}

Expand Down Expand Up @@ -250,6 +298,12 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) {
goto connection_new_failed;
}

if (aws_mqtt_client_connection_set_connection_result_handlers(
py_connection->native, s_on_connection_success, py_connection, s_on_connection_failure, py_connection)) {
PyErr_SetAwsLastError();
goto set_connection_handlers_failed;
}

if (aws_mqtt_client_connection_set_connection_interruption_handlers(
py_connection->native,
s_on_connection_interrupted,
Expand Down Expand Up @@ -306,6 +360,7 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) {
proxy_new_failed:
use_websockets_failed:
set_interruption_failed:
set_connection_handlers_failed:
aws_mqtt_client_connection_release(py_connection->native);
connection_new_failed:
aws_mem_release(allocator, py_connection);
Expand Down Expand Up @@ -355,29 +410,6 @@ static void s_on_connect(
Py_XDECREF(callback);
}

/* Call on_connection_success or failure based on the result */
PyObject *self = PyWeakref_GetObject(py_connection->self_proxy); /* borrowed reference */
if (self != Py_None) {
/* Successful connection - call _on_connection_success */
if (error_code == AWS_ERROR_SUCCESS) {
PyObject *success_result = PyObject_CallMethod(
self, "_on_connection_success", "(iN)", return_code, PyBool_FromLong(session_present));
if (success_result) {
Py_DECREF(success_result);
} else {
PyErr_WriteUnraisable(PyErr_Occurred());
}
/* Unsuccessful connection - call _on_connection_failure */
} else {
PyObject *success_result = PyObject_CallMethod(self, "_on_connection_failure", "(i)", error_code);
if (success_result) {
Py_DECREF(success_result);
} else {
PyErr_WriteUnraisable(PyErr_Occurred());
}
}
}

PyGILState_Release(state);
}

Expand Down
102 changes: 87 additions & 15 deletions test/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def _create_connection(
tls_context,
port=8883,
use_static_singletons=False,
client_id=None,
on_connection_success_callback=None,
on_connection_failure_callback=None,
on_connection_closed_callback=None):
on_connection_closed_callback=None,
on_connection_resumed_callback=None):
if use_static_singletons:
client = Client(tls_ctx=tls_context)
else:
Expand All @@ -48,12 +50,13 @@ def _create_connection(

connection = Connection(
client=client,
client_id=create_client_id(),
client_id=client_id if client_id else create_client_id(),
host_name=endpoint,
port=port,
on_connection_closed=on_connection_closed_callback,
on_connection_failure=on_connection_failure_callback,
on_connection_success=on_connection_success_callback)
on_connection_success=on_connection_success_callback,
on_connection_resumed=on_connection_resumed_callback)
return connection

def test_connect_disconnect(self):
Expand Down Expand Up @@ -400,18 +403,18 @@ def test_connect_disconnect_with_callbacks_happy(self):
test_tls_opts = TlsContextOptions.create_client_with_mtls_from_path(test_input_cert, test_input_key)
test_tls = ClientTlsContext(test_tls_opts)

onConnectionSuccessFuture = Future()
onConnectionClosedFuture = Future()
on_connection_success_future = Future()
on_connection_closed_future = Future()

def on_connection_success_callback(connection, callback_data: OnConnectionSuccessData):
onConnectionSuccessFuture.set_result(
on_connection_success_future.set_result(
{'return_code': callback_data.return_code, "session_present": callback_data.session_present})

def on_connection_failure_callback(connection, callback_data: OnConnectionFailureData):
pass

def on_connection_closed_callback(connection, callback_data: OnConnectionClosedData):
onConnectionClosedFuture.set_result({})
on_connection_closed_future.set_result({})

connection = self._create_connection(
endpoint=test_input_endpoint,
Expand All @@ -420,11 +423,11 @@ def on_connection_closed_callback(connection, callback_data: OnConnectionClosedD
on_connection_failure_callback=on_connection_failure_callback,
on_connection_closed_callback=on_connection_closed_callback)
connection.connect().result(TIMEOUT)
successData = onConnectionSuccessFuture.result(TIMEOUT)
self.assertEqual(successData['return_code'], ConnectReturnCode.ACCEPTED)
self.assertEqual(successData['session_present'], False)
success_data = on_connection_success_future.result(TIMEOUT)
self.assertEqual(success_data['return_code'], ConnectReturnCode.ACCEPTED)
self.assertEqual(success_data['session_present'], False)
connection.disconnect().result(TIMEOUT)
onConnectionClosedFuture.result(TIMEOUT)
on_connection_closed_future.result(TIMEOUT)

def test_connect_disconnect_with_callbacks_unhappy(self):
test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
Expand All @@ -433,13 +436,13 @@ def test_connect_disconnect_with_callbacks_unhappy(self):
test_tls_opts = TlsContextOptions.create_client_with_mtls_from_path(test_input_cert, test_input_key)
test_tls = ClientTlsContext(test_tls_opts)

onConnectionFailureFuture = Future()
on_onnection_failure_future = Future()

def on_connection_success_callback(connection, callback_data: OnConnectionSuccessData):
pass

def on_connection_failure_callback(connection, callback_data: OnConnectionFailureData):
onConnectionFailureFuture.set_result({'error': callback_data.error})
on_onnection_failure_future.set_result({'error': callback_data.error})

def on_connection_closed_callback(connection, callback_data: OnConnectionClosedData):
pass
Expand All @@ -459,8 +462,77 @@ def on_connection_closed_callback(connection, callback_data: OnConnectionClosedD
exception_occurred = True
self.assertTrue(exception_occurred, "Exception did not occur when connecting with invalid arguments!")

failureData = onConnectionFailureFuture.result(TIMEOUT)
self.assertTrue(failureData['error'] is not None)
failure_data = on_onnection_failure_future.result(TIMEOUT)
self.assertTrue(failure_data['error'] is not None)

def test_connect_disconnect_with_callbacks_happy_on_resume(self):
# Check that an on_connection_success callback fires on a resumed connection.

# NOTE Since there is no mocked server available on this abstraction level, the only sensible approach
# is to interrupt a connection, and wait for it to be resumed automatically. For that, another client
# with the same client_id connects to the server and then immediately disconnects.

test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT")
test_input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_KEY")
test_tls_opts = TlsContextOptions.create_client_with_mtls_from_path(test_input_cert, test_input_key)
test_tls = ClientTlsContext(test_tls_opts)

on_connection_success_future = Future()
on_connection_closed_future = Future()
on_connection_resumed_future = Future()

def on_connection_success_callback(connection, callback_data: OnConnectionSuccessData):
on_connection_success_future.set_result(
{'return_code': callback_data.return_code, "session_present": callback_data.session_present})

def on_connection_closed_callback(connection, callback_data: OnConnectionClosedData):
on_connection_closed_future.set_result({})

def on_connection_resumed_callback(connection, return_code: ConnectReturnCode, session_present):
on_connection_resumed_future.set_result(
{'return_code': return_code, "session_present": session_present})

connection = self._create_connection(
endpoint=test_input_endpoint,
tls_context=test_tls,
on_connection_success_callback=on_connection_success_callback,
on_connection_closed_callback=on_connection_closed_callback,
on_connection_resumed_callback=on_connection_resumed_callback)
connection.connect().result(TIMEOUT)
success_data = on_connection_success_future.result(TIMEOUT)
self.assertEqual(success_data['return_code'], ConnectReturnCode.ACCEPTED)
self.assertEqual(success_data['session_present'], False)

# Reset the future for the reconnect attempt.
on_connection_success_future = Future()

on_connection_success_future_dup = Future()

def on_connection_success_callback_dup(connection, callback_data: OnConnectionSuccessData):
on_connection_success_future_dup.set_result({})

# Reuse the same client_id to displace the first client.
connection_dup = self._create_connection(
endpoint=test_input_endpoint,
tls_context=test_tls,
client_id=connection.client_id,
on_connection_success_callback=on_connection_success_callback_dup)

connection_dup.connect().result(TIMEOUT)
on_connection_success_future_dup.result(TIMEOUT)
connection_dup.disconnect().result(TIMEOUT)

# After the second client disconnects, the first one should reconnect,
# and on_connection_success callback should be fired once again.
on_connection_resumed_future.result(TIMEOUT)
success_data = on_connection_success_future.result(TIMEOUT)

self.assertEqual(success_data['return_code'], ConnectReturnCode.ACCEPTED)
self.assertEqual(success_data['session_present'], False)

connection.disconnect().result(TIMEOUT)
on_connection_closed_future.result(TIMEOUT)

# ==============================================================
# MOSQUITTO CONNECTION TESTS
Expand Down

0 comments on commit 18e9169

Please sign in to comment.