diff --git a/azure-iot-device/azure/iot/device/iothub/abstract_clients.py b/azure-iot-device/azure/iot/device/iothub/abstract_clients.py index 34dca98c3..056996dfa 100644 --- a/azure-iot-device/azure/iot/device/iothub/abstract_clients.py +++ b/azure-iot-device/azure/iot/device/iothub/abstract_clients.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -def _validate_kwargs(**kwargs): +def _validate_kwargs(exclude=[], **kwargs): """Helper function to validate user provided kwargs. Raises TypeError if an invalid option has been provided""" valid_kwargs = [ @@ -30,11 +30,29 @@ def _validate_kwargs(**kwargs): "cipher", "server_verification_cert", "proxy_options", + "sastoken_ttl", ] for kwarg in kwargs: - if kwarg not in valid_kwargs: - raise TypeError("Got an unexpected keyword argument '{}'".format(kwarg)) + if (kwarg not in valid_kwargs) or (kwarg in exclude): + raise TypeError("Unsupported keyword argument: '{}'".format(kwarg)) + + +def _get_config_kwargs(**kwargs): + """Get the subset of kwargs which pertain the config object""" + valid_config_kwargs = [ + "product_info", + "websockets", + "cipher", + "server_verification_cert", + "proxy_options", + ] + + config_kwargs = {} + for kwarg in kwargs: + if kwarg in valid_config_kwargs: + config_kwargs[kwarg] = kwargs[kwarg] + return config_kwargs def _form_sas_uri(hostname, device_id, module_id=None): @@ -80,9 +98,11 @@ def create_from_connection_string(cls, connection_string, **kwargs): arbitrary product info which is appended to the user agent string. :param proxy_options: Options for sending traffic through proxy servers. :type proxy_options: :class:`azure.iot.device.ProxyOptions` + :param int sastoken_ttl: The time to live (in seconds) for the created SasToken used for + authentication. Default is 3600 seconds (1 hour) :raises: ValueError if given an invalid connection_string. - :raises: TypeError if given an unrecognized parameter. + :raises: TypeError if given an unsupported parameter. :returns: An instance of an IoTHub client that uses a connection string for authentication. """ @@ -101,20 +121,22 @@ def create_from_connection_string(cls, connection_string, **kwargs): signing_mechanism = auth.SymmetricKeySigningMechanism( key=connection_string[cs.SHARED_ACCESS_KEY] ) + token_ttl = kwargs.get("sastoken_ttl", 3600) try: - sastoken = st.SasToken(uri, signing_mechanism) + sastoken = st.SasToken(uri, signing_mechanism, ttl=token_ttl) except st.SasTokenError as e: - new_err = ValueError("Could not create a SasToken using provided connection string") + new_err = ValueError("Could not create a SasToken using provided values") new_err.__cause__ = e raise new_err # Pipeline Config setup + config_kwargs = _get_config_kwargs(**kwargs) pipeline_configuration = pipeline.IoTHubPipelineConfig( device_id=connection_string[cs.DEVICE_ID], module_id=connection_string.get(cs.MODULE_ID), hostname=connection_string[cs.HOST_NAME], gateway_hostname=connection_string.get(cs.GATEWAY_HOST_NAME), sastoken=sastoken, - **kwargs + **config_kwargs ) if cls.__name__ == "IoTHubDeviceClient": pipeline_configuration.blob_upload = True @@ -194,16 +216,18 @@ def create_from_x509_certificate(cls, x509, hostname, device_id, **kwargs): :param proxy_options: Options for sending traffic through proxy servers. :type proxy_options: :class:`azure.iot.device.ProxyOptions` - :raises: TypeError if given an unrecognized parameter. + :raises: TypeError if given an unsupported parameter. :returns: An instance of an IoTHub client that uses an X509 certificate for authentication. """ # Ensure no invalid kwargs were passed by the user - _validate_kwargs(**kwargs) + excluded_kwargs = ["sastoken_ttl"] + _validate_kwargs(exclude=excluded_kwargs, **kwargs) # Pipeline Config setup + config_kwargs = _get_config_kwargs(**kwargs) pipeline_configuration = pipeline.IoTHubPipelineConfig( - device_id=device_id, hostname=hostname, x509=x509, **kwargs + device_id=device_id, hostname=hostname, x509=x509, **config_kwargs ) pipeline_configuration.blob_upload = True # Blob Upload is a feature on Device Clients @@ -235,8 +259,10 @@ def create_from_symmetric_key(cls, symmetric_key, hostname, device_id, **kwargs) arbitrary product info which is appended to the user agent string. :param proxy_options: Options for sending traffic through proxy servers. :type proxy_options: :class:`azure.iot.device.ProxyOptions` + :param int sastoken_ttl: The time to live (in seconds) for the created SasToken used for + authentication. Default is 3600 seconds (1 hour) - :raises: TypeError if given an unrecognized parameter. + :raises: TypeError if given an unsupported parameter. :return: An instance of an IoTHub client that uses a symmetric key for authentication. """ @@ -246,16 +272,18 @@ def create_from_symmetric_key(cls, symmetric_key, hostname, device_id, **kwargs) # Create SasToken uri = _form_sas_uri(hostname=hostname, device_id=device_id) signing_mechanism = auth.SymmetricKeySigningMechanism(key=symmetric_key) + token_ttl = kwargs.get("sastoken_ttl", 3600) try: - sastoken = st.SasToken(uri, signing_mechanism) + sastoken = st.SasToken(uri, signing_mechanism, ttl=token_ttl) except st.SasTokenError as e: new_err = ValueError("Could not create a SasToken using provided values") new_err.__cause__ = e raise new_err # Pipeline Config setup + config_kwargs = _get_config_kwargs(**kwargs) pipeline_configuration = pipeline.IoTHubPipelineConfig( - device_id=device_id, hostname=hostname, sastoken=sastoken, **kwargs + device_id=device_id, hostname=hostname, sastoken=sastoken, **config_kwargs ) pipeline_configuration.blob_upload = True # Blob Upload is a feature on Device Clients @@ -297,19 +325,19 @@ def create_from_edge_environment(cls, **kwargs): arbitrary product info which is appended to the user agent string. :param proxy_options: Options for sending traffic through proxy servers. :type proxy_options: :class:`azure.iot.device.ProxyOptions` + :param int sastoken_ttl: The time to live (in seconds) for the created SasToken used for + authentication. Default is 3600 seconds (1 hour) :raises: OSError if the IoT Edge container is not configured correctly. :raises: ValueError if debug variables are invalid. + :raises: TypeError if given an unsupported parameter. :returns: An instance of an IoTHub client that uses the IoT Edge environment for authentication. """ # Ensure no invalid kwargs were passed by the user - _validate_kwargs(**kwargs) - if kwargs.get("server_verification_cert"): - raise TypeError( - "'server_verification_cert' is not supported by clients using an IoT Edge environment" - ) + excluded_kwargs = ["server_verification_cert"] + _validate_kwargs(exclude=excluded_kwargs, **kwargs) # First try the regular Edge container variables try: @@ -382,16 +410,18 @@ def create_from_edge_environment(cls, **kwargs): # Create SasToken uri = _form_sas_uri(hostname=hostname, device_id=device_id, module_id=module_id) + token_ttl = kwargs.get("sastoken_ttl", 3600) try: - sastoken = st.SasToken(uri, signing_mechanism) + sastoken = st.SasToken(uri, signing_mechanism, ttl=token_ttl) except st.SasTokenError as e: new_err = ValueError( - "Could not create a SasToken using the values in the Edge environment" + "Could not create a SasToken using the values provided, or in the Edge environment" ) new_err.__cause__ = e raise new_err # Pipeline Config setup + config_kwargs = _get_config_kwargs(**kwargs) pipeline_configuration = pipeline.IoTHubPipelineConfig( device_id=device_id, module_id=module_id, @@ -399,7 +429,7 @@ def create_from_edge_environment(cls, **kwargs): gateway_hostname=gateway_hostname, sastoken=sastoken, server_verification_cert=server_verification_cert, - **kwargs + **config_kwargs ) pipeline_configuration.method_invoke = ( True @@ -439,16 +469,18 @@ def create_from_x509_certificate(cls, x509, hostname, device_id, module_id, **kw :param proxy_options: Options for sending traffic through proxy servers. :type proxy_options: :class:`azure.iot.device.ProxyOptions` - :raises: TypeError if given an unrecognized parameter. + :raises: TypeError if given an unsupported parameter. :returns: An instance of an IoTHub client that uses an X509 certificate for authentication. """ # Ensure no invalid kwargs were passed by the user - _validate_kwargs(**kwargs) + excluded_kwargs = ["sastoken_ttl"] + _validate_kwargs(exclude=excluded_kwargs, **kwargs) # Pipeline Config setup + config_kwargs = _get_config_kwargs(**kwargs) pipeline_configuration = pipeline.IoTHubPipelineConfig( - device_id=device_id, module_id=module_id, hostname=hostname, x509=x509, **kwargs + device_id=device_id, module_id=module_id, hostname=hostname, x509=x509, **config_kwargs ) # Pipeline setup diff --git a/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py b/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py index e269745db..f106dbde7 100644 --- a/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py +++ b/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py @@ -19,15 +19,26 @@ logger = logging.getLogger(__name__) -def _validate_kwargs(**kwargs): +def _validate_kwargs(exclude=[], **kwargs): """Helper function to validate user provided kwargs. Raises TypeError if an invalid option has been provided""" # TODO: add support for server_verification_cert - valid_kwargs = ["websockets", "cipher", "proxy_options"] + valid_kwargs = ["websockets", "cipher", "proxy_options", "sastoken_ttl"] for kwarg in kwargs: - if kwarg not in valid_kwargs: - raise TypeError("Got an unexpected keyword argument '{}'".format(kwarg)) + if (kwarg not in valid_kwargs) or (kwarg in exclude): + raise TypeError("Unsupported keyword argument '{}'".format(kwarg)) + + +def _get_config_kwargs(**kwargs): + """Get the subset of kwargs which pertain the config object""" + valid_config_kwargs = ["websockets", "cipher", "proxy_options"] + + config_kwargs = {} + for kwarg in kwargs: + if kwarg in valid_config_kwargs: + config_kwargs[kwarg] = kwargs[kwarg] + return config_kwargs def _form_sas_uri(id_scope, registration_id): @@ -98,20 +109,22 @@ def create_from_symmetric_key( # Create SasToken uri = _form_sas_uri(id_scope=id_scope, registration_id=registration_id) signing_mechanism = auth.SymmetricKeySigningMechanism(key=symmetric_key) + token_ttl = kwargs.get("sastoken_ttl", 3600) try: - sastoken = st.SasToken(uri, signing_mechanism) + sastoken = st.SasToken(uri, signing_mechanism, ttl=token_ttl) except st.SasTokenError as e: new_err = ValueError("Could not create a SasToken using the provided values") new_err.__cause__ = e raise new_err # Pipeline Config setup + config_kwargs = _get_config_kwargs(**kwargs) pipeline_configuration = pipeline.ProvisioningPipelineConfig( hostname=provisioning_host, registration_id=registration_id, id_scope=id_scope, sastoken=sastoken, - **kwargs + **config_kwargs ) # Pipeline setup @@ -154,15 +167,17 @@ def create_from_x509_certificate( :returns: A ProvisioningDeviceClient which can register via Symmetric Key. """ # Ensure no invalid kwargs were passed by the user - _validate_kwargs(**kwargs) + excluded_kwargs = ["sastoken_ttl"] + _validate_kwargs(exclude=excluded_kwargs, **kwargs) # Pipeline Config setup + config_kwargs = _get_config_kwargs(**kwargs) pipeline_configuration = pipeline.ProvisioningPipelineConfig( hostname=provisioning_host, registration_id=registration_id, id_scope=id_scope, x509=x509, - **kwargs + **config_kwargs ) # Pipeline setup diff --git a/azure-iot-device/tests/iothub/shared_client_tests.py b/azure-iot-device/tests/iothub/shared_client_tests.py index 66e92401a..d0c18ecb2 100644 --- a/azure-iot-device/tests/iothub/shared_client_tests.py +++ b/azure-iot-device/tests/iothub/shared_client_tests.py @@ -253,6 +253,39 @@ def test_sastoken(self, mocker, client_class, connection_string): sastoken_mock = mocker.patch.object(st, "SasToken") cs_obj = cs.ConnectionString(connection_string) + custom_ttl = 1000 + client_class.create_from_connection_string(connection_string, sastoken_ttl=custom_ttl) + + # Determine expected URI based on class under test + if client_class.__name__ == "IoTHubDeviceClient": + expected_uri = "{hostname}/devices/{device_id}".format( + hostname=cs_obj[cs.HOST_NAME], device_id=cs_obj[cs.DEVICE_ID] + ) + else: + expected_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=cs_obj[cs.HOST_NAME], + device_id=cs_obj[cs.DEVICE_ID], + module_id=cs_obj[cs.MODULE_ID], + ) + + # SymmetricKeySigningMechanism created using the connection string's SharedAccessKey + assert sksm_mock.call_count == 1 + assert sksm_mock.call_args == mocker.call(key=cs_obj[cs.SHARED_ACCESS_KEY]) + + # Token was created with a SymmetricKeySigningMechanism, the expected URI, and custom ttl + assert sastoken_mock.call_count == 1 + assert sastoken_mock.call_args == mocker.call( + expected_uri, sksm_mock.return_value, ttl=custom_ttl + ) + + @pytest.mark.it( + "Uses 3600 seconds (1 hour) as the default SasToken TTL if no custom TTL is provided" + ) + def test_sastoken_default(self, mocker, client_class, connection_string): + sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") + sastoken_mock = mocker.patch.object(st, "SasToken") + cs_obj = cs.ConnectionString(connection_string) + client_class.create_from_connection_string(connection_string) # Determine expected URI based on class under test @@ -271,9 +304,11 @@ def test_sastoken(self, mocker, client_class, connection_string): assert sksm_mock.call_count == 1 assert sksm_mock.call_args == mocker.call(key=cs_obj[cs.SHARED_ACCESS_KEY]) - # Token was created with a SymmetricKeySigningMechanism and the expected URI + # Token was created with a SymmetricKeySigningMechanism, the expected URI, and default ttl assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call(expected_uri, sksm_mock.return_value) + assert sastoken_mock.call_args == mocker.call( + expected_uri, sksm_mock.return_value, ttl=3600 + ) @pytest.mark.it( "Creates MQTT and HTTP Pipelines with an IoTHubPipelineConfig object containing the SasToken and values from the connection string" @@ -410,6 +445,34 @@ def test_sastoken(self, mocker, client_class): hostname=self.hostname, device_id=self.device_id ) + custom_ttl = 1000 + client_class.create_from_symmetric_key( + symmetric_key=self.symmetric_key, + hostname=self.hostname, + device_id=self.device_id, + sastoken_ttl=custom_ttl, + ) + + # SymmetricKeySigningMechanism created using the provided symmetric key + assert sksm_mock.call_count == 1 + assert sksm_mock.call_args == mocker.call(key=self.symmetric_key) + + # SasToken created with the SymmetricKeySigningMechanism, the expected URI, and the custom ttl + assert sastoken_mock.call_count == 1 + assert sastoken_mock.call_args == mocker.call( + expected_uri, sksm_mock.return_value, ttl=custom_ttl + ) + + @pytest.mark.it( + "Uses 3600 seconds (1 hour) as the default SasToken TTL if no custom TTL is provided" + ) + def test_sastoken_default(self, mocker, client_class): + sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") + sastoken_mock = mocker.patch.object(st, "SasToken") + expected_uri = "{hostname}/devices/{device_id}".format( + hostname=self.hostname, device_id=self.device_id + ) + client_class.create_from_symmetric_key( symmetric_key=self.symmetric_key, hostname=self.hostname, device_id=self.device_id ) @@ -418,9 +481,11 @@ def test_sastoken(self, mocker, client_class): assert sksm_mock.call_count == 1 assert sksm_mock.call_args == mocker.call(key=self.symmetric_key) - # SasToken created with the SymmetricKeySigningMechanism and the expected URI + # SasToken created with the SymmetricKeySigningMechanism, the expected URI, and the default ttl assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call(expected_uri, sksm_mock.return_value) + assert sastoken_mock.call_args == mocker.call( + expected_uri, sksm_mock.return_value, ttl=3600 + ) @pytest.mark.it( "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the SasToken and values provided in parameters" @@ -530,6 +595,13 @@ def test_client_returned( assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value assert client._http_pipeline is mock_http_pipeline_init.return_value + @pytest.mark.it("Raises a TypeError if the 'sastoken_ttl' kwarg is supplied by the user") + def test_sastoken_ttl(self, client_class, x509): + with pytest.raises(TypeError): + client_class.create_from_x509_certificate( + x509=x509, hostname=self.hostname, device_id=self.device_id, sastoken_ttl=1000 + ) + ############################## # SHARED MODULE CLIENT TESTS # @@ -592,6 +664,13 @@ def test_client_returned( assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value assert client._http_pipeline is mock_http_pipeline_init.return_value + @pytest.mark.it("Raises a TypeError if the 'sastoken_ttl' kwarg is supplied by the user") + def test_sastoken_ttl(self, client_class, x509): + with pytest.raises(TypeError): + client_class.create_from_x509_certificate( + x509=x509, hostname=self.hostname, device_id=self.device_id, sastoken_ttl=1000 + ) + @pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") class SharedIoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests( @@ -674,7 +753,7 @@ def option_test_required_patching(self, mocker, mock_edge_hsm, edge_container_en mocker.patch.dict(os.environ, edge_container_environment, clear=True) @pytest.mark.it( - "Creates a SasToken that uses an IoTEdgeHsm, from the values extracted from the Edge environment" + "Creates a SasToken that uses an IoTEdgeHsm, from the values extracted from the Edge environment and the user-provided TTL" ) def test_sastoken(self, mocker, client_class, mock_edge_hsm, edge_container_environment): mocker.patch.dict(os.environ, edge_container_environment, clear=True) @@ -686,6 +765,39 @@ def test_sastoken(self, mocker, client_class, mock_edge_hsm, edge_container_envi module_id=edge_container_environment["IOTEDGE_MODULEID"], ) + custom_ttl = 1000 + client_class.create_from_edge_environment(sastoken_ttl=custom_ttl) + + # IoTEdgeHsm created using the extracted values + assert mock_edge_hsm.call_count == 1 + assert mock_edge_hsm.call_args == mocker.call( + module_id=edge_container_environment["IOTEDGE_MODULEID"], + generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], + workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], + api_version=edge_container_environment["IOTEDGE_APIVERSION"], + ) + + # SasToken created with the IoTEdgeHsm, the expected URI and the custom ttl + assert sastoken_mock.call_count == 1 + assert sastoken_mock.call_args == mocker.call( + expected_uri, mock_edge_hsm.return_value, ttl=custom_ttl + ) + + @pytest.mark.it( + "Uses 3600 seconds (1 hour) as the default SasToken TTL if no custom TTL is provided" + ) + def test_sastoken_default( + self, mocker, client_class, mock_edge_hsm, edge_container_environment + ): + mocker.patch.dict(os.environ, edge_container_environment, clear=True) + sastoken_mock = mocker.patch.object(st, "SasToken") + + expected_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=edge_container_environment["IOTEDGE_IOTHUBHOSTNAME"], + device_id=edge_container_environment["IOTEDGE_DEVICEID"], + module_id=edge_container_environment["IOTEDGE_MODULEID"], + ) + client_class.create_from_edge_environment() # IoTEdgeHsm created using the extracted values @@ -697,9 +809,11 @@ def test_sastoken(self, mocker, client_class, mock_edge_hsm, edge_container_envi api_version=edge_container_environment["IOTEDGE_APIVERSION"], ) - # SasToken created with the IoTEdgeHsm and the expected URI + # SasToken created with the IoTEdgeHsm, the expected URI, and the default ttl assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call(expected_uri, mock_edge_hsm.return_value) + assert sastoken_mock.call_args == mocker.call( + expected_uri, mock_edge_hsm.return_value, ttl=3600 + ) @pytest.mark.it( "Uses an IoTEdgeHsm as the SasToken signing mechanism even if any Edge local debug environment variables may also be present" @@ -728,7 +842,9 @@ def test_hybrid_env( api_version=edge_container_environment["IOTEDGE_APIVERSION"], ) assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call(mocker.ANY, mock_edge_hsm.return_value) + assert sastoken_mock.call_args == mocker.call( + mocker.ANY, mock_edge_hsm.return_value, ttl=3600 + ) @pytest.mark.it( "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the SasToken and values extracted from the Edge environment" @@ -849,7 +965,7 @@ def mock_open(self, mocker): return mocker.patch.object(io, "open") @pytest.mark.it( - "Creates a SasToken that uses a SymmetricKeySigningMechanism, from the values in the connection string extracted from the Edge local debug environment" + "Creates a SasToken that uses a SymmetricKeySigningMechanism, from the values in the connection string extracted from the Edge local debug environment, as well as the user-provided TTL" ) def test_sastoken(self, mocker, client_class, mock_open, edge_local_debug_environment): mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) @@ -862,15 +978,44 @@ def test_sastoken(self, mocker, client_class, mock_open, edge_local_debug_enviro module_id=cs_obj[cs.MODULE_ID], ) + custom_ttl = 1000 + client_class.create_from_edge_environment(sastoken_ttl=custom_ttl) + + # SymmetricKeySigningMechanism created using the connection string's Shared Access Key + assert sksm_mock.call_count == 1 + assert sksm_mock.call_args == mocker.call(key=cs_obj[cs.SHARED_ACCESS_KEY]) + + # SasToken created with the SymmetricKeySigningMechanism, the expected URI, and the custom ttl + assert sastoken_mock.call_count == 1 + assert sastoken_mock.call_args == mocker.call( + expected_uri, sksm_mock.return_value, ttl=custom_ttl + ) + + @pytest.mark.it( + "Uses 3600 seconds (1 hour) as the default SasToken TTL if no custom TTL is provided" + ) + def test_sastoken_default(self, mocker, client_class, mock_open, edge_local_debug_environment): + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") + sastoken_mock = mocker.patch.object(st, "SasToken") + cs_obj = cs.ConnectionString(edge_local_debug_environment["EdgeHubConnectionString"]) + expected_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=cs_obj[cs.HOST_NAME], + device_id=cs_obj[cs.DEVICE_ID], + module_id=cs_obj[cs.MODULE_ID], + ) + client_class.create_from_edge_environment() # SymmetricKeySigningMechanism created using the connection string's Shared Access Key assert sksm_mock.call_count == 1 assert sksm_mock.call_args == mocker.call(key=cs_obj[cs.SHARED_ACCESS_KEY]) - # SasToken created with the SymmetricKeySigningMechanism and the expected URI + # SasToken created with the SymmetricKeySigningMechanism, the expected URI and default ttl assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call(expected_uri, sksm_mock.return_value) + assert sastoken_mock.call_args == mocker.call( + expected_uri, sksm_mock.return_value, ttl=3600 + ) @pytest.mark.it( "Only uses Edge local debug variables if no Edge container variables are present in the environment" @@ -903,7 +1048,9 @@ def test_auth_provider_and_pipeline_hybrid_env( api_version=edge_container_environment["IOTEDGE_APIVERSION"], ) assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call(mocker.ANY, mock_edge_hsm.return_value) + assert sastoken_mock.call_args == mocker.call( + mocker.ANY, mock_edge_hsm.return_value, ttl=3600 + ) @pytest.mark.it( "Extracts the server verification certificate from the file indicated by the filepath extracted from the Edge local debug environment" diff --git a/azure-iot-device/tests/provisioning/shared_client_tests.py b/azure-iot-device/tests/provisioning/shared_client_tests.py index 437773935..f5530c3e9 100644 --- a/azure-iot-device/tests/provisioning/shared_client_tests.py +++ b/azure-iot-device/tests/provisioning/shared_client_tests.py @@ -132,20 +132,51 @@ def test_sastoken(self, mocker, client_class): id_scope=fake_id_scope, registration_id=fake_registration_id ) + custom_ttl = 1000 client_class.create_from_symmetric_key( provisioning_host=fake_provisioning_host, registration_id=fake_registration_id, id_scope=fake_id_scope, symmetric_key=fake_symmetric_key, + sastoken_ttl=custom_ttl, ) # SymmetricKeySigningMechanism created using the provided symmetric key assert sksm_mock.call_count == 1 assert sksm_mock.call_args == mocker.call(key=fake_symmetric_key) - # SasToken created with the SymmetricKeySigningMechanism and the expected URI + # SasToken created with the SymmetricKeySigningMechanism, the expected URI, and the custom ttl assert sastoken_mock.call_count == 1 - assert sastoken_mock.call_args == mocker.call(expected_uri, sksm_mock.return_value) + assert sastoken_mock.call_args == mocker.call( + expected_uri, sksm_mock.return_value, ttl=custom_ttl + ) + + @pytest.mark.it( + "Uses 3600 seconds (1 hour) as the default SasToken TTL if no custom TTL is provided" + ) + def test_sastoken_default(self, mocker, client_class): + sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") + sastoken_mock = mocker.patch.object(st, "SasToken") + expected_uri = "{id_scope}/registrations/{registration_id}".format( + id_scope=fake_id_scope, registration_id=fake_registration_id + ) + + client_class.create_from_symmetric_key( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, + ) + + # SymmetricKeySigningMechanism created using the provided symmetric key + assert sksm_mock.call_count == 1 + assert sksm_mock.call_args == mocker.call(key=fake_symmetric_key) + + # SasToken created with the SymmetricKeySigningMechanism, the expected URI, and the default ttl + assert sastoken_mock.call_count == 1 + assert sastoken_mock.call_args == mocker.call( + expected_uri, sksm_mock.return_value, ttl=3600 + ) @pytest.mark.it( "Creates an MQTT pipeline with a ProvisioningPipelineConfig object containing the SasToken and values provided in the parameters" @@ -249,3 +280,14 @@ def test_client_returned(self, mocker, client_class, x509, mock_pipeline_init): assert isinstance(client, client_class) assert client._pipeline is mock_pipeline_init.return_value + + @pytest.mark.it("Raises a TypeError if the 'sastoken_ttl' kwarg is supplied by the user") + def test_sastoken_ttl(self, client_class, x509): + with pytest.raises(TypeError): + client_class.create_from_x509_certificate( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + x509=x509, + sastoken_ttl=1000, + )