Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

from awscrt import NativeResource
from awscrt._test import check_for_leaks
from awscrt.io import init_logging, LogLevel
import time
import unittest
import sys

TIMEOUT = 30.0

Expand Down Expand Up @@ -57,3 +56,23 @@ def tearDown(self):
except Exception:
NativeResourceTest._previous_test_failed = True
raise


MAX_RETRIES = 5


def _is_retryable_exception(e):
exception_text = str(e)
return "AWS_IO_TLS_NEGOTIATION_TIMEOUT" in exception_text or "AWS_IO_SOCKET_TIMEOUT" in exception_text


def test_retry_wrapper(test_function):
for i in range(MAX_RETRIES):
try:
test_function()
return
except Exception as e:
if _is_retryable_exception(e) and i + 1 < MAX_RETRIES:
time.sleep(1)
else:
raise
116 changes: 93 additions & 23 deletions test/test_mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from awscrt.io import ClientBootstrap, ClientTlsContext, DefaultHostResolver, EventLoopGroup, Pkcs11Lib, TlsContextOptions
from awscrt import http
from awscrt.mqtt import Client, Connection, QoS, Will, OnConnectionClosedData, OnConnectionFailureData, OnConnectionSuccessData, ConnectReturnCode
from test import NativeResourceTest
from test import test_retry_wrapper, NativeResourceTest
from concurrent.futures import Future
import os
import unittest
Expand Down Expand Up @@ -59,7 +59,7 @@ def _create_connection(
on_connection_resumed=on_connection_resumed_callback)
return connection

def test_connect_disconnect(self):
def _test_connect_disconnect(self):
test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT")
test_input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_KEY")
Expand All @@ -71,7 +71,10 @@ def test_connect_disconnect(self):
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_ecc_connect_disconnect(self):
def test_connect_disconnect(self):
test_retry_wrapper(self._test_connect_disconnect)

def _test_ecc_connect_disconnect(self):
test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_ECC_CERT")
test_input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_ECC_KEY")
Expand All @@ -83,7 +86,10 @@ def test_ecc_connect_disconnect(self):
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_pkcs11(self):
def test_ecc_connect_disconnect(self):
test_retry_wrapper(self._test_ecc_connect_disconnect)

def _test_pkcs11(self):
test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
test_input_pkcs11_lib = _get_env_variable("AWS_TEST_PKCS11_LIB")
test_input_pkcs11_pin = _get_env_variable("AWS_TEST_PKCS11_PIN")
Expand All @@ -105,7 +111,10 @@ def test_pkcs11(self):
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_pub_sub(self):
def test_pkcs11(self):
test_retry_wrapper(self._test_pkcs11)

def _test_pub_sub(self):
self.TEST_TOPIC = '/test/me/senpai/' + str(uuid.uuid4())
test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT")
Expand Down Expand Up @@ -148,7 +157,10 @@ def on_message(**kwargs):
# disconnect
connection.disconnect().result(TIMEOUT)

def test_will(self):
def test_pub_sub(self):
test_retry_wrapper(self._test_pub_sub)

def _test_will(self):
self.TEST_TOPIC = '/test/me/senpai/' + str(uuid.uuid4())
test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT")
Expand Down Expand Up @@ -239,7 +251,10 @@ def on_message(**kwargs):
# disconnect
subscriber.disconnect().result(TIMEOUT)

def test_on_message(self):
def test_will(self):
test_retry_wrapper(self._test_will)

def _test_on_message(self):
test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT")
test_input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_KEY")
Expand Down Expand Up @@ -275,7 +290,10 @@ def on_message(**kwargs):
# disconnect
connection.disconnect().result(TIMEOUT)

def test_on_message_old_fn_signature(self):
def test_on_message(self):
test_retry_wrapper(self._test_on_message)

def _test_on_message_old_fn_signature(self):
# ensure that message-received callbacks with the old function signature still work

test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
Expand Down Expand Up @@ -320,7 +338,10 @@ def on_sub_message(topic, payload):
# disconnect
connection.disconnect().result(TIMEOUT)

def test_connect_disconnect_with_default_singletons(self):
def test_on_message_old_fn_signature(self):
test_retry_wrapper(self._test_on_message_old_fn_signature)

def _test_connect_disconnect_with_default_singletons(self):
test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT")
test_input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_KEY")
Expand All @@ -337,7 +358,10 @@ def test_connect_disconnect_with_default_singletons(self):
EventLoopGroup.release_static_default()
DefaultHostResolver.release_static_default()

def test_connect_publish_wait_statistics_disconnect(self):
def test_connect_disconnect_with_default_singletons(self):
test_retry_wrapper(self._test_connect_disconnect_with_default_singletons)

