Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use on_connection_* callbacks from aws-c-mqtt #490

Merged
merged 6 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 64 additions & 32 deletions source/mqtt_client_connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,63 @@ static void s_mqtt_python_connection_destructor(PyObject *connection_capsule) {
}
}

static void s_on_connection_success(
struct aws_mqtt_client_connection *connection,
enum aws_mqtt_connect_return_code return_code,
bool session_present,
void *user_data) {

if (connection == NULL || user_data == NULL) {
return; // The connection is dead - skip!
}

struct mqtt_connection_binding *py_connection = user_data;

PyGILState_STATE state;
if (aws_py_gilstate_ensure(&state)) {
return; /* Python has shut down. Nothing matters anymore, but don't crash */
}

PyObject *self = PyWeakref_GetObject(py_connection->self_proxy); /* borrowed reference */
if (self != Py_None) {
PyObject *success_result =
PyObject_CallMethod(self, "_on_connection_success", "(iN)", return_code, PyBool_FromLong(session_present));
if (success_result) {
Py_DECREF(success_result);
} else {
PyErr_WriteUnraisable(PyErr_Occurred());
}
}

PyGILState_Release(state);
}

static void s_on_connection_failure(struct aws_mqtt_client_connection *connection, int error_code, void *user_data) {

if (connection == NULL || user_data == NULL) {
return; // The connection is dead - skip!
}

struct mqtt_connection_binding *py_connection = user_data;

PyGILState_STATE state;
if (aws_py_gilstate_ensure(&state)) {
return; /* Python has shut down. Nothing matters anymore, but don't crash */
}

PyObject *self = PyWeakref_GetObject(py_connection->self_proxy); /* borrowed reference */
if (self != Py_None) {
PyObject *success_result = PyObject_CallMethod(self, "_on_connection_failure", "(i)", error_code);
if (success_result) {
Py_DECREF(success_result);
} else {
PyErr_WriteUnraisable(PyErr_Occurred());
}
}

PyGILState_Release(state);
}

static void s_on_connection_interrupted(struct aws_mqtt_client_connection *connection, int error_code, void *userdata) {

if (connection == NULL || userdata == NULL) {
Expand Down Expand Up @@ -167,15 +224,6 @@ static void s_on_connection_resumed(
}
}

/* call _on_connection_success */
PyObject *success_result =
PyObject_CallMethod(self, "_on_connection_success", "(iN)", return_code, PyBool_FromLong(session_present));
if (success_result) {
Py_DECREF(success_result);
} else {
PyErr_WriteUnraisable(PyErr_Occurred());
}

PyGILState_Release(state);
}

Expand Down Expand Up @@ -250,6 +298,12 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) {
goto connection_new_failed;
}

if (aws_mqtt_client_connection_set_connection_result_handlers(
py_connection->native, s_on_connection_success, py_connection, s_on_connection_failure, py_connection)) {
PyErr_SetAwsLastError();
goto set_connection_handlers_failed;
}

if (aws_mqtt_client_connection_set_connection_interruption_handlers(
py_connection->native,
s_on_connection_interrupted,
Expand Down Expand Up @@ -306,6 +360,7 @@ PyObject *aws_py_mqtt_client_connection_new(PyObject *self, PyObject *args) {
proxy_new_failed:
use_websockets_failed:
set_interruption_failed:
set_connection_handlers_failed:
aws_mqtt_client_connection_release(py_connection->native);
connection_new_failed:
aws_mem_release(allocator, py_connection);
Expand Down Expand Up @@ -355,29 +410,6 @@ static void s_on_connect(
Py_XDECREF(callback);
}

/* Call on_connection_success or failure based on the result */
PyObject *self = PyWeakref_GetObject(py_connection->self_proxy); /* borrowed reference */
if (self != Py_None) {
/* Successful connection - call _on_connection_success */
if (error_code == AWS_ERROR_SUCCESS) {
PyObject *success_result = PyObject_CallMethod(
self, "_on_connection_success", "(iN)", return_code, PyBool_FromLong(session_present));
if (success_result) {
Py_DECREF(success_result);
} else {
PyErr_WriteUnraisable(PyErr_Occurred());
}
/* Unsuccessful connection - call _on_connection_failure */
} else {
PyObject *success_result = PyObject_CallMethod(self, "_on_connection_failure", "(i)", error_code);
if (success_result) {
Py_DECREF(success_result);
} else {
PyErr_WriteUnraisable(PyErr_Occurred());
}
}
}

PyGILState_Release(state);
}

Expand Down
102 changes: 87 additions & 15 deletions test/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def _create_connection(
tls_context,
port=8883,
use_static_singletons=False,
client_id=None,
on_connection_success_callback=None,
on_connection_failure_callback=None,
on_connection_closed_callback=None):
on_connection_closed_callback=None,
on_connection_resumed_callback=None):
if use_static_singletons:
client = Client(tls_ctx=tls_context)
else:
Expand All @@ -48,12 +50,13 @@ def _create_connection(

connection = Connection(
client=client,
client_id=create_client_id(),
client_id=client_id if client_id else create_client_id(),
host_name=endpoint,
port=port,
on_connection_closed=on_connection_closed_callback,
on_connection_failure=on_connection_failure_callback,
on_connection_success=on_connection_success_callback)
on_connection_success=on_connection_success_callback,
on_connection_resumed=on_connection_resumed_callback)
return connection

def test_connect_disconnect(self):
Expand Down Expand Up @@ -400,18 +403,18 @@ def test_connect_disconnect_with_callbacks_happy(self):
test_tls_opts = TlsContextOptions.create_client_with_mtls_from_path(test_input_cert, test_input_key)
test_tls = ClientTlsContext(test_tls_opts)

onConnectionSuccessFuture = Future()
onConnectionClosedFuture = Future()
on_connection_success_future = Future()
on_connection_closed_future = Future()

def on_connection_success_callback(connection, callback_data: OnConnectionSuccessData):
onConnectionSuccessFuture.set_result(
on_connection_success_future.set_result(
{'return_code': callback_data.return_code, "session_present": callback_data.session_present})

def on_connection_failure_callback(connection, callback_data: OnConnectionFailureData):
pass

def on_connection_closed_callback(connection, callback_data: OnConnectionClosedData):
onConnectionClosedFuture.set_result({})
on_connection_closed_future.set_result({})

connection = self._create_connection(
endpoint=test_input_endpoint,
Expand All @@ -420,11 +423,11 @@ def on_connection_closed_callback(connection, callback_data: OnConnectionClosedD
on_connection_failure_callback=on_connection_failure_callback,
on_connection_closed_callback=on_connection_closed_callback)
connection.connect().result(TIMEOUT)
successData = onConnectionSuccessFuture.result(TIMEOUT)
self.assertEqual(successData['return_code'], ConnectReturnCode.ACCEPTED)
self.assertEqual(successData['session_present'], False)
success_data = on_connection_success_future.result(TIMEOUT)
self.assertEqual(success_data['return_code'], ConnectReturnCode.ACCEPTED)
self.assertEqual(success_data['session_present'], False)
connection.disconnect().result(TIMEOUT)
onConnectionClosedFuture.result(TIMEOUT)
on_connection_closed_future.result(TIMEOUT)

def test_connect_disconnect_with_callbacks_unhappy(self):
test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
Expand All @@ -433,13 +436,13 @@ def test_connect_disconnect_with_callbacks_unhappy(self):
test_tls_opts = TlsContextOptions.create_client_with_mtls_from_path(test_input_cert, test_input_key)
test_tls = ClientTlsContext(test_tls_opts)

onConnectionFailureFuture = Future()
on_onnection_failure_future = Future()

def on_connection_success_callback(connection, callback_data: OnConnectionSuccessData):
pass

def on_connection_failure_callback(connection, callback_data: OnConnectionFailureData):
onConnectionFailureFuture.set_result({'error': callback_data.error})
on_onnection_failure_future.set_result({'error': callback_data.error})

def on_connection_closed_callback(connection, callback_data: OnConnectionClosedData):
pass
Expand All @@ -459,8 +462,77 @@ def on_connection_closed_callback(connection, callback_data: OnConnectionClosedD
exception_occurred = True
self.assertTrue(exception_occurred, "Exception did not occur when connecting with invalid arguments!")

failureData = onConnectionFailureFuture.result(TIMEOUT)
self.assertTrue(failureData['error'] is not None)
failure_data = on_onnection_failure_future.result(TIMEOUT)
self.assertTrue(failure_data['error'] is not None)

def test_connect_disconnect_with_callbacks_happy_on_resume(self):
# Check that an on_connection_success callback fires on a resumed connection.

# NOTE Since there is no mocked server available on this abstraction level, the only sensible approach
# is to interrupt a connection, and wait for it to be resumed automatically. For that, another client
# with the same client_id connects to the server and then immediately disconnects.

test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT")
test_input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_KEY")
test_tls_opts = TlsContextOptions.create_client_with_mtls_from_path(test_input_cert, test_input_key)
test_tls = ClientTlsContext(test_tls_opts)

on_connection_success_future = Future()
on_connection_closed_future = Future()
on_connection_resumed_future = Future()

def on_connection_success_callback(connection, callback_data: OnConnectionSuccessData):
on_connection_success_future.set_result(
{'return_code': callback_data.return_code, "session_present": callback_data.session_present})

def on_connection_closed_callback(connection, callback_data: OnConnectionClosedData):
on_connection_closed_future.set_result({})

def on_connection_resumed_callback(connection, return_code: ConnectReturnCode, session_present):
on_connection_resumed_future.set_result(
{'return_code': return_code, "session_present": session_present})

connection = self._create_connection(
endpoint=test_input_endpoint,
tls_context=test_tls,
on_connection_success_callback=on_connection_success_callback,
on_connection_closed_callback=on_connection_closed_callback,
on_connection_resumed_callback=on_connection_resumed_callback)
connection.connect().result(TIMEOUT)
success_data = on_connection_success_future.result(TIMEOUT)
self.assertEqual(success_data['return_code'], ConnectReturnCode.ACCEPTED)
self.assertEqual(success_data['session_present'], False)

# Reset the future for the reconnect attempt.
on_connection_success_future = Future()
sfod marked this conversation as resolved.
Show resolved Hide resolved

on_connection_success_future_dup = Future()

def on_connection_success_callback_dup(connection, callback_data: OnConnectionSuccessData):
on_connection_success_future_dup.set_result({})

# Reuse the same client_id to displace the first client.
connection_dup = self._create_connection(
endpoint=test_input_endpoint,
tls_context=test_tls,
client_id=connection.client_id,
on_connection_success_callback=on_connection_success_callback_dup)

connection_dup.connect().result(TIMEOUT)
on_connection_success_future_dup.result(TIMEOUT)
connection_dup.disconnect().result(TIMEOUT)

# After the second client disconnects, the first one should reconnect,
# and on_connection_success callback should be fired once again.
on_connection_resumed_future.result(TIMEOUT)
success_data = on_connection_success_future.result(TIMEOUT)

self.assertEqual(success_data['return_code'], ConnectReturnCode.ACCEPTED)
self.assertEqual(success_data['session_present'], False)

connection.disconnect().result(TIMEOUT)
on_connection_closed_future.result(TIMEOUT)

# ==============================================================
# MOSQUITTO CONNECTION TESTS
Expand Down
Loading