def _test_connect_publish_wait_statistics_disconnect(self):
test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT")
test_input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_KEY")
Expand Down Expand Up @@ -369,7 +393,10 @@ def test_connect_publish_wait_statistics_disconnect(self):
# disconnect
connection.disconnect().result(TIMEOUT)

def test_connect_publish_statistics_wait_disconnect(self):
def test_connect_publish_wait_statistics_disconnect(self):
test_retry_wrapper(self._test_connect_publish_wait_statistics_disconnect)

def _test_connect_publish_statistics_wait_disconnect(self):
test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT")
test_input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_KEY")
Expand Down Expand Up @@ -409,7 +436,10 @@ def test_connect_publish_statistics_wait_disconnect(self):
# disconnect
connection.disconnect().result(TIMEOUT)

def test_connect_disconnect_with_callbacks_happy(self):
def test_connect_publish_statistics_wait_disconnect(self):
test_retry_wrapper(self._test_connect_publish_statistics_wait_disconnect)

def _test_connect_disconnect_with_callbacks_happy(self):
test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT")
test_input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_KEY")
Expand Down Expand Up @@ -442,7 +472,10 @@ def on_connection_closed_callback(connection, callback_data: OnConnectionClosedD
connection.disconnect().result(TIMEOUT)
on_connection_closed_future.result(TIMEOUT)

def test_connect_disconnect_with_callbacks_unhappy(self):
def test_connect_disconnect_with_callbacks_happy(self):
test_retry_wrapper(self._test_connect_disconnect_with_callbacks_happy)

def _test_connect_disconnect_with_callbacks_unhappy(self):
test_input_endpoint = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
test_input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT")
test_input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_KEY")
Expand Down Expand Up @@ -478,7 +511,10 @@ def on_connection_closed_callback(connection, callback_data: OnConnectionClosedD
failure_data = on_onnection_failure_future.result(TIMEOUT)
self.assertTrue(failure_data['error'] is not None)

def test_connect_disconnect_with_callbacks_happy_on_resume(self):
def test_connect_disconnect_with_callbacks_unhappy(self):
test_retry_wrapper(self._test_connect_disconnect_with_callbacks_unhappy)

def _test_connect_disconnect_with_callbacks_happy_on_resume(self):
# Check that an on_connection_success callback fires on a resumed connection.

# NOTE Since there is no mocked server available on this abstraction level, the only sensible approach
Expand Down Expand Up @@ -517,6 +553,10 @@ def on_connection_resumed_callback(connection, return_code: ConnectReturnCode, s
self.assertEqual(success_data['return_code'], ConnectReturnCode.ACCEPTED)
self.assertEqual(success_data['session_present'], False)

# Putting a sleep here helps prevent a "race" condition in IoT Core where the second connection can get
# rejected rather than the first disconnected.
time.sleep(5)

# Reset the future for the reconnect attempt.
on_connection_success_future = Future()

Expand Down Expand Up @@ -547,11 +587,14 @@ def on_connection_success_callback_dup(connection, callback_data: OnConnectionSu
connection.disconnect().result(TIMEOUT)
on_connection_closed_future.result(TIMEOUT)

def test_connect_disconnect_with_callbacks_happy_on_resume(self):
test_retry_wrapper(self._test_connect_disconnect_with_callbacks_happy_on_resume)

# ==============================================================
# MOSQUITTO CONNECTION TESTS
# ==============================================================

def test_mqtt311_direct_connect_minimum(self):
def _test_mqtt311_direct_connect_minimum(self):
input_host_name = _get_env_variable("AWS_TEST_MQTT311_DIRECT_MQTT_HOST")
input_port = int(_get_env_variable("AWS_TEST_MQTT311_DIRECT_MQTT_PORT"))

Expand All @@ -567,7 +610,10 @@ def test_mqtt311_direct_connect_minimum(self):
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_mqtt311_direct_connect_basic_auth(self):
def test_mqtt311_direct_connect_minimum(self):
test_retry_wrapper(self._test_mqtt311_direct_connect_minimum)

def _test_mqtt311_direct_connect_basic_auth(self):
input_host_name = _get_env_variable("AWS_TEST_MQTT311_DIRECT_MQTT_BASIC_AUTH_HOST")
input_port = int(_get_env_variable("AWS_TEST_MQTT311_DIRECT_MQTT_BASIC_AUTH_PORT"))
input_username = _get_env_variable("AWS_TEST_MQTT311_BASIC_AUTH_USERNAME")
Expand All @@ -587,7 +633,10 @@ def test_mqtt311_direct_connect_basic_auth(self):
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_mqtt311_direct_connect_tls(self):
def test_mqtt311_direct_connect_basic_auth(self):
test_retry_wrapper(self._test_mqtt311_direct_connect_basic_auth)

def _test_mqtt311_direct_connect_tls(self):
input_host_name = _get_env_variable("AWS_TEST_MQTT311_DIRECT_MQTT_TLS_HOST")
input_port = int(_get_env_variable("AWS_TEST_MQTT311_DIRECT_MQTT_TLS_PORT"))

Expand All @@ -605,7 +654,10 @@ def test_mqtt311_direct_connect_tls(self):
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_mqtt311_direct_connect_mutual_tls(self):
def test_mqtt311_direct_connect_tls(self):
test_retry_wrapper(self._test_mqtt311_direct_connect_tls)

def _test_mqtt311_direct_connect_mutual_tls(self):
input_cert = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_CERT")
input_key = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_RSA_KEY")
input_host = _get_env_variable("AWS_TEST_MQTT311_IOT_CORE_HOST")
Expand All @@ -626,7 +678,10 @@ def test_mqtt311_direct_connect_mutual_tls(self):
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_mqtt311_direct_connect_http_proxy_tls(self):
def test_mqtt311_direct_connect_mutual_tls(self):
test_retry_wrapper(self._test_mqtt311_direct_connect_mutual_tls)

def _test_mqtt311_direct_connect_http_proxy_tls(self):
input_proxy_host = _get_env_variable("AWS_TEST_MQTT311_PROXY_HOST")
input_proxy_port = int(_get_env_variable("AWS_TEST_MQTT311_PROXY_PORT"))
input_host_name = _get_env_variable("AWS_TEST_MQTT311_DIRECT_MQTT_TLS_HOST")
Expand Down Expand Up @@ -655,7 +710,10 @@ def test_mqtt311_direct_connect_http_proxy_tls(self):
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_mqtt311_websocket_connect_minimum(self):
def test_mqtt311_direct_connect_http_proxy_tls(self):
test_retry_wrapper(self._test_mqtt311_direct_connect_http_proxy_tls)

def _test_mqtt311_websocket_connect_minimum(self):
input_host_name = _get_env_variable("AWS_TEST_MQTT311_WS_MQTT_HOST")
input_port = int(_get_env_variable("AWS_TEST_MQTT311_WS_MQTT_PORT"))

Expand All @@ -677,7 +735,10 @@ def sign_function(transform_args, **kwargs):
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_mqtt311_websocket_connect_basic_auth(self):
def test_mqtt311_websocket_connect_minimum(self):
test_retry_wrapper(self._test_mqtt311_websocket_connect_minimum)

def _test_mqtt311_websocket_connect_basic_auth(self):
input_host_name = _get_env_variable("AWS_TEST_MQTT311_WS_MQTT_BASIC_AUTH_HOST")
input_port = int(_get_env_variable("AWS_TEST_MQTT311_WS_MQTT_BASIC_AUTH_PORT"))
input_username = _get_env_variable("AWS_TEST_MQTT311_BASIC_AUTH_USERNAME")
Expand All @@ -703,7 +764,10 @@ def sign_function(transform_args, **kwargs):
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_mqtt311_websocket_connect_tls(self):
def test_mqtt311_websocket_connect_basic_auth(self):
test_retry_wrapper(self._test_mqtt311_websocket_connect_basic_auth)

def _test_mqtt311_websocket_connect_tls(self):
input_host_name = _get_env_variable("AWS_TEST_MQTT311_WS_MQTT_TLS_HOST")
input_port = int(_get_env_variable("AWS_TEST_MQTT311_WS_MQTT_TLS_PORT"))

Expand All @@ -727,7 +791,10 @@ def sign_function(transform_args, **kwargs):
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_mqtt311_websocket_connect_http_proxy_tls(self):
def test_mqtt311_websocket_connect_tls(self):
test_retry_wrapper(self._test_mqtt311_websocket_connect_tls)

def _test_mqtt311_websocket_connect_http_proxy_tls(self):
input_proxy_host = _get_env_variable("AWS_TEST_MQTT311_PROXY_HOST")
input_proxy_port = int(_get_env_variable("AWS_TEST_MQTT311_PROXY_PORT"))
input_host_name = _get_env_variable("AWS_TEST_MQTT311_WS_MQTT_TLS_HOST")
Expand Down Expand Up @@ -761,6 +828,9 @@ def sign_function(transform_args, **kwargs):
connection.connect().result(TIMEOUT)
connection.disconnect().result(TIMEOUT)

def test_mqtt311_websocket_connect_http_proxy_tls(self):
test_retry_wrapper(self._test_mqtt311_websocket_connect_http_proxy_tls)


if __name__ == 'main':
unittest.main()
Loading